aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-xorm/xorm/dialect_postgres.go
diff options
context:
space:
mode:
authorUnknwon <u@gogs.io>2018-03-09 00:26:47 -0500
committerUnknwon <u@gogs.io>2018-03-09 00:26:47 -0500
commit28f74cf1c67cde80ae453a799d76752114fd5e18 (patch)
tree72b160aef0810492e257c2707884bb3052e1ba51 /vendor/github.com/go-xorm/xorm/dialect_postgres.go
parent83655d5c00110044a4ac9bf46ec039379eded5dd (diff)
vendor: update github.com/go-xorm/xorm (#4913)
Diffstat (limited to 'vendor/github.com/go-xorm/xorm/dialect_postgres.go')
-rw-r--r--vendor/github.com/go-xorm/xorm/dialect_postgres.go154
1 files changed, 82 insertions, 72 deletions
diff --git a/vendor/github.com/go-xorm/xorm/dialect_postgres.go b/vendor/github.com/go-xorm/xorm/dialect_postgres.go
index 1d4daa27..2b2a0b78 100644
--- a/vendor/github.com/go-xorm/xorm/dialect_postgres.go
+++ b/vendor/github.com/go-xorm/xorm/dialect_postgres.go
@@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net/url"
- "sort"
"strconv"
"strings"
@@ -765,6 +764,9 @@ var (
"YES": true,
"ZONE": true,
}
+
+ // DefaultPostgresSchema default postgres schema
+ DefaultPostgresSchema = "public"
)
type postgres struct {
@@ -772,7 +774,14 @@ type postgres struct {
}
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
- return db.Base.Init(d, db, uri, drivername, dataSourceName)
+ err := db.Base.Init(d, db, uri, drivername, dataSourceName)
+ if err != nil {
+ return err
+ }
+ if db.Schema == "" {
+ db.Schema = DefaultPostgresSchema
+ }
+ return nil
}
func (db *postgres) SqlType(c *core.Column) string {
@@ -781,6 +790,9 @@ func (db *postgres) SqlType(c *core.Column) string {
case core.TinyInt:
res = core.SmallInt
return res
+ case core.Bit:
+ res = core.Boolean
+ return res
case core.MediumInt, core.Int, core.Integer:
if c.IsAutoIncrement {
return core.Serial
@@ -866,29 +878,35 @@ func (db *postgres) IndexOnTable() bool {
}
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
- args := []interface{}{tableName, idxName}
+ if len(db.Schema) == 0 {
+ args := []interface{}{tableName, idxName}
+ return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
+ }
+
+ args := []interface{}{db.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` +
- `WHERE tablename = ? AND indexname = ?`, args
+ `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
- args := []interface{}{tableName}
- return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
+ if len(db.Schema) == 0 {
+ args := []interface{}{tableName}
+ return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
+ }
+ args := []interface{}{db.Schema, tableName}
+ return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
-/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
- args := []interface{}{tableName, colName}
- return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
- " AND column_name = ?", args
-}*/
-
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
- return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
- tableName, col.Name, db.SqlType(col))
+ if len(db.Schema) == 0 {
+ return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
+ tableName, col.Name, db.SqlType(col))
+ }
+ return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
+ db.Schema, tableName, col.Name, db.SqlType(col))
}
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
- //var unique string
quote := db.Quote
idxName := index.Name
@@ -904,9 +922,14 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
}
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
- args := []interface{}{tableName, colName}
- query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
- " AND column_name = $2"
+ args := []interface{}{db.Schema, tableName, colName}
+ query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
+ " AND column_name = $3"
+ if len(db.Schema) == 0 {
+ args = []interface{}{tableName, colName}
+ query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
+ " AND column_name = $2"
+ }
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
@@ -919,8 +942,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
}
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
- // FIXME: the schema should be replaced by user custom's
- args := []interface{}{tableName, "public"}
+ args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@@ -931,7 +953,15 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
-WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`
+WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
+
+ var f string
+ if len(db.Schema) != 0 {
+ args = append(args, db.Schema)
+ f = "AND s.table_schema = $2"
+ }
+ s = fmt.Sprintf(s, f)
+
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@@ -1021,9 +1051,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att
}
func (db *postgres) GetTables() ([]*core.Table, error) {
- // FIXME: replace public to user customrize schema
- args := []interface{}{"public"}
- s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1")
+ args := []interface{}{}
+ s := "SELECT tablename FROM pg_tables"
+ if len(db.Schema) != 0 {
+ args = append(args, db.Schema)
+ s = s + " WHERE schemaname = $1"
+ }
+
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@@ -1047,10 +1081,13 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
}
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
- // FIXME: replace the public schema to user specify schema
- args := []interface{}{"public", tableName}
- s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2")
+ args := []interface{}{tableName}
+ s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
db.LogSQL(s, args)
+ if len(db.Schema) != 0 {
+ args = append(args, db.Schema)
+ s = s + " AND schemaname=$2"
+ }
rows, err := db.DB().Query(s, args...)
if err != nil {
@@ -1114,10 +1151,6 @@ func (vs values) Get(k string) (v string) {
return vs[k]
}
-func errorf(s string, args ...interface{}) {
- panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
-}
-
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@@ -1128,46 +1161,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
- var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
- accrue := func(k, v string) {
- if v != "" {
- kvs = append(kvs, k+"="+escaper.Replace(v))
- }
- }
-
- if u.User != nil {
- v := u.User.Username()
- accrue("user", v)
-
- v, _ = u.User.Password()
- accrue("password", v)
- }
-
- i := strings.Index(u.Host, ":")
- if i < 0 {
- accrue("host", u.Host)
- } else {
- accrue("host", u.Host[:i])
- accrue("port", u.Host[i+1:])
- }
if u.Path != "" {
- accrue("dbname", u.Path[1:])
+ return escaper.Replace(u.Path[1:]), nil
}
- q := u.Query()
- for k := range q {
- accrue(k, q.Get(k))
- }
-
- sort.Strings(kvs) // Makes testing easier (not a performance concern)
- return strings.Join(kvs, " "), nil
+ return "", nil
}
-func parseOpts(name string, o values) {
+func parseOpts(name string, o values) error {
if len(name) == 0 {
- return
+ return fmt.Errorf("invalid options: %s", name)
}
name = strings.TrimSpace(name)
@@ -1176,31 +1181,36 @@ func parseOpts(name string, o values) {
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
- errorf("invalid option: %q", p)
+ return fmt.Errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
+
+ return nil
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
- o := make(values)
var err error
+
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
- dataSourceName, err = parseURL(dataSourceName)
+ db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
+ } else {
+ o := make(values)
+ err = parseOpts(dataSourceName, o)
+ if err != nil {
+ return nil, err
+ }
+
+ db.DbName = o.Get("dbname")
}
- parseOpts(dataSourceName, o)
- db.DbName = o.Get("dbname")
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
- /*db.Schema = o.Get("schema")
- if len(db.Schema) == 0 {
- db.Schema = "public"
- }*/
+
return db, nil
}