From 28f74cf1c67cde80ae453a799d76752114fd5e18 Mon Sep 17 00:00:00 2001 From: Unknwon Date: Fri, 9 Mar 2018 00:26:47 -0500 Subject: vendor: update github.com/go-xorm/xorm (#4913) --- vendor/github.com/go-xorm/xorm/dialect_postgres.go | 154 +++++++++++---------- 1 file changed, 82 insertions(+), 72 deletions(-) (limited to 'vendor/github.com/go-xorm/xorm/dialect_postgres.go') diff --git a/vendor/github.com/go-xorm/xorm/dialect_postgres.go b/vendor/github.com/go-xorm/xorm/dialect_postgres.go index 1d4daa27..2b2a0b78 100644 --- a/vendor/github.com/go-xorm/xorm/dialect_postgres.go +++ b/vendor/github.com/go-xorm/xorm/dialect_postgres.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net/url" - "sort" "strconv" "strings" @@ -765,6 +764,9 @@ var ( "YES": true, "ZONE": true, } + + // DefaultPostgresSchema default postgres schema + DefaultPostgresSchema = "public" ) type postgres struct { @@ -772,7 +774,14 @@ type postgres struct { } func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) + err := db.Base.Init(d, db, uri, drivername, dataSourceName) + if err != nil { + return err + } + if db.Schema == "" { + db.Schema = DefaultPostgresSchema + } + return nil } func (db *postgres) SqlType(c *core.Column) string { @@ -781,6 +790,9 @@ func (db *postgres) SqlType(c *core.Column) string { case core.TinyInt: res = core.SmallInt return res + case core.Bit: + res = core.Boolean + return res case core.MediumInt, core.Int, core.Integer: if c.IsAutoIncrement { return core.Serial @@ -866,29 +878,35 @@ func (db *postgres) IndexOnTable() bool { } func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{tableName, idxName} + if len(db.Schema) == 0 { + args := []interface{}{tableName, idxName} + return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args + } + + args := []interface{}{db.Schema, tableName, idxName} return `SELECT indexname FROM pg_indexes ` + - `WHERE tablename = ? AND indexname = ?`, args + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + if len(db.Schema) == 0 { + args := []interface{}{tableName} + return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + } + args := []interface{}{db.Schema, tableName} + return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } -/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -}*/ - func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { - return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SqlType(col)) + if len(db.Schema) == 0 { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) + } + return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", + db.Schema, tableName, col.Name, db.SqlType(col)) } func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { - //var unique string quote := db.Quote idxName := index.Name @@ -904,9 +922,14 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { } func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{tableName, colName} - query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + - " AND column_name = $2" + args := []interface{}{db.Schema, tableName, colName} + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + + " AND column_name = $3" + if len(db.Schema) == 0 { + args = []interface{}{tableName, colName} + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + + " AND column_name = $2" + } db.LogSQL(query, args) rows, err := db.DB().Query(query, args...) @@ -919,8 +942,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - // FIXME: the schema should be replaced by user custom's - args := []interface{}{tableName, "public"} + args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -931,7 +953,15 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` + + var f string + if len(db.Schema) != 0 { + args = append(args, db.Schema) + f = "AND s.table_schema = $2" + } + s = fmt.Sprintf(s, f) + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1021,9 +1051,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att } func (db *postgres) GetTables() ([]*core.Table, error) { - // FIXME: replace public to user customrize schema - args := []interface{}{"public"} - s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") + args := []interface{}{} + s := "SELECT tablename FROM pg_tables" + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " WHERE schemaname = $1" + } + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1047,10 +1081,13 @@ func (db *postgres) GetTables() ([]*core.Table, error) { } func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { - // FIXME: replace the public schema to user specify schema - args := []interface{}{"public", tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") + args := []interface{}{tableName} + s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") db.LogSQL(s, args) + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " AND schemaname=$2" + } rows, err := db.DB().Query(s, args...) if err != nil { @@ -1114,10 +1151,6 @@ func (vs values) Get(k string) (v string) { return vs[k] } -func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) -} - func parseURL(connstr string) (string, error) { u, err := url.Parse(connstr) if err != nil { @@ -1128,46 +1161,18 @@ func parseURL(connstr string) (string, error) { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } - var kvs []string escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - i := strings.Index(u.Host, ":") - if i < 0 { - accrue("host", u.Host) - } else { - accrue("host", u.Host[:i]) - accrue("port", u.Host[i+1:]) - } if u.Path != "" { - accrue("dbname", u.Path[1:]) + return escaper.Replace(u.Path[1:]), nil } - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil + return "", nil } -func parseOpts(name string, o values) { +func parseOpts(name string, o values) error { if len(name) == 0 { - return + return fmt.Errorf("invalid options: %s", name) } name = strings.TrimSpace(name) @@ -1176,31 +1181,36 @@ func parseOpts(name string, o values) { for _, p := range ps { kv := strings.Split(p, "=") if len(kv) < 2 { - errorf("invalid option: %q", p) + return fmt.Errorf("invalid option: %q", p) } o.Set(kv[0], kv[1]) } + + return nil } func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { db := &core.Uri{DbType: core.POSTGRES} - o := make(values) var err error + if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { - dataSourceName, err = parseURL(dataSourceName) + db.DbName, err = parseURL(dataSourceName) if err != nil { return nil, err } + } else { + o := make(values) + err = parseOpts(dataSourceName, o) + if err != nil { + return nil, err + } + + db.DbName = o.Get("dbname") } - parseOpts(dataSourceName, o) - db.DbName = o.Get("dbname") if db.DbName == "" { return nil, errors.New("dbname is empty") } - /*db.Schema = o.Get("schema") - if len(db.Schema) == 0 { - db.Schema = "public" - }*/ + return db, nil } -- cgit v1.2.3