aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-xorm/xorm/statement.go
diff options
context:
space:
mode:
authorpeter zhang <admin@ddatsh.com>2017-04-27 07:47:16 +0800
committer无闻 <u@gogs.io>2017-04-26 19:47:16 -0400
commit10ee2e0dad6471e8912cc833faf33db7eaa55fa8 (patch)
treeb1d2f03768781a31d7a68d2c5a9a7140e170f261 /vendor/github.com/go-xorm/xorm/statement.go
parent6500aafcb867e37379b1f238198e95134b09ac4e (diff)
vendor: update xorm version for fix git clone error build with golang 1.8.1 (#4460)
Diffstat (limited to 'vendor/github.com/go-xorm/xorm/statement.go')
-rw-r--r--vendor/github.com/go-xorm/xorm/statement.go106
1 files changed, 66 insertions, 40 deletions
diff --git a/vendor/github.com/go-xorm/xorm/statement.go b/vendor/github.com/go-xorm/xorm/statement.go
index fb116b94..b6f0baf2 100644
--- a/vendor/github.com/go-xorm/xorm/statement.go
+++ b/vendor/github.com/go-xorm/xorm/statement.go
@@ -39,7 +39,7 @@ type Statement struct {
Engine *Engine
Start int
LimitN int
- IdParam *core.PK
+ idParam *core.PK
OrderStr string
JoinStr string
joinArgs []interface{}
@@ -91,7 +91,7 @@ func (statement *Statement) Init() {
statement.columnMap = make(map[string]bool)
statement.AltTableName = ""
statement.tableName = ""
- statement.IdParam = nil
+ statement.idParam = nil
statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0)
statement.UseCache = true
@@ -195,29 +195,26 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
- if len(args) == 0 {
- return statement
- }
-
- in := builder.In(column, args...)
+ in := builder.In(statement.Engine.Quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
- if len(args) == 0 {
- return statement
- }
-
- in := builder.NotIn(column, args...)
- statement.cond = statement.cond.And(in)
+ notIn := builder.NotIn(statement.Engine.Quote(column), args...)
+ statement.cond = statement.cond.And(notIn)
return statement
}
-func (statement *Statement) setRefValue(v reflect.Value) {
- statement.RefTable = statement.Engine.autoMapType(reflect.Indirect(v))
+func (statement *Statement) setRefValue(v reflect.Value) error {
+ var err error
+ statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
+ if err != nil {
+ return err
+ }
statement.tableName = statement.Engine.tbName(v)
+ return nil
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
@@ -227,7 +224,12 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
- statement.RefTable = statement.Engine.autoMapType(v)
+ var err error
+ statement.RefTable, err = statement.Engine.autoMapType(v)
+ if err != nil {
+ statement.Engine.logger.Error(err)
+ return statement
+ }
statement.AltTableName = statement.Engine.tbName(v)
}
return statement
@@ -418,7 +420,11 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if fieldValue == reflect.Zero(fieldType) {
continue
}
- if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
+ if fieldType.Kind() == reflect.Array {
+ if isArrayValueZero(fieldValue) {
+ continue
+ }
+ } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
}
@@ -433,13 +439,16 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
- if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
+ if fieldType.Kind() == reflect.Slice &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
+ } else if fieldType.Kind() == reflect.Array &&
+ fieldType.Elem().Kind() == reflect.Uint8 {
+ val = fieldValue.Slice(0, 0).Interface()
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
@@ -651,7 +660,9 @@ func buildConds(engine *Engine, table *core.Table, bean interface{},
}
}
}
- case reflect.Array, reflect.Slice, reflect.Map:
+ case reflect.Array:
+ continue
+ case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
@@ -706,13 +717,6 @@ func (statement *Statement) TableName() string {
return statement.tableName
}
-// Id generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
-//
-// Deprecated: use ID instead
-func (statement *Statement) Id(id interface{}) *Statement {
- return statement.ID(id)
-}
-
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) ID(id interface{}) *Statement {
idValue := reflect.ValueOf(id)
@@ -721,23 +725,23 @@ func (statement *Statement) ID(id interface{}) *Statement {
switch idType {
case ptrPkType:
if pkPtr, ok := (id).(*core.PK); ok {
- statement.IdParam = pkPtr
+ statement.idParam = pkPtr
return statement
}
case pkType:
if pk, ok := (id).(core.PK); ok {
- statement.IdParam = &pk
+ statement.idParam = &pk
return statement
}
}
switch idType.Kind() {
case reflect.String:
- statement.IdParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
+ statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
return statement
}
- statement.IdParam = &core.PK{id}
+ statement.idParam = &core.PK{id}
return statement
}
@@ -1120,7 +1124,11 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
}
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
- statement.setRefValue(rValue(bean))
+ v := rValue(bean)
+ isStruct := v.Kind() == reflect.Struct
+ if isStruct {
+ statement.setRefValue(v)
+ }
var columnStr = statement.ColumnStr
if len(statement.selectStr) > 0 {
@@ -1139,14 +1147,22 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
- } else {
- columnStr = "*"
}
}
}
}
- condSQL, condArgs, _ := statement.genConds(bean)
+ if len(columnStr) == 0 {
+ columnStr = "*"
+ }
+
+ var condSQL string
+ var condArgs []interface{}
+ if isStruct {
+ condSQL, condArgs, _ = statement.genConds(bean)
+ } else {
+ condSQL, condArgs, _ = builder.ToSQL(statement.cond)
+ }
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
}
@@ -1172,17 +1188,21 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
+ if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
+ colName = statement.Engine.Quote(colName)
+ }
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
+ sumSelect := strings.Join(sumStrs, ", ")
condSQL, condArgs, _ := statement.genConds(bean)
- return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...)
+ return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...)
}
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
var distinct string
- if statement.IsDistinct {
+ if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
}
@@ -1198,8 +1218,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
fmt.Fprintf(&buf, " WHERE %v", condSQL)
}
var whereStr = buf.String()
+ var fromStr = " FROM "
+
+ if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
+ fromStr += statement.TableName()
+ } else {
+ fromStr += quote(statement.TableName())
+ }
- var fromStr = " FROM " + quote(statement.TableName())
if statement.TableAlias != "" {
if dialect.DBType() == core.ORACLE {
fromStr += " " + quote(statement.TableAlias)
@@ -1289,14 +1315,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
}
func (statement *Statement) processIDParam() {
- if statement.IdParam == nil {
+ if statement.idParam == nil {
return
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
- if i < len(*(statement.IdParam)) {
- statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]})
+ if i < len(*(statement.idParam)) {
+ statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
} else {
statement.cond = statement.cond.And(builder.Eq{colName: ""})
}