aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-sql-driver/mysql/dsn.go
diff options
context:
space:
mode:
authorUnknwon <u@gogs.io>2018-06-15 13:33:08 +0800
committerUnknwon <u@gogs.io>2018-06-15 13:33:08 +0800
commit93f3a7f96a1bc99ba17f02c5fb8e2e5b21671c3c (patch)
tree268fe87bf1ad9cb639214818cc3a51be52b4ba5d /vendor/github.com/go-sql-driver/mysql/dsn.go
parent7856b1202d6bfd8bded396b87692ea7113fedbed (diff)
vendor: update github.com/go-sql-driver/mysql
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/dsn.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/dsn.go214
1 files changed, 156 insertions, 58 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go
index 73138bc5..be014bab 100644
--- a/vendor/github.com/go-sql-driver/mysql/dsn.go
+++ b/vendor/github.com/go-sql-driver/mysql/dsn.go
@@ -10,11 +10,14 @@ package mysql
import (
"bytes"
+ "crypto/rsa"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
+ "sort"
+ "strconv"
"strings"
"time"
)
@@ -26,31 +29,84 @@ var (
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)
-// Config is a configuration parsed from a DSN string
+// Config is a configuration parsed from a DSN string.
+// If a new Config is created instead of being parsed from a DSN string,
+// the NewConfig function should be used, which sets default values.
type Config struct {
- User string // Username
- Passwd string // Password (requires User)
- Net string // Network type
- Addr string // Network address (requires Net)
- DBName string // Database name
- Params map[string]string // Connection parameters
- Collation string // Connection collation
- Loc *time.Location // Location for time.Time values
- TLSConfig string // TLS configuration name
- tls *tls.Config // TLS configuration
- Timeout time.Duration // Dial timeout
- ReadTimeout time.Duration // I/O read timeout
- WriteTimeout time.Duration // I/O write timeout
+ User string // Username
+ Passwd string // Password (requires User)
+ Net string // Network type
+ Addr string // Network address (requires Net)
+ DBName string // Database name
+ Params map[string]string // Connection parameters
+ Collation string // Connection collation
+ Loc *time.Location // Location for time.Time values
+ MaxAllowedPacket int // Max packet size allowed
+ ServerPubKey string // Server public key name
+ pubKey *rsa.PublicKey // Server public key
+ TLSConfig string // TLS configuration name
+ tls *tls.Config // TLS configuration
+ Timeout time.Duration // Dial timeout
+ ReadTimeout time.Duration // I/O read timeout
+ WriteTimeout time.Duration // I/O write timeout
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
+ AllowNativePasswords bool // Allows the native password authentication method
AllowOldPasswords bool // Allows the old insecure password method
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
- Strict bool // Return warnings as errors
+ RejectReadOnly bool // Reject read-only connections
+}
+
+// NewConfig creates a new Config and sets default values.
+func NewConfig() *Config {
+ return &Config{
+ Collation: defaultCollation,
+ Loc: time.UTC,
+ MaxAllowedPacket: defaultMaxAllowedPacket,
+ AllowNativePasswords: true,
+ }
+}
+
+func (cfg *Config) normalize() error {
+ if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
+ return errInvalidDSNUnsafeCollation
+ }
+
+ // Set default network if empty
+ if cfg.Net == "" {
+ cfg.Net = "tcp"
+ }
+
+ // Set default address if empty
+ if cfg.Addr == "" {
+ switch cfg.Net {
+ case "tcp":
+ cfg.Addr = "127.0.0.1:3306"
+ case "unix":
+ cfg.Addr = "/tmp/mysql.sock"
+ default:
+ return errors.New("default addr for network '" + cfg.Net + "' unknown")
+ }
+
+ } else if cfg.Net == "tcp" {
+ cfg.Addr = ensureHavePort(cfg.Addr)
+ }
+
+ if cfg.tls != nil {
+ if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
+ host, _, err := net.SplitHostPort(cfg.Addr)
+ if err == nil {
+ cfg.tls.ServerName = host
+ }
+ }
+ }
+
+ return nil
}
// FormatDSN formats the given Config into a DSN string which can be passed to
@@ -99,6 +155,15 @@ func (cfg *Config) FormatDSN() string {
}
}
+ if !cfg.AllowNativePasswords {
+ if hasParam {
+ buf.WriteString("&allowNativePasswords=false")
+ } else {
+ hasParam = true
+ buf.WriteString("?allowNativePasswords=false")
+ }
+ }
+
if cfg.AllowOldPasswords {
if hasParam {
buf.WriteString("&allowOldPasswords=true")
@@ -183,13 +248,23 @@ func (cfg *Config) FormatDSN() string {
buf.WriteString(cfg.ReadTimeout.String())
}
- if cfg.Strict {
+ if cfg.RejectReadOnly {
+ if hasParam {
+ buf.WriteString("&rejectReadOnly=true")
+ } else {
+ hasParam = true
+ buf.WriteString("?rejectReadOnly=true")
+ }
+ }
+
+ if len(cfg.ServerPubKey) > 0 {
if hasParam {
- buf.WriteString("&strict=true")
+ buf.WriteString("&serverPubKey=")
} else {
hasParam = true
- buf.WriteString("?strict=true")
+ buf.WriteString("?serverPubKey=")
}
+ buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
}
if cfg.Timeout > 0 {
@@ -222,9 +297,25 @@ func (cfg *Config) FormatDSN() string {
buf.WriteString(cfg.WriteTimeout.String())
}
+ if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
+ if hasParam {
+ buf.WriteString("&maxAllowedPacket=")
+ } else {
+ hasParam = true
+ buf.WriteString("?maxAllowedPacket=")
+ }
+ buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
+
+ }
+
// other params
if cfg.Params != nil {
- for param, value := range cfg.Params {
+ var params []string
+ for param := range cfg.Params {
+ params = append(params, param)
+ }
+ sort.Strings(params)
+ for _, param := range params {
if hasParam {
buf.WriteByte('&')
} else {
@@ -234,7 +325,7 @@ func (cfg *Config) FormatDSN() string {
buf.WriteString(param)
buf.WriteByte('=')
- buf.WriteString(url.QueryEscape(value))
+ buf.WriteString(url.QueryEscape(cfg.Params[param]))
}
}
@@ -244,10 +335,7 @@ func (cfg *Config) FormatDSN() string {
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
// New config with some default values
- cfg = &Config{
- Loc: time.UTC,
- Collation: defaultCollation,
- }
+ cfg = NewConfig()
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
@@ -315,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
return nil, errInvalidDSNNoSlash
}
- if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
- return nil, errInvalidDSNUnsafeCollation
- }
-
- // Set default network if empty
- if cfg.Net == "" {
- cfg.Net = "tcp"
- }
-
- // Set default address if empty
- if cfg.Addr == "" {
- switch cfg.Net {
- case "tcp":
- cfg.Addr = "127.0.0.1:3306"
- case "unix":
- cfg.Addr = "/tmp/mysql.sock"
- default:
- return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
- }
-
+ if err = cfg.normalize(); err != nil {
+ return nil, err
}
-
return
}
@@ -351,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// cfg params
switch value := param[1]; param[0] {
-
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
@@ -368,6 +436,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value)
}
+ // Use native password authentication
+ case "allowNativePasswords":
+ var isBool bool
+ cfg.AllowNativePasswords, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
@@ -441,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return
}
- // Strict mode
- case "strict":
+ // Reject read-only connections
+ case "rejectReadOnly":
var isBool bool
- cfg.Strict, isBool = readBool(value)
+ cfg.RejectReadOnly, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
+ // Server public key
+ case "serverPubKey":
+ name, err := url.QueryUnescape(value)
+ if err != nil {
+ return fmt.Errorf("invalid value for server pub key name: %v", err)
+ }
+
+ if pubKey := getServerPubKey(name); pubKey != nil {
+ cfg.ServerPubKey = name
+ cfg.pubKey = pubKey
+ } else {
+ return errors.New("invalid value / unknown server pub key name: " + name)
+ }
+
+ // Strict mode
+ case "strict":
+ panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
+
// Dial Timeout
case "timeout":
cfg.Timeout, err = time.ParseDuration(value)
@@ -475,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}
- if tlsConfig, ok := tlsConfigRegister[name]; ok {
- if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
- host, _, err := net.SplitHostPort(cfg.Addr)
- if err == nil {
- tlsConfig.ServerName = host
- }
- }
-
+ if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
@@ -496,7 +583,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}
-
+ case "maxAllowedPacket":
+ cfg.MaxAllowedPacket, err = strconv.Atoi(value)
+ if err != nil {
+ return
+ }
default:
// lazy init
if cfg.Params == nil {
@@ -511,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return
}
+
+func ensureHavePort(addr string) string {
+ if _, _, err := net.SplitHostPort(addr); err != nil {
+ return net.JoinHostPort(addr, "3306")
+ }
+ return addr
+}