diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/mssql.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/mssql.go | 609 |
1 files changed, 0 insertions, 609 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go deleted file mode 100644 index b997cd68..00000000 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go +++ /dev/null @@ -1,609 +0,0 @@ -package mssql - -import ( - "database/sql" - "database/sql/driver" - "encoding/binary" - "errors" - "fmt" - "io" - "math" - "net" - "reflect" - "strings" - "time" - "golang.org/x/net/context" // use the "x/net/context" for backwards compatibility. -) - -var driverInstance = &MssqlDriver{processQueryText: true} -var driverInstanceNoProcess = &MssqlDriver{processQueryText: false} - -func init() { - sql.Register("mssql", driverInstance) - sql.Register("sqlserver", driverInstanceNoProcess) -} - -// Abstract the dialer for testing and for non-TCP based connections. -type dialer interface { - Dial(addr string) (net.Conn, error) -} - -var createDialer func(p *connectParams) dialer - -type tcpDialer struct { - nd *net.Dialer -} - -func (d tcpDialer) Dial(addr string) (net.Conn, error) { - return d.nd.Dial("tcp", addr) -} - -type MssqlDriver struct { - log optionalLogger - - processQueryText bool -} - -func SetLogger(logger Logger) { - driverInstance.SetLogger(logger) - driverInstanceNoProcess.SetLogger(logger) -} - -func (d *MssqlDriver) SetLogger(logger Logger) { - d.log = optionalLogger{logger} -} - -type MssqlConn struct { - sess *tdsSession - transactionCtx context.Context - - processQueryText bool -} - -func (c *MssqlConn) simpleProcessResp(ctx context.Context) error { - tokchan := make(chan tokenStruct, 5) - go processResponse(ctx, c.sess, tokchan) - for tok := range tokchan { - switch token := tok.(type) { - case doneStruct: - if token.isError() { - return token.getError() - } - case error: - return token - } - } - return nil -} - -func (c *MssqlConn) Commit() error { - if err := c.sendCommitRequest(); err != nil { - return err - } - return c.simpleProcessResp(c.transactionCtx) -} - -func (c *MssqlConn) sendCommitRequest() error { - headers := []headerStruct{ - {hdrtype: dataStmHdrTransDescr, - data: transDescrHdr{c.sess.tranid, 1}.pack()}, - } - if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { - if c.sess.logFlags&logErrors != 0 { - c.sess.log.Printf("Failed to send CommitXact with %v", err) - } - return driver.ErrBadConn - } - return nil -} - -func (c *MssqlConn) Rollback() error { - if err := c.sendRollbackRequest(); err != nil { - return err - } - return c.simpleProcessResp(c.transactionCtx) -} - -func (c *MssqlConn) sendRollbackRequest() error { - headers := []headerStruct{ - {hdrtype: dataStmHdrTransDescr, - data: transDescrHdr{c.sess.tranid, 1}.pack()}, - } - if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { - if c.sess.logFlags&logErrors != 0 { - c.sess.log.Printf("Failed to send RollbackXact with %v", err) - } - return driver.ErrBadConn - } - return nil -} - -func (c *MssqlConn) Begin() (driver.Tx, error) { - return c.begin(context.Background(), isolationUseCurrent) -} - -func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) { - err := c.sendBeginRequest(ctx, tdsIsolation) - if err != nil { - return nil, err - } - return c.processBeginResponse(ctx) -} - -func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error { - c.transactionCtx = ctx - headers := []headerStruct{ - {hdrtype: dataStmHdrTransDescr, - data: transDescrHdr{0, 1}.pack()}, - } - if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil { - if c.sess.logFlags&logErrors != 0 { - c.sess.log.Printf("Failed to send BeginXact with %v", err) - } - return driver.ErrBadConn - } - return nil -} - -func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) { - if err := c.simpleProcessResp(ctx); err != nil { - return nil, err - } - // successful BEGINXACT request will return sess.tranid - // for started transaction - return c, nil -} - -func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) { - return d.open(dsn) -} - -func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) { - params, err := parseConnectParams(dsn) - if err != nil { - return nil, err - } - - sess, err := connect(d.log, params) - if err != nil { - // main server failed, try fail-over partner - if params.failOverPartner == "" { - return nil, err - } - - params.host = params.failOverPartner - if params.failOverPort != 0 { - params.port = params.failOverPort - } - - sess, err = connect(d.log, params) - if err != nil { - // fail-over partner also failed, now fail - return nil, err - } - } - - conn := &MssqlConn{sess, context.Background(), d.processQueryText} - conn.sess.log = d.log - return conn, nil -} - -func (c *MssqlConn) Close() error { - return c.sess.buf.transport.Close() -} - -type MssqlStmt struct { - c *MssqlConn - query string - paramCount int - notifSub *queryNotifSub -} - -type queryNotifSub struct { - msgText string - options string - timeout uint32 -} - -func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) { - return c.prepareContext(context.Background(), query) -} - -func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) { - paramCount := -1 - if c.processQueryText { - query, paramCount = parseParams(query) - } - return &MssqlStmt{c, query, paramCount, nil}, nil -} - -func (s *MssqlStmt) Close() error { - return nil -} - -func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) { - to := uint32(timeout / time.Second) - if to < 1 { - to = 1 - } - s.notifSub = &queryNotifSub{id, options, to} -} - -func (s *MssqlStmt) NumInput() int { - return s.paramCount -} - -func (s *MssqlStmt) sendQuery(args []namedValue) (err error) { - headers := []headerStruct{ - {hdrtype: dataStmHdrTransDescr, - data: transDescrHdr{s.c.sess.tranid, 1}.pack()}, - } - - if s.notifSub != nil { - headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif, - data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()}) - } - - // no need to check number of parameters here, it is checked by database/sql - if s.c.sess.logFlags&logSQL != 0 { - s.c.sess.log.Println(s.query) - } - if s.c.sess.logFlags&logParams != 0 && len(args) > 0 { - for i := 0; i < len(args); i++ { - s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i]) - } - - } - if len(args) == 0 { - if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil { - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Printf("Failed to send SqlBatch with %v", err) - } - return driver.ErrBadConn - } - } else { - params := make([]Param, len(args)+2) - decls := make([]string, len(args)) - params[0] = makeStrParam(s.query) - for i, val := range args { - params[i+2], err = s.makeParam(val.Value) - if err != nil { - return - } - var name string - if len(val.Name) > 0 { - name = "@" + val.Name - } else { - name = fmt.Sprintf("@p%d", val.Ordinal) - } - params[i+2].Name = name - decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti)) - } - params[1] = makeStrParam(strings.Join(decls, ",")) - if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil { - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Printf("Failed to send Rpc with %v", err) - } - return driver.ErrBadConn - } - } - return -} - -type namedValue struct { - Name string - Ordinal int - Value driver.Value -} - -func convertOldArgs(args []driver.Value) []namedValue { - list := make([]namedValue, len(args)) - for i, v := range args { - list[i] = namedValue{ - Ordinal: i + 1, - Value: v, - } - } - return list -} - -func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) { - return s.queryContext(context.Background(), convertOldArgs(args)) -} - -func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) { - if err := s.sendQuery(args); err != nil { - return nil, err - } - return s.processQueryResponse(ctx) -} - -func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) { - tokchan := make(chan tokenStruct, 5) - ctx, cancel := context.WithCancel(ctx) - go processResponse(ctx, s.c.sess, tokchan) - // process metadata - var cols []columnStruct -loop: - for tok := range tokchan { - switch token := tok.(type) { - // by ignoring DONE token we effectively - // skip empty result-sets - // this improves results in queryes like that: - // set nocount on; select 1 - // see TestIgnoreEmptyResults test - //case doneStruct: - //break loop - case []columnStruct: - cols = token - break loop - case doneStruct: - if token.isError() { - return nil, token.getError() - } - case error: - return nil, token - } - } - res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel} - return -} - -func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) { - return s.exec(context.Background(), convertOldArgs(args)) -} - -func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) { - if err := s.sendQuery(args); err != nil { - return nil, err - } - return s.processExec(ctx) -} - -func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) { - tokchan := make(chan tokenStruct, 5) - go processResponse(ctx, s.c.sess, tokchan) - var rowCount int64 - for token := range tokchan { - switch token := token.(type) { - case doneInProcStruct: - if token.Status&doneCount != 0 { - rowCount += int64(token.RowCount) - } - case doneStruct: - if token.Status&doneCount != 0 { - rowCount += int64(token.RowCount) - } - if token.isError() { - return nil, token.getError() - } - case error: - return nil, token - } - } - return &MssqlResult{s.c, rowCount}, nil -} - -type MssqlRows struct { - sess *tdsSession - cols []columnStruct - tokchan chan tokenStruct - - nextCols []columnStruct - - cancel func() -} - -func (rc *MssqlRows) Close() error { - rc.cancel() - for _ = range rc.tokchan { - } - rc.tokchan = nil - return nil -} - -func (rc *MssqlRows) Columns() (res []string) { - res = make([]string, len(rc.cols)) - for i, col := range rc.cols { - res[i] = col.ColName - } - return -} - -func (rc *MssqlRows) Next(dest []driver.Value) error { - if rc.nextCols != nil { - return io.EOF - } - for tok := range rc.tokchan { - switch tokdata := tok.(type) { - case []columnStruct: - rc.nextCols = tokdata - return io.EOF - case []interface{}: - for i := range dest { - dest[i] = tokdata[i] - } - return nil - case doneStruct: - if tokdata.isError() { - return tokdata.getError() - } - case error: - return tokdata - } - } - return io.EOF -} - -func (rc *MssqlRows) HasNextResultSet() bool { - return rc.nextCols != nil -} - -func (rc *MssqlRows) NextResultSet() error { - rc.cols = rc.nextCols - rc.nextCols = nil - if rc.cols == nil { - return io.EOF - } - return nil -} - -// It should return -// the value type that can be used to scan types into. For example, the database -// column type "bigint" this should return "reflect.TypeOf(int64(0))". -func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type { - return makeGoLangScanType(r.cols[index].ti) -} - -// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the -// database system type name without the length. Type names should be uppercase. -// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", -// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", -// "TIMESTAMP". -func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string { - return makeGoLangTypeName(r.cols[index].ti) -} - -// RowsColumnTypeLength may be implemented by Rows. It should return the length -// of the column type if the column is a variable length type. If the column is -// not a variable length type ok should return false. -// If length is not limited other than system limits, it should return math.MaxInt64. -// The following are examples of returned values for various types: -// TEXT (math.MaxInt64, true) -// varchar(10) (10, true) -// nvarchar(10) (10, true) -// decimal (0, false) -// int (0, false) -// bytea(30) (30, true) -func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) { - return makeGoLangTypeLength(r.cols[index].ti) -} - -// It should return -// the precision and scale for decimal types. If not applicable, ok should be false. -// The following are examples of returned values for various types: -// decimal(38, 4) (38, 4, true) -// int (0, 0, false) -// decimal (math.MaxInt64, math.MaxInt64, true) -func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { - return makeGoLangTypePrecisionScale(r.cols[index].ti) -} - -// The nullable value should -// be true if it is known the column may be null, or false if the column is known -// to be not nullable. -// If the column nullability is unknown, ok should be false. -func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) { - nullable = r.cols[index].Flags&colFlagNullable != 0 - ok = true - return -} - -func makeStrParam(val string) (res Param) { - res.ti.TypeId = typeNVarChar - res.buffer = str2ucs2(val) - res.ti.Size = len(res.buffer) - return -} - -func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) { - if val == nil { - res.ti.TypeId = typeNVarChar - res.buffer = nil - res.ti.Size = 2 - return - } - switch val := val.(type) { - case int64: - res.ti.TypeId = typeIntN - res.buffer = make([]byte, 8) - res.ti.Size = 8 - binary.LittleEndian.PutUint64(res.buffer, uint64(val)) - case float64: - res.ti.TypeId = typeFltN - res.ti.Size = 8 - res.buffer = make([]byte, 8) - binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val)) - case []byte: - res.ti.TypeId = typeBigVarBin - res.ti.Size = len(val) - res.buffer = val - case string: - res = makeStrParam(val) - case bool: - res.ti.TypeId = typeBitN - res.ti.Size = 1 - res.buffer = make([]byte, 1) - if val { - res.buffer[0] = 1 - } - case time.Time: - if s.c.sess.loginAck.TDSVersion >= verTDS73 { - res.ti.TypeId = typeDateTimeOffsetN - res.ti.Scale = 7 - res.ti.Size = 10 - buf := make([]byte, 10) - res.buffer = buf - days, ns := dateTime2(val) - ns /= 100 - buf[0] = byte(ns) - buf[1] = byte(ns >> 8) - buf[2] = byte(ns >> 16) - buf[3] = byte(ns >> 24) - buf[4] = byte(ns >> 32) - buf[5] = byte(days) - buf[6] = byte(days >> 8) - buf[7] = byte(days >> 16) - _, offset := val.Zone() - offset /= 60 - buf[8] = byte(offset) - buf[9] = byte(offset >> 8) - } else { - res.ti.TypeId = typeDateTimeN - res.ti.Size = 8 - res.buffer = make([]byte, 8) - ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) - dur := val.Sub(ref) - days := dur / (24 * time.Hour) - tm := (300 * (dur % (24 * time.Hour))) / time.Second - binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days)) - binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm)) - } - default: - err = fmt.Errorf("mssql: unknown type for %T", val) - return - } - return -} - -type MssqlResult struct { - c *MssqlConn - rowsAffected int64 -} - -func (r *MssqlResult) RowsAffected() (int64, error) { - return r.rowsAffected, nil -} - -func (r *MssqlResult) LastInsertId() (int64, error) { - s, err := r.c.Prepare("select cast(@@identity as bigint)") - if err != nil { - return 0, err - } - defer s.Close() - rows, err := s.Query(nil) - if err != nil { - return 0, err - } - defer rows.Close() - dest := make([]driver.Value, 1) - err = rows.Next(dest) - if err != nil { - return 0, err - } - if dest[0] == nil { - return -1, errors.New("There is no generated identity value") - } - lastInsertId := dest[0].(int64) - return lastInsertId, nil -} |