aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/denisenkom/go-mssqldb/mssql.go
diff options
context:
space:
mode:
authorUnknwon <u@gogs.io>2017-02-13 20:52:35 -0500
committerUnknwon <u@gogs.io>2017-02-13 20:52:35 -0500
commitf967e9d02114553535a267d076df37201f74ddf5 (patch)
tree9c3032f722a2523e153a69c5c02b6dc3ad4345af /vendor/github.com/denisenkom/go-mssqldb/mssql.go
parent5179063e71d7234250209389e6f69db1cd6f0caa (diff)
vendor: add new dependency (#3772)
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/mssql.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/mssql.go609
1 files changed, 609 insertions, 0 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go
new file mode 100644
index 00000000..b997cd68
--- /dev/null
+++ b/vendor/github.com/denisenkom/go-mssqldb/mssql.go
@@ -0,0 +1,609 @@
+package mssql
+
+import (
+ "database/sql"
+ "database/sql/driver"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "reflect"
+ "strings"
+ "time"
+ "golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
+)
+
+var driverInstance = &MssqlDriver{processQueryText: true}
+var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
+
+func init() {
+ sql.Register("mssql", driverInstance)
+ sql.Register("sqlserver", driverInstanceNoProcess)
+}
+
+// Abstract the dialer for testing and for non-TCP based connections.
+type dialer interface {
+ Dial(addr string) (net.Conn, error)
+}
+
+var createDialer func(p *connectParams) dialer
+
+type tcpDialer struct {
+ nd *net.Dialer
+}
+
+func (d tcpDialer) Dial(addr string) (net.Conn, error) {
+ return d.nd.Dial("tcp", addr)
+}
+
+type MssqlDriver struct {
+ log optionalLogger
+
+ processQueryText bool
+}
+
+func SetLogger(logger Logger) {
+ driverInstance.SetLogger(logger)
+ driverInstanceNoProcess.SetLogger(logger)
+}
+
+func (d *MssqlDriver) SetLogger(logger Logger) {
+ d.log = optionalLogger{logger}
+}
+
+type MssqlConn struct {
+ sess *tdsSession
+ transactionCtx context.Context
+
+ processQueryText bool
+}
+
+func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
+ tokchan := make(chan tokenStruct, 5)
+ go processResponse(ctx, c.sess, tokchan)
+ for tok := range tokchan {
+ switch token := tok.(type) {
+ case doneStruct:
+ if token.isError() {
+ return token.getError()
+ }
+ case error:
+ return token
+ }
+ }
+ return nil
+}
+
+func (c *MssqlConn) Commit() error {
+ if err := c.sendCommitRequest(); err != nil {
+ return err
+ }
+ return c.simpleProcessResp(c.transactionCtx)
+}
+
+func (c *MssqlConn) sendCommitRequest() error {
+ headers := []headerStruct{
+ {hdrtype: dataStmHdrTransDescr,
+ data: transDescrHdr{c.sess.tranid, 1}.pack()},
+ }
+ if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
+ if c.sess.logFlags&logErrors != 0 {
+ c.sess.log.Printf("Failed to send CommitXact with %v", err)
+ }
+ return driver.ErrBadConn
+ }
+ return nil
+}
+
+func (c *MssqlConn) Rollback() error {
+ if err := c.sendRollbackRequest(); err != nil {
+ return err
+ }
+ return c.simpleProcessResp(c.transactionCtx)
+}
+
+func (c *MssqlConn) sendRollbackRequest() error {
+ headers := []headerStruct{
+ {hdrtype: dataStmHdrTransDescr,
+ data: transDescrHdr{c.sess.tranid, 1}.pack()},
+ }
+ if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
+ if c.sess.logFlags&logErrors != 0 {
+ c.sess.log.Printf("Failed to send RollbackXact with %v", err)
+ }
+ return driver.ErrBadConn
+ }
+ return nil
+}
+
+func (c *MssqlConn) Begin() (driver.Tx, error) {
+ return c.begin(context.Background(), isolationUseCurrent)
+}
+
+func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {
+ err := c.sendBeginRequest(ctx, tdsIsolation)
+ if err != nil {
+ return nil, err
+ }
+ return c.processBeginResponse(ctx)
+}
+
+func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
+ c.transactionCtx = ctx
+ headers := []headerStruct{
+ {hdrtype: dataStmHdrTransDescr,
+ data: transDescrHdr{0, 1}.pack()},
+ }
+ if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
+ if c.sess.logFlags&logErrors != 0 {
+ c.sess.log.Printf("Failed to send BeginXact with %v", err)
+ }
+ return driver.ErrBadConn
+ }
+ return nil
+}
+
+func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
+ if err := c.simpleProcessResp(ctx); err != nil {
+ return nil, err
+ }
+ // successful BEGINXACT request will return sess.tranid
+ // for started transaction
+ return c, nil
+}
+
+func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
+ return d.open(dsn)
+}
+
+func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
+ params, err := parseConnectParams(dsn)
+ if err != nil {
+ return nil, err
+ }
+
+ sess, err := connect(d.log, params)
+ if err != nil {
+ // main server failed, try fail-over partner
+ if params.failOverPartner == "" {
+ return nil, err
+ }
+
+ params.host = params.failOverPartner
+ if params.failOverPort != 0 {
+ params.port = params.failOverPort
+ }
+
+ sess, err = connect(d.log, params)
+ if err != nil {
+ // fail-over partner also failed, now fail
+ return nil, err
+ }
+ }
+
+ conn := &MssqlConn{sess, context.Background(), d.processQueryText}
+ conn.sess.log = d.log
+ return conn, nil
+}
+
+func (c *MssqlConn) Close() error {
+ return c.sess.buf.transport.Close()
+}
+
+type MssqlStmt struct {
+ c *MssqlConn
+ query string
+ paramCount int
+ notifSub *queryNotifSub
+}
+
+type queryNotifSub struct {
+ msgText string
+ options string
+ timeout uint32
+}
+
+func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
+ return c.prepareContext(context.Background(), query)
+}
+
+func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
+ paramCount := -1
+ if c.processQueryText {
+ query, paramCount = parseParams(query)
+ }
+ return &MssqlStmt{c, query, paramCount, nil}, nil
+}
+
+func (s *MssqlStmt) Close() error {
+ return nil
+}
+
+func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
+ to := uint32(timeout / time.Second)
+ if to < 1 {
+ to = 1
+ }
+ s.notifSub = &queryNotifSub{id, options, to}
+}
+
+func (s *MssqlStmt) NumInput() int {
+ return s.paramCount
+}
+
+func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
+ headers := []headerStruct{
+ {hdrtype: dataStmHdrTransDescr,
+ data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
+ }
+
+ if s.notifSub != nil {
+ headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
+ data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
+ }
+
+ // no need to check number of parameters here, it is checked by database/sql
+ if s.c.sess.logFlags&logSQL != 0 {
+ s.c.sess.log.Println(s.query)
+ }
+ if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
+ for i := 0; i < len(args); i++ {
+ s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
+ }
+
+ }
+ if len(args) == 0 {
+ if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
+ if s.c.sess.logFlags&logErrors != 0 {
+ s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
+ }
+ return driver.ErrBadConn
+ }
+ } else {
+ params := make([]Param, len(args)+2)
+ decls := make([]string, len(args))
+ params[0] = makeStrParam(s.query)
+ for i, val := range args {
+ params[i+2], err = s.makeParam(val.Value)
+ if err != nil {
+ return
+ }
+ var name string
+ if len(val.Name) > 0 {
+ name = "@" + val.Name
+ } else {
+ name = fmt.Sprintf("@p%d", val.Ordinal)
+ }
+ params[i+2].Name = name
+ decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
+ }
+ params[1] = makeStrParam(strings.Join(decls, ","))
+ if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
+ if s.c.sess.logFlags&logErrors != 0 {
+ s.c.sess.log.Printf("Failed to send Rpc with %v", err)
+ }
+ return driver.ErrBadConn
+ }
+ }
+ return
+}
+
+type namedValue struct {
+ Name string
+ Ordinal int
+ Value driver.Value
+}
+
+func convertOldArgs(args []driver.Value) []namedValue {
+ list := make([]namedValue, len(args))
+ for i, v := range args {
+ list[i] = namedValue{
+ Ordinal: i + 1,
+ Value: v,
+ }
+ }
+ return list
+}
+
+func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+ return s.queryContext(context.Background(), convertOldArgs(args))
+}
+
+func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
+ if err := s.sendQuery(args); err != nil {
+ return nil, err
+ }
+ return s.processQueryResponse(ctx)
+}
+
+func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
+ tokchan := make(chan tokenStruct, 5)
+ ctx, cancel := context.WithCancel(ctx)
+ go processResponse(ctx, s.c.sess, tokchan)
+ // process metadata
+ var cols []columnStruct
+loop:
+ for tok := range tokchan {
+ switch token := tok.(type) {
+ // by ignoring DONE token we effectively
+ // skip empty result-sets
+ // this improves results in queryes like that:
+ // set nocount on; select 1
+ // see TestIgnoreEmptyResults test
+ //case doneStruct:
+ //break loop
+ case []columnStruct:
+ cols = token
+ break loop
+ case doneStruct:
+ if token.isError() {
+ return nil, token.getError()
+ }
+ case error:
+ return nil, token
+ }
+ }
+ res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel}
+ return
+}
+
+func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+ return s.exec(context.Background(), convertOldArgs(args))
+}
+
+func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
+ if err := s.sendQuery(args); err != nil {
+ return nil, err
+ }
+ return s.processExec(ctx)
+}
+
+func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
+ tokchan := make(chan tokenStruct, 5)
+ go processResponse(ctx, s.c.sess, tokchan)
+ var rowCount int64
+ for token := range tokchan {
+ switch token := token.(type) {
+ case doneInProcStruct:
+ if token.Status&doneCount != 0 {
+ rowCount += int64(token.RowCount)
+ }
+ case doneStruct:
+ if token.Status&doneCount != 0 {
+ rowCount += int64(token.RowCount)
+ }
+ if token.isError() {
+ return nil, token.getError()
+ }
+ case error:
+ return nil, token
+ }
+ }
+ return &MssqlResult{s.c, rowCount}, nil
+}
+
+type MssqlRows struct {
+ sess *tdsSession
+ cols []columnStruct
+ tokchan chan tokenStruct
+
+ nextCols []columnStruct
+
+ cancel func()
+}
+
+func (rc *MssqlRows) Close() error {
+ rc.cancel()
+ for _ = range rc.tokchan {
+ }
+ rc.tokchan = nil
+ return nil
+}
+
+func (rc *MssqlRows) Columns() (res []string) {
+ res = make([]string, len(rc.cols))
+ for i, col := range rc.cols {
+ res[i] = col.ColName
+ }
+ return
+}
+
+func (rc *MssqlRows) Next(dest []driver.Value) error {
+ if rc.nextCols != nil {
+ return io.EOF
+ }
+ for tok := range rc.tokchan {
+ switch tokdata := tok.(type) {
+ case []columnStruct:
+ rc.nextCols = tokdata
+ return io.EOF
+ case []interface{}:
+ for i := range dest {
+ dest[i] = tokdata[i]
+ }
+ return nil
+ case doneStruct:
+ if tokdata.isError() {
+ return tokdata.getError()
+ }
+ case error:
+ return tokdata
+ }
+ }
+ return io.EOF
+}
+
+func (rc *MssqlRows) HasNextResultSet() bool {
+ return rc.nextCols != nil
+}
+
+func (rc *MssqlRows) NextResultSet() error {
+ rc.cols = rc.nextCols
+ rc.nextCols = nil
+ if rc.cols == nil {
+ return io.EOF
+ }
+ return nil
+}
+
+// It should return
+// the value type that can be used to scan types into. For example, the database
+// column type "bigint" this should return "reflect.TypeOf(int64(0))".
+func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
+ return makeGoLangScanType(r.cols[index].ti)
+}
+
+// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
+// database system type name without the length. Type names should be uppercase.
+// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
+// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
+// "TIMESTAMP".
+func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
+ return makeGoLangTypeName(r.cols[index].ti)
+}
+
+// RowsColumnTypeLength may be implemented by Rows. It should return the length
+// of the column type if the column is a variable length type. If the column is
+// not a variable length type ok should return false.
+// If length is not limited other than system limits, it should return math.MaxInt64.
+// The following are examples of returned values for various types:
+// TEXT (math.MaxInt64, true)
+// varchar(10) (10, true)
+// nvarchar(10) (10, true)
+// decimal (0, false)
+// int (0, false)
+// bytea(30) (30, true)
+func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
+ return makeGoLangTypeLength(r.cols[index].ti)
+}
+
+// It should return
+// the precision and scale for decimal types. If not applicable, ok should be false.
+// The following are examples of returned values for various types:
+// decimal(38, 4) (38, 4, true)
+// int (0, 0, false)
+// decimal (math.MaxInt64, math.MaxInt64, true)
+func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
+ return makeGoLangTypePrecisionScale(r.cols[index].ti)
+}
+
+// The nullable value should
+// be true if it is known the column may be null, or false if the column is known
+// to be not nullable.
+// If the column nullability is unknown, ok should be false.
+func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
+ nullable = r.cols[index].Flags&colFlagNullable != 0
+ ok = true
+ return
+}
+
+func makeStrParam(val string) (res Param) {
+ res.ti.TypeId = typeNVarChar
+ res.buffer = str2ucs2(val)
+ res.ti.Size = len(res.buffer)
+ return
+}
+
+func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
+ if val == nil {
+ res.ti.TypeId = typeNVarChar
+ res.buffer = nil
+ res.ti.Size = 2
+ return
+ }
+ switch val := val.(type) {
+ case int64:
+ res.ti.TypeId = typeIntN
+ res.buffer = make([]byte, 8)
+ res.ti.Size = 8
+ binary.LittleEndian.PutUint64(res.buffer, uint64(val))
+ case float64:
+ res.ti.TypeId = typeFltN
+ res.ti.Size = 8
+ res.buffer = make([]byte, 8)
+ binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
+ case []byte:
+ res.ti.TypeId = typeBigVarBin
+ res.ti.Size = len(val)
+ res.buffer = val
+ case string:
+ res = makeStrParam(val)
+ case bool:
+ res.ti.TypeId = typeBitN
+ res.ti.Size = 1
+ res.buffer = make([]byte, 1)
+ if val {
+ res.buffer[0] = 1
+ }
+ case time.Time:
+ if s.c.sess.loginAck.TDSVersion >= verTDS73 {
+ res.ti.TypeId = typeDateTimeOffsetN
+ res.ti.Scale = 7
+ res.ti.Size = 10
+ buf := make([]byte, 10)
+ res.buffer = buf
+ days, ns := dateTime2(val)
+ ns /= 100
+ buf[0] = byte(ns)
+ buf[1] = byte(ns >> 8)
+ buf[2] = byte(ns >> 16)
+ buf[3] = byte(ns >> 24)
+ buf[4] = byte(ns >> 32)
+ buf[5] = byte(days)
+ buf[6] = byte(days >> 8)
+ buf[7] = byte(days >> 16)
+ _, offset := val.Zone()
+ offset /= 60
+ buf[8] = byte(offset)
+ buf[9] = byte(offset >> 8)
+ } else {
+ res.ti.TypeId = typeDateTimeN
+ res.ti.Size = 8
+ res.buffer = make([]byte, 8)
+ ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
+ dur := val.Sub(ref)
+ days := dur / (24 * time.Hour)
+ tm := (300 * (dur % (24 * time.Hour))) / time.Second
+ binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
+ binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
+ }
+ default:
+ err = fmt.Errorf("mssql: unknown type for %T", val)
+ return
+ }
+ return
+}
+
+type MssqlResult struct {
+ c *MssqlConn
+ rowsAffected int64
+}
+
+func (r *MssqlResult) RowsAffected() (int64, error) {
+ return r.rowsAffected, nil
+}
+
+func (r *MssqlResult) LastInsertId() (int64, error) {
+ s, err := r.c.Prepare("select cast(@@identity as bigint)")
+ if err != nil {
+ return 0, err
+ }
+ defer s.Close()
+ rows, err := s.Query(nil)
+ if err != nil {
+ return 0, err
+ }
+ defer rows.Close()
+ dest := make([]driver.Value, 1)
+ err = rows.Next(dest)
+ if err != nil {
+ return 0, err
+ }
+ if dest[0] == nil {
+ return -1, errors.New("There is no generated identity value")
+ }
+ lastInsertId := dest[0].(int64)
+ return lastInsertId, nil
+}