forgejo/vendor/github.com/pingcap/tidb/driver.go

561 lines
14 KiB
Go

// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
// database/sql/driver
package tidb
import (
"database/sql"
"database/sql/driver"
"io"
"net/url"
"path/filepath"
"strings"
"sync"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/types"
)
const (
// DriverName is name of TiDB driver.
DriverName = "tidb"
)
var (
_ driver.Conn = (*driverConn)(nil)
_ driver.Execer = (*driverConn)(nil)
_ driver.Queryer = (*driverConn)(nil)
_ driver.Tx = (*driverConn)(nil)
_ driver.Result = (*driverResult)(nil)
_ driver.Rows = (*driverRows)(nil)
_ driver.Stmt = (*driverStmt)(nil)
_ driver.Driver = (*sqlDriver)(nil)
txBeginSQL = "BEGIN;"
txCommitSQL = "COMMIT;"
txRollbackSQL = "ROLLBACK;"
errNoResult = errors.New("query statement does not produce a result set (no top level SELECT)")
)
type errList []error
type driverParams struct {
storePath string
dbName string
// when set to true `mysql.Time` isn't encoded as string but passed as `time.Time`
// this option is named for compatibility the same as in the mysql driver
// while we actually do not have additional parsing to do
parseTime bool
}
func (e *errList) append(err error) {
if err != nil {
*e = append(*e, err)
}
}
func (e errList) error() error {
if len(e) == 0 {
return nil
}
return e
}
func (e errList) Error() string {
a := make([]string, len(e))
for i, v := range e {
a[i] = v.Error()
}
return strings.Join(a, "\n")
}
func params(args []driver.Value) []interface{} {
r := make([]interface{}, len(args))
for i, v := range args {
r[i] = interface{}(v)
}
return r
}
var (
tidbDriver = &sqlDriver{}
driverOnce sync.Once
)
// RegisterDriver registers TiDB driver.
// The name argument can be optionally prefixed by "engine://". In that case the
// prefix is recognized as a storage engine name.
//
// The name argument can be optionally prefixed by "memory://". In that case
// the prefix is stripped before interpreting it as a name of a memory-only,
// volatile DB.
//
// [0]: http://golang.org/pkg/database/sql/driver/
func RegisterDriver() {
driverOnce.Do(func() { sql.Register(DriverName, tidbDriver) })
}
// sqlDriver implements the interface required by database/sql/driver.
type sqlDriver struct {
mu sync.Mutex
}
func (d *sqlDriver) lock() {
d.mu.Lock()
}
func (d *sqlDriver) unlock() {
d.mu.Unlock()
}
// parseDriverDSN cuts off DB name from dsn. It returns error if the dsn is not
// valid.
func parseDriverDSN(dsn string) (params *driverParams, err error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, errors.Trace(err)
}
path := filepath.Join(u.Host, u.Path)
dbName := filepath.Clean(filepath.Base(path))
if dbName == "" || dbName == "." || dbName == string(filepath.Separator) {
return nil, errors.Errorf("invalid DB name %q", dbName)
}
// cut off dbName
path = filepath.Clean(filepath.Dir(path))
if path == "" || path == "." || path == string(filepath.Separator) {
return nil, errors.Errorf("invalid dsn %q", dsn)
}
u.Path, u.Host = path, ""
params = &driverParams{
storePath: u.String(),
dbName: dbName,
}
// parse additional driver params
query := u.Query()
if parseTime := query.Get("parseTime"); parseTime == "true" {
params.parseTime = true
}
return params, nil
}
// Open returns a new connection to the database.
//
// The dsn must be a URL format 'engine://path/dbname?params'.
// Engine is the storage name registered with RegisterStore.
// Path is the storage specific format.
// Params is key-value pairs split by '&', optional params are storage specific.
// Examples:
// goleveldb://relative/path/test
// boltdb:///absolute/path/test
// hbase://zk1,zk2,zk3/hbasetbl/test?tso=zk
//
// Open may return a cached connection (one previously closed), but doing so is
// unnecessary; the sql package maintains a pool of idle connections for
// efficient re-use.
//
// The behavior of the mysql driver regarding time parsing can also be imitated
// by passing ?parseTime
//
// The returned connection is only used by one goroutine at a time.
func (d *sqlDriver) Open(dsn string) (driver.Conn, error) {
params, err := parseDriverDSN(dsn)
if err != nil {
return nil, errors.Trace(err)
}
store, err := NewStore(params.storePath)
if err != nil {
return nil, errors.Trace(err)
}
sess, err := CreateSession(store)
if err != nil {
return nil, errors.Trace(err)
}
s := sess.(*session)
d.lock()
defer d.unlock()
DBName := model.NewCIStr(params.dbName)
domain := sessionctx.GetDomain(s)
cs := &ast.CharsetOpt{
Chs: "utf8",
Col: "utf8_bin",
}
if !domain.InfoSchema().SchemaExists(DBName) {
err = domain.DDL().CreateSchema(s, DBName, cs)
if err != nil {
return nil, errors.Trace(err)
}
}
driver := &sqlDriver{}
return newDriverConn(s, driver, DBName.O, params)
}
// driverConn is a connection to a database. It is not used concurrently by
// multiple goroutines.
//
// Conn is assumed to be stateful.
type driverConn struct {
s Session
driver *sqlDriver
stmts map[string]driver.Stmt
params *driverParams
}
func newDriverConn(sess *session, d *sqlDriver, schema string, params *driverParams) (driver.Conn, error) {
r := &driverConn{
driver: d,
stmts: map[string]driver.Stmt{},
s: sess,
params: params,
}
_, err := r.s.Execute("use " + schema)
if err != nil {
return nil, errors.Trace(err)
}
return r, nil
}
// Prepare returns a prepared statement, bound to this connection.
func (c *driverConn) Prepare(query string) (driver.Stmt, error) {
stmtID, paramCount, fields, err := c.s.PrepareStmt(query)
if err != nil {
return nil, err
}
s := &driverStmt{
conn: c,
query: query,
stmtID: stmtID,
paramCount: paramCount,
isQuery: fields != nil,
}
c.stmts[query] = s
return s, nil
}
// Close invalidates and potentially stops any current prepared statements and
// transactions, marking this connection as no longer in use.
//
// Because the sql package maintains a free pool of connections and only calls
// Close when there's a surplus of idle connections, it shouldn't be necessary
// for drivers to do their own connection caching.
func (c *driverConn) Close() error {
var err errList
for _, s := range c.stmts {
stmt := s.(*driverStmt)
err.append(stmt.conn.s.DropPreparedStmt(stmt.stmtID))
}
c.driver.lock()
defer c.driver.unlock()
return err.error()
}
// Begin starts and returns a new transaction.
func (c *driverConn) Begin() (driver.Tx, error) {
if c.s == nil {
return nil, errors.Errorf("Need init first")
}
if _, err := c.s.Execute(txBeginSQL); err != nil {
return nil, errors.Trace(err)
}
return c, nil
}
func (c *driverConn) Commit() error {
if c.s == nil {
return terror.CommitNotInTransaction
}
_, err := c.s.Execute(txCommitSQL)
if err != nil {
return errors.Trace(err)
}
err = c.s.FinishTxn(false)
return errors.Trace(err)
}
func (c *driverConn) Rollback() error {
if c.s == nil {
return terror.RollbackNotInTransaction
}
if _, err := c.s.Execute(txRollbackSQL); err != nil {
return errors.Trace(err)
}
return nil
}
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the sql package's DB.Exec will first
// prepare a query, execute the statement, and then close the statement.
//
// Exec may return driver.ErrSkip.
func (c *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return c.driverExec(query, args)
}
func (c *driverConn) getStmt(query string) (stmt driver.Stmt, err error) {
stmt, ok := c.stmts[query]
if !ok {
stmt, err = c.Prepare(query)
if err != nil {
return nil, errors.Trace(err)
}
}
return
}
func (c *driverConn) driverExec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 {
if _, err := c.s.Execute(query); err != nil {
return nil, errors.Trace(err)
}
r := &driverResult{}
r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
return r, nil
}
stmt, err := c.getStmt(query)
if err != nil {
return nil, errors.Trace(err)
}
return stmt.Exec(args)
}
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Queryer, the sql package's DB.Query will first
// prepare a query, execute the statement, and then close the statement.
//
// Query may return driver.ErrSkip.
func (c *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return c.driverQuery(query, args)
}
func (c *driverConn) driverQuery(query string, args []driver.Value) (driver.Rows, error) {
if len(args) == 0 {
rss, err := c.s.Execute(query)
if err != nil {
return nil, errors.Trace(err)
}
if len(rss) == 0 {
return nil, errors.Trace(errNoResult)
}
return &driverRows{params: c.params, rs: rss[0]}, nil
}
stmt, err := c.getStmt(query)
if err != nil {
return nil, errors.Trace(err)
}
return stmt.Query(args)
}
// driverResult is the result of a query execution.
type driverResult struct {
lastInsertID int64
rowsAffected int64
}
// LastInsertID returns the database's auto-generated ID after, for example, an
// INSERT into a table with primary key.
func (r *driverResult) LastInsertId() (int64, error) { // -golint
return r.lastInsertID, nil
}
// RowsAffected returns the number of rows affected by the query.
func (r *driverResult) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
// driverRows is an iterator over an executed query's results.
type driverRows struct {
rs ast.RecordSet
params *driverParams
}
// Columns returns the names of the columns. The number of columns of the
// result is inferred from the length of the slice. If a particular column
// name isn't known, an empty string should be returned for that entry.
func (r *driverRows) Columns() []string {
if r.rs == nil {
return []string{}
}
fs, _ := r.rs.Fields()
names := make([]string, len(fs))
for i, f := range fs {
names[i] = f.ColumnAsName.O
}
return names
}
// Close closes the rows iterator.
func (r *driverRows) Close() error {
if r.rs != nil {
return r.rs.Close()
}
return nil
}
// Next is called to populate the next row of data into the provided slice. The
// provided slice will be the same size as the Columns() are wide.
//
// The dest slice may be populated only with a driver Value type, but excluding
// string. All string values must be converted to []byte.
//
// Next should return io.EOF when there are no more rows.
func (r *driverRows) Next(dest []driver.Value) error {
if r.rs == nil {
return io.EOF
}
row, err := r.rs.Next()
if err != nil {
return errors.Trace(err)
}
if row == nil {
return io.EOF
}
if len(row.Data) != len(dest) {
return errors.Errorf("field count mismatch: got %d, need %d", len(row.Data), len(dest))
}
for i, xi := range row.Data {
switch xi.Kind() {
case types.KindNull:
dest[i] = nil
case types.KindInt64:
dest[i] = xi.GetInt64()
case types.KindUint64:
dest[i] = xi.GetUint64()
case types.KindFloat32:
dest[i] = xi.GetFloat32()
case types.KindFloat64:
dest[i] = xi.GetFloat64()
case types.KindString:
dest[i] = xi.GetString()
case types.KindBytes:
dest[i] = xi.GetBytes()
case types.KindMysqlBit:
dest[i] = xi.GetMysqlBit().ToString()
case types.KindMysqlDecimal:
dest[i] = xi.GetMysqlDecimal().String()
case types.KindMysqlDuration:
dest[i] = xi.GetMysqlDuration().String()
case types.KindMysqlEnum:
dest[i] = xi.GetMysqlEnum().String()
case types.KindMysqlHex:
dest[i] = xi.GetMysqlHex().ToString()
case types.KindMysqlSet:
dest[i] = xi.GetMysqlSet().String()
case types.KindMysqlTime:
t := xi.GetMysqlTime()
if !r.params.parseTime {
dest[i] = t.String()
} else {
dest[i] = t.Time
}
default:
return errors.Errorf("unable to handle type %T", xi.GetValue())
}
}
return nil
}
// driverStmt is a prepared statement. It is bound to a driverConn and not used
// by multiple goroutines concurrently.
type driverStmt struct {
conn *driverConn
query string
stmtID uint32
paramCount int
isQuery bool
}
// Close closes the statement.
//
// As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
func (s *driverStmt) Close() error {
s.conn.s.DropPreparedStmt(s.stmtID)
delete(s.conn.stmts, s.query)
return nil
}
// NumInput returns the number of placeholder parameters.
//
// If NumInput returns >= 0, the sql package will sanity check argument counts
// from callers and return errors to the caller before the statement's Exec or
// Query methods are called.
//
// NumInput may also return -1, if the driver doesn't know its number of
// placeholders. In that case, the sql package will not sanity check Exec or
// Query argument counts.
func (s *driverStmt) NumInput() int {
return s.paramCount
}
// Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
func (s *driverStmt) Exec(args []driver.Value) (driver.Result, error) {
c := s.conn
_, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
if err != nil {
return nil, errors.Trace(err)
}
r := &driverResult{}
if s != nil {
r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
}
return r, nil
}
// Exec executes a query that may return rows, such as a SELECT.
func (s *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
c := s.conn
rs, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
if err != nil {
return nil, errors.Trace(err)
}
if rs == nil {
if s.isQuery {
return nil, errors.Trace(errNoResult)
}
// The statement is not a query.
return &driverRows{}, nil
}
return &driverRows{params: s.conn.params, rs: rs}, nil
}
func init() {
RegisterDriver()
}