diff options
Diffstat (limited to 'vendor/github.com/lib/pq/conn.go')
-rw-r--r-- | vendor/github.com/lib/pq/conn.go | 1893 |
1 files changed, 0 insertions, 1893 deletions
diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go deleted file mode 100644 index cb44c692..00000000 --- a/vendor/github.com/lib/pq/conn.go +++ /dev/null @@ -1,1893 +0,0 @@ -package pq - -import ( - "bufio" - "crypto/md5" - "crypto/tls" - "crypto/x509" - "database/sql" - "database/sql/driver" - "encoding/binary" - "errors" - "fmt" - "io" - "io/ioutil" - "net" - "os" - "os/user" - "path" - "path/filepath" - "strconv" - "strings" - "time" - "unicode" - - "github.com/lib/pq/oid" -) - -// Common error types -var ( - ErrNotSupported = errors.New("pq: Unsupported command") - ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") - ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") - - errUnexpectedReady = errors.New("unexpected ReadyForQuery") - errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") - errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") -) - -type drv struct{} - -func (d *drv) Open(name string) (driver.Conn, error) { - return Open(name) -} - -func init() { - sql.Register("postgres", &drv{}) -} - -type parameterStatus struct { - // server version in the same format as server_version_num, or 0 if - // unavailable - serverVersion int - - // the current location based on the TimeZone value of the session, if - // available - currentLocation *time.Location -} - -type transactionStatus byte - -const ( - txnStatusIdle transactionStatus = 'I' - txnStatusIdleInTransaction transactionStatus = 'T' - txnStatusInFailedTransaction transactionStatus = 'E' -) - -func (s transactionStatus) String() string { - switch s { - case txnStatusIdle: - return "idle" - case txnStatusIdleInTransaction: - return "idle in transaction" - case txnStatusInFailedTransaction: - return "in a failed transaction" - default: - errorf("unknown transactionStatus %d", s) - } - - panic("not reached") -} - -type Dialer interface { - Dial(network, address string) (net.Conn, error) - DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) -} - -type defaultDialer struct{} - -func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { - return net.Dial(ntw, addr) -} -func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout(ntw, addr, timeout) -} - -type conn struct { - c net.Conn - buf *bufio.Reader - namei int - scratch [512]byte - txnStatus transactionStatus - - parameterStatus parameterStatus - - saveMessageType byte - saveMessageBuffer []byte - - // If true, this connection is bad and all public-facing functions should - // return ErrBadConn. - bad bool - - // If set, this connection should never use the binary format when - // receiving query results from prepared statements. Only provided for - // debugging. - disablePreparedBinaryResult bool - - // Whether to always send []byte parameters over as binary. Enables single - // round-trip mode for non-prepared Query calls. - binaryParameters bool -} - -// Handle driver-side settings in parsed connection string. -func (c *conn) handleDriverSettings(o values) (err error) { - boolSetting := func(key string, val *bool) error { - if value := o.Get(key); value != "" { - if value == "yes" { - *val = true - } else if value == "no" { - *val = false - } else { - return fmt.Errorf("unrecognized value %q for %s", value, key) - } - } - return nil - } - - err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) - if err != nil { - return err - } - err = boolSetting("binary_parameters", &c.binaryParameters) - if err != nil { - return err - } - return nil -} - -func (c *conn) handlePgpass(o values) { - // if a password was supplied, do not process .pgpass - _, ok := o["password"] - if ok { - return - } - filename := os.Getenv("PGPASSFILE") - if filename == "" { - // XXX this code doesn't work on Windows where the default filename is - // XXX %APPDATA%\postgresql\pgpass.conf - user, err := user.Current() - if err != nil { - return - } - filename = filepath.Join(user.HomeDir, ".pgpass") - } - fileinfo, err := os.Stat(filename) - if err != nil { - return - } - mode := fileinfo.Mode() - if mode&(0x77) != 0 { - // XXX should warn about incorrect .pgpass permissions as psql does - return - } - file, err := os.Open(filename) - if err != nil { - return - } - defer file.Close() - scanner := bufio.NewScanner(io.Reader(file)) - hostname := o.Get("host") - ntw, _ := network(o) - port := o.Get("port") - db := o.Get("dbname") - username := o.Get("user") - // From: https://github.com/tg/pgpass/blob/master/reader.go - getFields := func(s string) []string { - fs := make([]string, 0, 5) - f := make([]rune, 0, len(s)) - - var esc bool - for _, c := range s { - switch { - case esc: - f = append(f, c) - esc = false - case c == '\\': - esc = true - case c == ':': - fs = append(fs, string(f)) - f = f[:0] - default: - f = append(f, c) - } - } - return append(fs, string(f)) - } - for scanner.Scan() { - line := scanner.Text() - if len(line) == 0 || line[0] == '#' { - continue - } - split := getFields(line) - if len(split) != 5 { - continue - } - if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { - o["password"] = split[4] - return - } - } -} - -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b - return &writeBuf{ - buf: c.scratch[:5], - pos: 1, - } -} - -func Open(name string) (_ driver.Conn, err error) { - return DialOpen(defaultDialer{}, name) -} - -func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { - // Handle any panics during connection initialization. Note that we - // specifically do *not* want to use errRecover(), as that would turn any - // connection errors into ErrBadConns, hiding the real error message from - // the user. - defer errRecoverNoErrBadConn(&err) - - o := make(values) - - // A number of defaults are applied here, in this order: - // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") - for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) - } - - if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { - name, err = ParseURL(name) - if err != nil { - return nil, err - } - } - - if err := parseOpts(name, o); err != nil { - return nil, err - } - - // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) - } - } - - // We can't work with any client_encoding other than UTF-8 currently. - // However, we have historically allowed the user to set it to UTF-8 - // explicitly, and there's no reason to break such programs, so allow that. - // Note that the "options" setting could also set client_encoding, but - // parsing its value is not worth it. Instead, we always explicitly send - // client_encoding as a separate run-time parameter, which should override - // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") - } - o.Set("client_encoding", "UTF8") - // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { - if datestyle != "ISO, MDY" { - panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", - "ISO, MDY", datestyle)) - } - } else { - o.Set("datestyle", "ISO, MDY") - } - - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if o.Get("user") == "" { - u, err := userCurrent() - if err != nil { - return nil, err - } else { - o.Set("user", u) - } - } - - cn := &conn{} - err = cn.handleDriverSettings(o) - if err != nil { - return nil, err - } - cn.handlePgpass(o) - - cn.c, err = dial(d, o) - if err != nil { - return nil, err - } - cn.ssl(o) - cn.buf = bufio.NewReader(cn.c) - cn.startup(o) - - // reset the deadline, in case one was set (see dial) - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { - err = cn.c.SetDeadline(time.Time{}) - } - return cn, err -} - -func dial(d Dialer, o values) (net.Conn, error) { - ntw, addr := network(o) - // SSL is not necessary or supported over UNIX domain sockets - if ntw == "unix" { - o["sslmode"] = "disable" - } - - // Zero or not specified means wait indefinitely. - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { - seconds, err := strconv.ParseInt(timeout, 10, 0) - if err != nil { - return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) - } - duration := time.Duration(seconds) * time.Second - // connect_timeout should apply to the entire connection establishment - // procedure, so we both use a timeout for the TCP connection - // establishment and set a deadline for doing the initial handshake. - // The deadline is then reset after startup() is done. - deadline := time.Now().Add(duration) - conn, err := d.DialTimeout(ntw, addr, duration) - if err != nil { - return nil, err - } - err = conn.SetDeadline(deadline) - return conn, err - } - return d.Dial(ntw, addr) -} - -func network(o values) (string, string) { - host := o.Get("host") - - if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) - return "unix", sockPath - } - - return "tcp", net.JoinHostPort(host, o.Get("port")) -} - -type values map[string]string - -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - -// scanner implements a tokenizer for libpq-style option strings. -type scanner struct { - s []rune - i int -} - -// newScanner returns a new scanner initialized with the option string s. -func newScanner(s string) *scanner { - return &scanner{[]rune(s), 0} -} - -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) Next() (rune, bool) { - if s.i >= len(s.s) { - return 0, false - } - r := s.s[s.i] - s.i++ - return r, true -} - -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) SkipSpaces() (rune, bool) { - r, ok := s.Next() - for unicode.IsSpace(r) && ok { - r, ok = s.Next() - } - return r, ok -} - -// parseOpts parses the options from name and adds them to the values. -// -// The parsing code is based on conninfo_parse from libpq's fe-connect.c -func parseOpts(name string, o values) error { - s := newScanner(name) - - for { - var ( - keyRunes, valRunes []rune - r rune - ok bool - ) - - if r, ok = s.SkipSpaces(); !ok { - break - } - - // Scan the key - for !unicode.IsSpace(r) && r != '=' { - keyRunes = append(keyRunes, r) - if r, ok = s.Next(); !ok { - break - } - } - - // Skip any whitespace if we're not at the = yet - if r != '=' { - r, ok = s.SkipSpaces() - } - - // The current character should be = - if r != '=' || !ok { - return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) - } - - // Skip any whitespace after the = - if r, ok = s.SkipSpaces(); !ok { - // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") - break - } - - if r != '\'' { - for !unicode.IsSpace(r) { - if r == '\\' { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`missing character after backslash`) - } - } - valRunes = append(valRunes, r) - - if r, ok = s.Next(); !ok { - break - } - } - } else { - quote: - for { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`unterminated quoted string literal in connection string`) - } - switch r { - case '\'': - break quote - case '\\': - r, _ = s.Next() - fallthrough - default: - valRunes = append(valRunes, r) - } - } - } - - o.Set(string(keyRunes), string(valRunes)) - } - - return nil -} - -func (cn *conn) isInTransaction() bool { - return cn.txnStatus == txnStatusIdleInTransaction || - cn.txnStatus == txnStatusInFailedTransaction -} - -func (cn *conn) checkIsInTransaction(intxn bool) { - if cn.isInTransaction() != intxn { - cn.bad = true - errorf("unexpected transaction status %v", cn.txnStatus) - } -} - -func (cn *conn) Begin() (_ driver.Tx, err error) { - if cn.bad { - return nil, driver.ErrBadConn - } - defer cn.errRecover(&err) - - cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") - if err != nil { - return nil, err - } - if commandTag != "BEGIN" { - cn.bad = true - return nil, fmt.Errorf("unexpected command tag %s", commandTag) - } - if cn.txnStatus != txnStatusIdleInTransaction { - cn.bad = true - return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) - } - return cn, nil -} - -func (cn *conn) Commit() (err error) { - if cn.bad { - return driver.ErrBadConn - } - defer cn.errRecover(&err) - - cn.checkIsInTransaction(true) - // We don't want the client to think that everything is okay if it tries - // to commit a failed transaction. However, no matter what we return, - // database/sql will release this connection back into the free connection - // pool so we have to abort the current transaction here. Note that you - // would get the same behaviour if you issued a COMMIT in a failed - // transaction, so it's also the least surprising thing to do here. - if cn.txnStatus == txnStatusInFailedTransaction { - if err := cn.Rollback(); err != nil { - return err - } - return ErrInFailedTransaction - } - - _, commandTag, err := cn.simpleExec("COMMIT") - if err != nil { - if cn.isInTransaction() { - cn.bad = true - } - return err - } - if commandTag != "COMMIT" { - cn.bad = true - return fmt.Errorf("unexpected command tag %s", commandTag) - } - cn.checkIsInTransaction(false) - return nil -} - -func (cn *conn) Rollback() (err error) { - if cn.bad { - return driver.ErrBadConn - } - defer cn.errRecover(&err) - - cn.checkIsInTransaction(true) - _, commandTag, err := cn.simpleExec("ROLLBACK") - if err != nil { - if cn.isInTransaction() { - cn.bad = true - } - return err - } - if commandTag != "ROLLBACK" { - return fmt.Errorf("unexpected command tag %s", commandTag) - } - cn.checkIsInTransaction(false) - return nil -} - -func (cn *conn) gname() string { - cn.namei++ - return strconv.FormatInt(int64(cn.namei), 10) -} - -func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) - - for { - t, r := cn.recv1() - switch t { - case 'C': - res, commandTag = cn.parseComplete(r.string()) - case 'Z': - cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady - } - // done - return - case 'E': - err = parseError(r) - case 'I': - res = emptyRows - case 'T', 'D': - // ignore any results - default: - cn.bad = true - errorf("unknown response for simple query: %q", t) - } - } -} - -func (cn *conn) simpleQuery(q string) (res *rows, err error) { - defer cn.errRecover(&err) - - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) - - for { - t, r := cn.recv1() - switch t { - case 'C', 'I': - // We allow queries which don't return any results through Query as - // well as Exec. We still have to give database/sql a rows object - // the user can close, though, to avoid connections from being - // leaked. A "rows" with done=true works fine for that purpose. - if err != nil { - cn.bad = true - errorf("unexpected message %q in simple query execution", t) - } - if res == nil { - res = &rows{ - cn: cn, - } - } - res.done = true - case 'Z': - cn.processReadyForQuery(r) - // done - return - case 'E': - res = nil - err = parseError(r) - case 'D': - if res == nil { - cn.bad = true - errorf("unexpected DataRow in simple query execution") - } - // the query didn't fail; kick off to Next - cn.saveMessage(t, r) - return - case 'T': - // res might be non-nil here if we received a previous - // CommandComplete, but that's fine; just overwrite it - res = &rows{cn: cn} - res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) - - // To work around a bug in QueryRow in Go 1.2 and earlier, wait - // until the first DataRow has been received. - default: - cn.bad = true - errorf("unknown response for simple query: %q", t) - } - } -} - -type noRows struct{} - -var emptyRows noRows - -var _ driver.Result = noRows{} - -func (noRows) LastInsertId() (int64, error) { - return 0, errNoLastInsertId -} - -func (noRows) RowsAffected() (int64, error) { - return 0, errNoRowsAffected -} - -// Decides which column formats to use for a prepared statement. The input is -// an array of type oids, one element per result column. -func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { - if len(colTyps) == 0 { - return nil, colFmtDataAllText - } - - colFmts = make([]format, len(colTyps)) - if forceText { - return colFmts, colFmtDataAllText - } - - allBinary := true - allText := true - for i, o := range colTyps { - switch o { - // This is the list of types to use binary mode for when receiving them - // through a prepared statement. If a type appears in this list, it - // must also be implemented in binaryDecode in encode.go. - case oid.T_bytea: - fallthrough - case oid.T_int8: - fallthrough - case oid.T_int4: - fallthrough - case oid.T_int2: - colFmts[i] = formatBinary - allText = false - - default: - allBinary = false - } - } - - if allBinary { - return colFmts, colFmtDataAllBinary - } else if allText { - return colFmts, colFmtDataAllText - } else { - colFmtData = make([]byte, 2+len(colFmts)*2) - binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) - for i, v := range colFmts { - binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) - } - return colFmts, colFmtData - } -} - -func (cn *conn) prepareTo(q, stmtName string) *stmt { - st := &stmt{cn: cn, name: stmtName} - - b := cn.writeBuf('P') - b.string(st.name) - b.string(q) - b.int16(0) - - b.next('D') - b.byte('S') - b.string(st.name) - - b.next('S') - cn.send(b) - - cn.readParseResponse() - st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() - st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) - cn.readReadyForQuery() - return st -} - -func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.bad { - return nil, driver.ErrBadConn - } - defer cn.errRecover(&err) - - if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) - } - return cn.prepareTo(q, cn.gname()), nil -} - -func (cn *conn) Close() (err error) { - // Skip cn.bad return here because we always want to close a connection. - defer cn.errRecover(&err) - - // Ensure that cn.c.Close is always run. Since error handling is done with - // panics and cn.errRecover, the Close must be in a defer. - defer func() { - cerr := cn.c.Close() - if err == nil { - err = cerr - } - }() - - // Don't go through send(); ListenerConn relies on us not scribbling on the - // scratch buffer of this connection. - return cn.sendSimpleMessage('X') -} - -// Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { - if cn.bad { - return nil, driver.ErrBadConn - } - defer cn.errRecover(&err) - - // Check to see if we can use the "simpleQuery" interface, which is - // *much* faster than going through prepare/exec - if len(args) == 0 { - return cn.simpleQuery(query) - } - - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) - - cn.readParseResponse() - cn.readBindResponse() - rows := &rows{cn: cn} - rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - return rows, nil - } else { - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil - } -} - -// Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.bad { - return nil, driver.ErrBadConn - } - defer cn.errRecover(&err) - - // Check to see if we can use the "simpleExec" interface, which is - // *much* faster than going through prepare/exec - if len(args) == 0 { - // ignore commandTag, our caller doesn't care - r, _, err := cn.simpleExec(query) - return r, err - } - - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) - - cn.readParseResponse() - cn.readBindResponse() - cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - res, _, err = cn.readExecuteResponse("Execute") - return res, err - } else { - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err - } -} - -func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) - if err != nil { - panic(err) - } -} - -func (cn *conn) sendStartupPacket(m *writeBuf) { - // sanity check - if m.buf[0] != 0 { - panic("oops") - } - - _, err := cn.c.Write((m.wrap())[1:]) - if err != nil { - panic(err) - } -} - -// Send a message of type typ to the server on the other end of cn. The -// message should have no payload. This method does not use the scratch -// buffer. -func (cn *conn) sendSimpleMessage(typ byte) (err error) { - _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) - return err -} - -// saveMessage memorizes a message and its buffer in the conn struct. -// recvMessage will then return these values on the next call to it. This -// method is useful in cases where you have to see what the next message is -// going to be (e.g. to see whether it's an error or not) but you can't handle -// the message yourself. -func (cn *conn) saveMessage(typ byte, buf *readBuf) { - if cn.saveMessageType != 0 { - cn.bad = true - errorf("unexpected saveMessageType %d", cn.saveMessageType) - } - cn.saveMessageType = typ - cn.saveMessageBuffer = *buf -} - -// recvMessage receives any message from the backend, or returns an error if -// a problem occurred while reading the message. -func (cn *conn) recvMessage(r *readBuf) (byte, error) { - // workaround for a QueryRow bug, see exec - if cn.saveMessageType != 0 { - t := cn.saveMessageType - *r = cn.saveMessageBuffer - cn.saveMessageType = 0 - cn.saveMessageBuffer = nil - return t, nil - } - - x := cn.scratch[:5] - _, err := io.ReadFull(cn.buf, x) - if err != nil { - return 0, err - } - - // read the type and length of the message that follows - t := x[0] - n := int(binary.BigEndian.Uint32(x[1:])) - 4 - var y []byte - if n <= len(cn.scratch) { - y = cn.scratch[:n] - } else { - y = make([]byte, n) - } - _, err = io.ReadFull(cn.buf, y) - if err != nil { - return 0, err - } - *r = y - return t, nil -} - -// recv receives a message from the backend, but if an error happened while -// reading the message or the received message was an ErrorResponse, it panics. -// NoticeResponses are ignored. This function should generally be used only -// during the startup sequence. -func (cn *conn) recv() (t byte, r *readBuf) { - for { - var err error - r = &readBuf{} - t, err = cn.recvMessage(r) - if err != nil { - panic(err) - } - - switch t { - case 'E': - panic(parseError(r)) - case 'N': - // ignore - default: - return - } - } -} - -// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by -// the caller to avoid an allocation. -func (cn *conn) recv1Buf(r *readBuf) byte { - for { - t, err := cn.recvMessage(r) - if err != nil { - panic(err) - } - - switch t { - case 'A', 'N': - // ignore - case 'S': - cn.processParameterStatus(r) - default: - return t - } - } -} - -// recv1 receives a message from the backend, panicking if an error occurs -// while attempting to read it. All asynchronous messages are ignored, with -// the exception of ErrorResponse. -func (cn *conn) recv1() (t byte, r *readBuf) { - r = &readBuf{} - t = cn.recv1Buf(r) - return t, r -} - -func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - // "require" is the default. - case "", "require": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - - // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: - // Note: For backwards compatibility with earlier versions of PostgreSQL, if a - // root CA file exists, the behavior of sslmode=require will be the same as - // that of verify-ca, meaning the server certificate is validated against the - // CA. Relying on this behavior is discouraged, and applications that need - // certificate validation should always use verify-ca or verify-full. - if _, err := os.Stat(o.Get("sslrootcert")); err == nil { - verifyCaOnly = true - } else { - o.Set("sslrootcert", "") - } - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": - return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) - } - - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) - - w := cn.writeBuf(0) - w.int32(80877103) - cn.sendStartupPacket(w) - - b := cn.scratch[:1] - _, err := io.ReadFull(cn.c, b) - if err != nil { - panic(err) - } - - if b[0] != 'S' { - panic(ErrSSLNotSupported) - } - - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true - } - - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } -} - -// isDriverSetting returns true iff a setting is purely for configuring the -// driver's options and should not be sent to the server in the connection -// startup packet. -func isDriverSetting(key string) bool { - switch key { - case "host", "port": - return true - case "password": - return true - case "sslmode", "sslcert", "sslkey", "sslrootcert": - return true - case "fallback_application_name": - return true - case "connect_timeout": - return true - case "disable_prepared_binary_result": - return true - case "binary_parameters": - return true - - default: - return false - } -} - -func (cn *conn) startup(o values) { - w := cn.writeBuf(0) - w.int32(196608) - // Send the backend the name of the database we want to connect to, and the - // user we want to connect as. Additionally, we send over any run-time - // parameters potentially included in the connection string. If the server - // doesn't recognize any of them, it will reply with an error. - for k, v := range o { - if isDriverSetting(k) { - // skip options which can't be run-time parameters - continue - } - // The protocol requires us to supply the database name as "database" - // instead of "dbname". - if k == "dbname" { - k = "database" - } - w.string(k) - w.string(v) - } - w.string("") - cn.sendStartupPacket(w) - - for { - t, r := cn.recv() - switch t { - case 'K': - case 'S': - cn.processParameterStatus(r) - case 'R': - cn.auth(r, o) - case 'Z': - cn.processReadyForQuery(r) - return - default: - errorf("unknown response for startup: %q", t) - } - } -} - -func (cn *conn) auth(r *readBuf, o values) { - switch code := r.int32(); code { - case 0: - // OK - case 3: - w := cn.writeBuf('p') - w.string(o.Get("password")) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } - - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 5: - s := string(r.next(4)) - w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } - - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - default: - errorf("unknown authentication response: %d", code) - } -} - -type format int - -const formatText format = 0 -const formatBinary format = 1 - -// One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} - -// No result-column format codes (i.e. all text). -var colFmtDataAllText []byte = []byte{0, 0} - -type stmt struct { - cn *conn - name string - colNames []string - colFmts []format - colFmtData []byte - colTyps []oid.Oid - paramTyps []oid.Oid - closed bool -} - -func (st *stmt) Close() (err error) { - if st.closed { - return nil - } - if st.cn.bad { - return driver.ErrBadConn - } - defer st.cn.errRecover(&err) - - w := st.cn.writeBuf('C') - w.byte('S') - w.string(st.name) - st.cn.send(w) - - st.cn.send(st.cn.writeBuf('S')) - - t, _ := st.cn.recv1() - if t != '3' { - st.cn.bad = true - errorf("unexpected close response: %q", t) - } - st.closed = true - - t, r := st.cn.recv1() - if t != 'Z' { - st.cn.bad = true - errorf("expected ready for query, but got: %q", t) - } - st.cn.processReadyForQuery(r) - - return nil -} - -func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - if st.cn.bad { - return nil, driver.ErrBadConn - } - defer st.cn.errRecover(&err) - - st.exec(v) - return &rows{ - cn: st.cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil -} - -func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.bad { - return nil, driver.ErrBadConn - } - defer st.cn.errRecover(&err) - - st.exec(v) - res, _, err = st.cn.readExecuteResponse("simple query") - return res, err -} - -func (st *stmt) exec(v []driver.Value) { - if len(v) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) - } - if len(v) != len(st.paramTyps) { - errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) - } - - cn := st.cn - w := cn.writeBuf('B') - w.byte(0) // unnamed portal - w.string(st.name) - - if cn.binaryParameters { - cn.sendBinaryParameters(w, v) - } else { - w.int16(0) - w.int16(len(v)) - for i, x := range v { - if x == nil { - w.int32(-1) - } else { - b := encode(&cn.parameterStatus, x, st.paramTyps[i]) - w.int32(len(b)) - w.bytes(b) - } - } - } - w.bytes(st.colFmtData) - - w.next('E') - w.byte(0) - w.int32(0) - - w.next('S') - cn.send(w) - - cn.readBindResponse() - cn.postExecuteWorkaround() - -} - -func (st *stmt) NumInput() int { - return len(st.paramTyps) -} - -// parseComplete parses the "command tag" from a CommandComplete message, and -// returns the number of rows affected (if applicable) and a string -// identifying only the command that was executed, e.g. "ALTER TABLE". If the -// command tag could not be parsed, parseComplete panics. -func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { - commandsWithAffectedRows := []string{ - "SELECT ", - // INSERT is handled below - "UPDATE ", - "DELETE ", - "FETCH ", - "MOVE ", - "COPY ", - } - - var affectedRows *string - for _, tag := range commandsWithAffectedRows { - if strings.HasPrefix(commandTag, tag) { - t := commandTag[len(tag):] - affectedRows = &t - commandTag = tag[:len(tag)-1] - break - } - } - // INSERT also includes the oid of the inserted row in its command tag. - // Oids in user tables are deprecated, and the oid is only returned when - // exactly one row is inserted, so it's unlikely to be of value to any - // real-world application and we can ignore it. - if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { - parts := strings.Split(commandTag, " ") - if len(parts) != 3 { - cn.bad = true - errorf("unexpected INSERT command tag %s", commandTag) - } - affectedRows = &parts[len(parts)-1] - commandTag = "INSERT" - } - // There should be no affected rows attached to the tag, just return it - if affectedRows == nil { - return driver.RowsAffected(0), commandTag - } - n, err := strconv.ParseInt(*affectedRows, 10, 64) - if err != nil { - cn.bad = true - errorf("could not parse commandTag: %s", err) - } - return driver.RowsAffected(n), commandTag -} - -type rows struct { - cn *conn - colNames []string - colTyps []oid.Oid - colFmts []format - done bool - rb readBuf -} - -func (rs *rows) Close() error { - // no need to look at cn.bad as Next() will - for { - err := rs.Next(nil) - switch err { - case nil: - case io.EOF: - return nil - default: - return err - } - } -} - -func (rs *rows) Columns() []string { - return rs.colNames -} - -func (rs *rows) Next(dest []driver.Value) (err error) { - if rs.done { - return io.EOF - } - - conn := rs.cn - if conn.bad { - return driver.ErrBadConn - } - defer conn.errRecover(&err) - - for { - t := conn.recv1Buf(&rs.rb) - switch t { - case 'E': - err = parseError(&rs.rb) - case 'C', 'I': - continue - case 'Z': - conn.processReadyForQuery(&rs.rb) - rs.done = true - if err != nil { - return err - } - return io.EOF - case 'D': - n := rs.rb.int16() - if err != nil { - conn.bad = true - errorf("unexpected DataRow after error %s", err) - } - if n < len(dest) { - dest = dest[:n] - } - for i := range dest { - l := rs.rb.int32() - if l == -1 { - dest[i] = nil - continue - } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) - } - return - default: - errorf("unexpected message after execute: %q", t) - } - } -} - -// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be -// used as part of an SQL statement. For example: -// -// tblname := "my_table" -// data := "my_data" -// err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data) -// -// Any double quotes in name will be escaped. The quoted identifier will be -// case sensitive when used in a query. If the input string contains a zero -// byte, the result will be truncated immediately before it. -func QuoteIdentifier(name string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return `"` + strings.Replace(name, `"`, `""`, -1) + `"` -} - -func md5s(s string) string { - h := md5.New() - h.Write([]byte(s)) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { - // Do one pass over the parameters to see if we're going to send any of - // them over in binary. If we are, create a paramFormats array at the - // same time. - var paramFormats []int - for i, x := range args { - _, ok := x.([]byte) - if ok { - if paramFormats == nil { - paramFormats = make([]int, len(args)) - } - paramFormats[i] = 1 - } - } - if paramFormats == nil { - b.int16(0) - } else { - b.int16(len(paramFormats)) - for _, x := range paramFormats { - b.int16(x) - } - } - - b.int16(len(args)) - for _, x := range args { - if x == nil { - b.int32(-1) - } else { - datum := binaryEncode(&cn.parameterStatus, x) - b.int32(len(datum)) - b.bytes(datum) - } - } -} - -func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { - if len(args) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) - } - - b := cn.writeBuf('P') - b.byte(0) // unnamed statement - b.string(query) - b.int16(0) - - b.next('B') - b.int16(0) // unnamed portal and statement - cn.sendBinaryParameters(b, args) - b.bytes(colFmtDataAllText) - - b.next('D') - b.byte('P') - b.byte(0) // unnamed portal - - b.next('E') - b.byte(0) - b.int32(0) - - b.next('S') - cn.send(b) -} - -func (c *conn) processParameterStatus(r *readBuf) { - var err error - - param := r.string() - switch param { - case "server_version": - var major1 int - var major2 int - var minor int - _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) - if err == nil { - c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor - } - - case "TimeZone": - c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) - if err != nil { - c.parameterStatus.currentLocation = nil - } - - default: - // ignore - } -} - -func (c *conn) processReadyForQuery(r *readBuf) { - c.txnStatus = transactionStatus(r.byte()) -} - -func (cn *conn) readReadyForQuery() { - t, r := cn.recv1() - switch t { - case 'Z': - cn.processReadyForQuery(r) - return - default: - cn.bad = true - errorf("unexpected message %q; expected ReadyForQuery", t) - } -} - -func (cn *conn) readParseResponse() { - t, r := cn.recv1() - switch t { - case '1': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.bad = true - errorf("unexpected Parse response %q", t) - } -} - -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { - for { - t, r := cn.recv1() - switch t { - case 't': - nparams := r.int16() - paramTyps = make([]oid.Oid, nparams) - for i := range paramTyps { - paramTyps[i] = r.oid() - } - case 'n': - return paramTyps, nil, nil - case 'T': - colNames, colTyps = parseStatementRowDescribe(r) - return paramTyps, colNames, colTyps - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.bad = true - errorf("unexpected Describe statement response %q", t) - } - } -} - -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { - t, r := cn.recv1() - switch t { - case 'T': - return parsePortalRowDescribe(r) - case 'n': - return nil, nil, nil - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.bad = true - errorf("unexpected Describe response %q", t) - } - panic("not reached") -} - -func (cn *conn) readBindResponse() { - t, r := cn.recv1() - switch t { - case '2': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - default: - cn.bad = true - errorf("unexpected Bind response %q", t) - } -} - -func (cn *conn) postExecuteWorkaround() { - // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores - // any errors from rows.Next, which masks errors that happened during the - // execution of the query. To avoid the problem in common cases, we wait - // here for one more message from the database. If it's not an error the - // query will likely succeed (or perhaps has already, if it's a - // CommandComplete), so we push the message into the conn struct; recv1 - // will return it as the next message for rows.Next or rows.Close. - // However, if it's an error, we wait until ReadyForQuery and then return - // the error to our caller. - for { - t, r := cn.recv1() - switch t { - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - case 'C', 'D', 'I': - // the query didn't fail, but we can't process this message - cn.saveMessage(t, r) - return - default: - cn.bad = true - errorf("unexpected message during extended query execution: %q", t) - } - } -} - -// Only for Exec(), since we ignore the returned data -func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { - for { - t, r := cn.recv1() - switch t { - case 'C': - if err != nil { - cn.bad = true - errorf("unexpected CommandComplete after error %s", err) - } - res, commandTag = cn.parseComplete(r.string()) - case 'Z': - cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady - } - return res, commandTag, err - case 'E': - err = parseError(r) - case 'T', 'D', 'I': - if err != nil { - cn.bad = true - errorf("unexpected %q after error %s", t, err) - } - if t == 'I' { - res = emptyRows - } - // ignore any results - default: - cn.bad = true - errorf("unknown %s response: %q", protocolState, t) - } - } -} - -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { - n := r.int16() - colNames = make([]string, n) - colTyps = make([]oid.Oid, n) - for i := range colNames { - colNames[i] = r.string() - r.next(6) - colTyps[i] = r.oid() - r.next(6) - // format code not known when describing a statement; always 0 - r.next(2) - } - return -} - -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { - n := r.int16() - colNames = make([]string, n) - colFmts = make([]format, n) - colTyps = make([]oid.Oid, n) - for i := range colNames { - colNames[i] = r.string() - r.next(6) - colTyps[i] = r.oid() - r.next(6) - colFmts[i] = format(r.int16()) - } - return -} - -// parseEnviron tries to mimic some of libpq's environment handling -// -// To ease testing, it does not directly reference os.Environ, but is -// designed to accept its output. -// -// Environment-set connection information is intended to have a higher -// precedence than a library default but lower than any explicitly -// passed information (such as in the URL or connection string). -func parseEnviron(env []string) (out map[string]string) { - out = make(map[string]string) - - for _, v := range env { - parts := strings.SplitN(v, "=", 2) - - accrue := func(keyname string) { - out[keyname] = parts[1] - } - unsupported := func() { - panic(fmt.Sprintf("setting %v not supported", parts[0])) - } - - // The order of these is the same as is seen in the - // PostgreSQL 9.1 manual. Unsupported but well-defined - // keys cause a panic; these should be unset prior to - // execution. Options which pq expects to be set to a - // certain value are allowed, but must be set to that - // value if present (they can, of course, be absent). - switch parts[0] { - case "PGHOST": - accrue("host") - case "PGHOSTADDR": - unsupported() - case "PGPORT": - accrue("port") - case "PGDATABASE": - accrue("dbname") - case "PGUSER": - accrue("user") - case "PGPASSWORD": - accrue("password") - case "PGSERVICE", "PGSERVICEFILE", "PGREALM": - unsupported() - case "PGOPTIONS": - accrue("options") - case "PGAPPNAME": - accrue("application_name") - case "PGSSLMODE": - accrue("sslmode") - case "PGSSLCERT": - accrue("sslcert") - case "PGSSLKEY": - accrue("sslkey") - case "PGSSLROOTCERT": - accrue("sslrootcert") - case "PGREQUIRESSL", "PGSSLCRL": - unsupported() - case "PGREQUIREPEER": - unsupported() - case "PGKRBSRVNAME", "PGGSSLIB": - unsupported() - case "PGCONNECT_TIMEOUT": - accrue("connect_timeout") - case "PGCLIENTENCODING": - accrue("client_encoding") - case "PGDATESTYLE": - accrue("datestyle") - case "PGTZ": - accrue("timezone") - case "PGGEQO": - accrue("geqo") - case "PGSYSCONFDIR", "PGLOCALEDIR": - unsupported() - } - } - - return out -} - -// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". -func isUTF8(name string) bool { - // Recognize all sorts of silly things as "UTF-8", like Postgres does - s := strings.Map(alnumLowerASCII, name) - return s == "utf8" || s == "unicode" -} - -func alnumLowerASCII(ch rune) rune { - if 'A' <= ch && ch <= 'Z' { - return ch + ('a' - 'A') - } - if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { - return ch - } - return -1 // discard -} |