aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/denisenkom/go-mssqldb/mssql.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/mssql.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/mssql.go609
1 files changed, 0 insertions, 609 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go
deleted file mode 100644
index b997cd68..00000000
--- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go
+++ /dev/null
@@ -1,609 +0,0 @@
-package mssql
-
-import (
- "database/sql"
- "database/sql/driver"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "math"
- "net"
- "reflect"
- "strings"
- "time"
- "golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
-)
-
-var driverInstance = &MssqlDriver{processQueryText: true}
-var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
-
-func init() {
- sql.Register("mssql", driverInstance)
- sql.Register("sqlserver", driverInstanceNoProcess)
-}
-
-// Abstract the dialer for testing and for non-TCP based connections.
-type dialer interface {
- Dial(addr string) (net.Conn, error)
-}
-
-var createDialer func(p *connectParams) dialer
-
-type tcpDialer struct {
- nd *net.Dialer
-}
-
-func (d tcpDialer) Dial(addr string) (net.Conn, error) {
- return d.nd.Dial("tcp", addr)
-}
-
-type MssqlDriver struct {
- log optionalLogger
-
- processQueryText bool
-}
-
-func SetLogger(logger Logger) {
- driverInstance.SetLogger(logger)
- driverInstanceNoProcess.SetLogger(logger)
-}
-
-func (d *MssqlDriver) SetLogger(logger Logger) {
- d.log = optionalLogger{logger}
-}
-
-type MssqlConn struct {
- sess *tdsSession
- transactionCtx context.Context
-
- processQueryText bool
-}
-
-func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
- tokchan := make(chan tokenStruct, 5)
- go processResponse(ctx, c.sess, tokchan)
- for tok := range tokchan {
- switch token := tok.(type) {
- case doneStruct:
- if token.isError() {
- return token.getError()
- }
- case error:
- return token
- }
- }
- return nil
-}
-
-func (c *MssqlConn) Commit() error {
- if err := c.sendCommitRequest(); err != nil {
- return err
- }
- return c.simpleProcessResp(c.transactionCtx)
-}
-
-func (c *MssqlConn) sendCommitRequest() error {
- headers := []headerStruct{
- {hdrtype: dataStmHdrTransDescr,
- data: transDescrHdr{c.sess.tranid, 1}.pack()},
- }
- if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
- if c.sess.logFlags&logErrors != 0 {
- c.sess.log.Printf("Failed to send CommitXact with %v", err)
- }
- return driver.ErrBadConn
- }
- return nil
-}
-
-func (c *MssqlConn) Rollback() error {
- if err := c.sendRollbackRequest(); err != nil {
- return err
- }
- return c.simpleProcessResp(c.transactionCtx)
-}
-
-func (c *MssqlConn) sendRollbackRequest() error {
- headers := []headerStruct{
- {hdrtype: dataStmHdrTransDescr,
- data: transDescrHdr{c.sess.tranid, 1}.pack()},
- }
- if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
- if c.sess.logFlags&logErrors != 0 {
- c.sess.log.Printf("Failed to send RollbackXact with %v", err)
- }
- return driver.ErrBadConn
- }
- return nil
-}
-
-func (c *MssqlConn) Begin() (driver.Tx, error) {
- return c.begin(context.Background(), isolationUseCurrent)
-}
-
-func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {
- err := c.sendBeginRequest(ctx, tdsIsolation)
- if err != nil {
- return nil, err
- }
- return c.processBeginResponse(ctx)
-}
-
-func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
- c.transactionCtx = ctx
- headers := []headerStruct{
- {hdrtype: dataStmHdrTransDescr,
- data: transDescrHdr{0, 1}.pack()},
- }
- if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
- if c.sess.logFlags&logErrors != 0 {
- c.sess.log.Printf("Failed to send BeginXact with %v", err)
- }
- return driver.ErrBadConn
- }
- return nil
-}
-
-func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
- if err := c.simpleProcessResp(ctx); err != nil {
- return nil, err
- }
- // successful BEGINXACT request will return sess.tranid
- // for started transaction
- return c, nil
-}
-
-func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
- return d.open(dsn)
-}
-
-func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
- params, err := parseConnectParams(dsn)
- if err != nil {
- return nil, err
- }
-
- sess, err := connect(d.log, params)
- if err != nil {
- // main server failed, try fail-over partner
- if params.failOverPartner == "" {
- return nil, err
- }
-
- params.host = params.failOverPartner
- if params.failOverPort != 0 {
- params.port = params.failOverPort
- }
-
- sess, err = connect(d.log, params)
- if err != nil {
- // fail-over partner also failed, now fail
- return nil, err
- }
- }
-
- conn := &MssqlConn{sess, context.Background(), d.processQueryText}
- conn.sess.log = d.log
- return conn, nil
-}
-
-func (c *MssqlConn) Close() error {
- return c.sess.buf.transport.Close()
-}
-
-type MssqlStmt struct {
- c *MssqlConn
- query string
- paramCount int
- notifSub *queryNotifSub
-}
-
-type queryNotifSub struct {
- msgText string
- options string
- timeout uint32
-}
-
-func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
- return c.prepareContext(context.Background(), query)
-}
-
-func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
- paramCount := -1
- if c.processQueryText {
- query, paramCount = parseParams(query)
- }
- return &MssqlStmt{c, query, paramCount, nil}, nil
-}
-
-func (s *MssqlStmt) Close() error {
- return nil
-}
-
-func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
- to := uint32(timeout / time.Second)
- if to < 1 {
- to = 1
- }
- s.notifSub = &queryNotifSub{id, options, to}
-}
-
-func (s *MssqlStmt) NumInput() int {
- return s.paramCount
-}
-
-func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
- headers := []headerStruct{
- {hdrtype: dataStmHdrTransDescr,
- data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
- }
-
- if s.notifSub != nil {
- headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
- data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
- }
-
- // no need to check number of parameters here, it is checked by database/sql
- if s.c.sess.logFlags&logSQL != 0 {
- s.c.sess.log.Println(s.query)
- }
- if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
- for i := 0; i < len(args); i++ {
- s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
- }
-
- }
- if len(args) == 0 {
- if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
- if s.c.sess.logFlags&logErrors != 0 {
- s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
- }
- return driver.ErrBadConn
- }
- } else {
- params := make([]Param, len(args)+2)
- decls := make([]string, len(args))
- params[0] = makeStrParam(s.query)
- for i, val := range args {
- params[i+2], err = s.makeParam(val.Value)
- if err != nil {
- return
- }
- var name string
- if len(val.Name) > 0 {
- name = "@" + val.Name
- } else {
- name = fmt.Sprintf("@p%d", val.Ordinal)
- }
- params[i+2].Name = name
- decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
- }
- params[1] = makeStrParam(strings.Join(decls, ","))
- if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
- if s.c.sess.logFlags&logErrors != 0 {
- s.c.sess.log.Printf("Failed to send Rpc with %v", err)
- }
- return driver.ErrBadConn
- }
- }
- return
-}
-
-type namedValue struct {
- Name string
- Ordinal int
- Value driver.Value
-}
-
-func convertOldArgs(args []driver.Value) []namedValue {
- list := make([]namedValue, len(args))
- for i, v := range args {
- list[i] = namedValue{
- Ordinal: i + 1,
- Value: v,
- }
- }
- return list
-}
-
-func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
- return s.queryContext(context.Background(), convertOldArgs(args))
-}
-
-func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
- if err := s.sendQuery(args); err != nil {
- return nil, err
- }
- return s.processQueryResponse(ctx)
-}
-
-func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
- tokchan := make(chan tokenStruct, 5)
- ctx, cancel := context.WithCancel(ctx)
- go processResponse(ctx, s.c.sess, tokchan)
- // process metadata
- var cols []columnStruct
-loop:
- for tok := range tokchan {
- switch token := tok.(type) {
- // by ignoring DONE token we effectively
- // skip empty result-sets
- // this improves results in queryes like that:
- // set nocount on; select 1
- // see TestIgnoreEmptyResults test
- //case doneStruct:
- //break loop
- case []columnStruct:
- cols = token
- break loop
- case doneStruct:
- if token.isError() {
- return nil, token.getError()
- }
- case error:
- return nil, token
- }
- }
- res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel}
- return
-}
-
-func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
- return s.exec(context.Background(), convertOldArgs(args))
-}
-
-func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
- if err := s.sendQuery(args); err != nil {
- return nil, err
- }
- return s.processExec(ctx)
-}
-
-func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
- tokchan := make(chan tokenStruct, 5)
- go processResponse(ctx, s.c.sess, tokchan)
- var rowCount int64
- for token := range tokchan {
- switch token := token.(type) {
- case doneInProcStruct:
- if token.Status&doneCount != 0 {
- rowCount += int64(token.RowCount)
- }
- case doneStruct:
- if token.Status&doneCount != 0 {
- rowCount += int64(token.RowCount)
- }
- if token.isError() {
- return nil, token.getError()
- }
- case error:
- return nil, token
- }
- }
- return &MssqlResult{s.c, rowCount}, nil
-}
-
-type MssqlRows struct {
- sess *tdsSession
- cols []columnStruct
- tokchan chan tokenStruct
-
- nextCols []columnStruct
-
- cancel func()
-}
-
-func (rc *MssqlRows) Close() error {
- rc.cancel()
- for _ = range rc.tokchan {
- }
- rc.tokchan = nil
- return nil
-}
-
-func (rc *MssqlRows) Columns() (res []string) {
- res = make([]string, len(rc.cols))
- for i, col := range rc.cols {
- res[i] = col.ColName
- }
- return
-}
-
-func (rc *MssqlRows) Next(dest []driver.Value) error {
- if rc.nextCols != nil {
- return io.EOF
- }
- for tok := range rc.tokchan {
- switch tokdata := tok.(type) {
- case []columnStruct:
- rc.nextCols = tokdata
- return io.EOF
- case []interface{}:
- for i := range dest {
- dest[i] = tokdata[i]
- }
- return nil
- case doneStruct:
- if tokdata.isError() {
- return tokdata.getError()
- }
- case error:
- return tokdata
- }
- }
- return io.EOF
-}
-
-func (rc *MssqlRows) HasNextResultSet() bool {
- return rc.nextCols != nil
-}
-
-func (rc *MssqlRows) NextResultSet() error {
- rc.cols = rc.nextCols
- rc.nextCols = nil
- if rc.cols == nil {
- return io.EOF
- }
- return nil
-}
-
-// It should return
-// the value type that can be used to scan types into. For example, the database
-// column type "bigint" this should return "reflect.TypeOf(int64(0))".
-func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
- return makeGoLangScanType(r.cols[index].ti)
-}
-
-// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
-// database system type name without the length. Type names should be uppercase.
-// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
-// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
-// "TIMESTAMP".
-func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
- return makeGoLangTypeName(r.cols[index].ti)
-}
-
-// RowsColumnTypeLength may be implemented by Rows. It should return the length
-// of the column type if the column is a variable length type. If the column is
-// not a variable length type ok should return false.
-// If length is not limited other than system limits, it should return math.MaxInt64.
-// The following are examples of returned values for various types:
-// TEXT (math.MaxInt64, true)
-// varchar(10) (10, true)
-// nvarchar(10) (10, true)
-// decimal (0, false)
-// int (0, false)
-// bytea(30) (30, true)
-func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
- return makeGoLangTypeLength(r.cols[index].ti)
-}
-
-// It should return
-// the precision and scale for decimal types. If not applicable, ok should be false.
-// The following are examples of returned values for various types:
-// decimal(38, 4) (38, 4, true)
-// int (0, 0, false)
-// decimal (math.MaxInt64, math.MaxInt64, true)
-func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
- return makeGoLangTypePrecisionScale(r.cols[index].ti)
-}
-
-// The nullable value should
-// be true if it is known the column may be null, or false if the column is known
-// to be not nullable.
-// If the column nullability is unknown, ok should be false.
-func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
- nullable = r.cols[index].Flags&colFlagNullable != 0
- ok = true
- return
-}
-
-func makeStrParam(val string) (res Param) {
- res.ti.TypeId = typeNVarChar
- res.buffer = str2ucs2(val)
- res.ti.Size = len(res.buffer)
- return
-}
-
-func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
- if val == nil {
- res.ti.TypeId = typeNVarChar
- res.buffer = nil
- res.ti.Size = 2
- return
- }
- switch val := val.(type) {
- case int64:
- res.ti.TypeId = typeIntN
- res.buffer = make([]byte, 8)
- res.ti.Size = 8
- binary.LittleEndian.PutUint64(res.buffer, uint64(val))
- case float64:
- res.ti.TypeId = typeFltN
- res.ti.Size = 8
- res.buffer = make([]byte, 8)
- binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
- case []byte:
- res.ti.TypeId = typeBigVarBin
- res.ti.Size = len(val)
- res.buffer = val
- case string:
- res = makeStrParam(val)
- case bool:
- res.ti.TypeId = typeBitN
- res.ti.Size = 1
- res.buffer = make([]byte, 1)
- if val {
- res.buffer[0] = 1
- }
- case time.Time:
- if s.c.sess.loginAck.TDSVersion >= verTDS73 {
- res.ti.TypeId = typeDateTimeOffsetN
- res.ti.Scale = 7
- res.ti.Size = 10
- buf := make([]byte, 10)
- res.buffer = buf
- days, ns := dateTime2(val)
- ns /= 100
- buf[0] = byte(ns)
- buf[1] = byte(ns >> 8)
- buf[2] = byte(ns >> 16)
- buf[3] = byte(ns >> 24)
- buf[4] = byte(ns >> 32)
- buf[5] = byte(days)
- buf[6] = byte(days >> 8)
- buf[7] = byte(days >> 16)
- _, offset := val.Zone()
- offset /= 60
- buf[8] = byte(offset)
- buf[9] = byte(offset >> 8)
- } else {
- res.ti.TypeId = typeDateTimeN
- res.ti.Size = 8
- res.buffer = make([]byte, 8)
- ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
- dur := val.Sub(ref)
- days := dur / (24 * time.Hour)
- tm := (300 * (dur % (24 * time.Hour))) / time.Second
- binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
- binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
- }
- default:
- err = fmt.Errorf("mssql: unknown type for %T", val)
- return
- }
- return
-}
-
-type MssqlResult struct {
- c *MssqlConn
- rowsAffected int64
-}
-
-func (r *MssqlResult) RowsAffected() (int64, error) {
- return r.rowsAffected, nil
-}
-
-func (r *MssqlResult) LastInsertId() (int64, error) {
- s, err := r.c.Prepare("select cast(@@identity as bigint)")
- if err != nil {
- return 0, err
- }
- defer s.Close()
- rows, err := s.Query(nil)
- if err != nil {
- return 0, err
- }
- defer rows.Close()
- dest := make([]driver.Value, 1)
- err = rows.Next(dest)
- if err != nil {
- return 0, err
- }
- if dest[0] == nil {
- return -1, errors.New("There is no generated identity value")
- }
- lastInsertId := dest[0].(int64)
- return lastInsertId, nil
-}