aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-xorm/xorm/statement.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-xorm/xorm/statement.go')
-rw-r--r--vendor/github.com/go-xorm/xorm/statement.go376
1 files changed, 118 insertions, 258 deletions
diff --git a/vendor/github.com/go-xorm/xorm/statement.go b/vendor/github.com/go-xorm/xorm/statement.go
index 58fa616b..35c4a472 100644
--- a/vendor/github.com/go-xorm/xorm/statement.go
+++ b/vendor/github.com/go-xorm/xorm/statement.go
@@ -73,6 +73,7 @@ type Statement struct {
decrColumns map[string]decrParam
exprColumns map[string]exprParam
cond builder.Cond
+ bufferSize int
}
// Init reset all the statement's fields
@@ -111,6 +112,7 @@ func (statement *Statement) Init() {
statement.decrColumns = make(map[string]decrParam)
statement.exprColumns = make(map[string]exprParam)
statement.cond = builder.NewCond()
+ statement.bufferSize = 0
}
// NoAutoCondition if you do not want convert bean's field as query condition, then use this function
@@ -158,6 +160,9 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
case string:
cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.And(cond)
+ case map[string]interface{}:
+ cond := builder.Eq(query.(map[string]interface{}))
+ statement.cond = statement.cond.And(cond)
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.And(cond)
@@ -179,6 +184,9 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
case string:
cond := builder.Expr(query.(string), args...)
statement.cond = statement.cond.Or(cond)
+ case map[string]interface{}:
+ cond := builder.Eq(query.(map[string]interface{}))
+ statement.cond = statement.cond.Or(cond)
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.Or(cond)
@@ -272,6 +280,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface())
+ if fieldType == nil {
+ continue
+ }
requiredField := useAllCols
includeNil := useAllCols
@@ -490,224 +501,6 @@ func (statement *Statement) colName(col *core.Column, tableName string) string {
return statement.Engine.Quote(col.Name)
}
-func buildConds(engine *Engine, table *core.Table, bean interface{},
- includeVersion bool, includeUpdated bool, includeNil bool,
- includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
- mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
- var conds []builder.Cond
- for _, col := range table.Columns() {
- if !includeVersion && col.IsVersion {
- continue
- }
- if !includeUpdated && col.IsUpdated {
- continue
- }
- if !includeAutoIncr && col.IsAutoIncrement {
- continue
- }
-
- if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
- continue
- }
- if col.SQLType.IsJson() {
- continue
- }
-
- var colName string
- if addedTableName {
- var nm = tableName
- if len(aliasName) > 0 {
- nm = aliasName
- }
- colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
- } else {
- colName = engine.Quote(col.Name)
- }
-
- fieldValuePtr, err := col.ValueOf(bean)
- if err != nil {
- engine.logger.Error(err)
- continue
- }
-
- if col.IsDeleted && !unscoped { // tag "deleted" is enabled
- if engine.dialect.DBType() == core.MSSQL {
- conds = append(conds, builder.IsNull{colName})
- } else {
- conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
- }
- }
-
- fieldValue := *fieldValuePtr
- if fieldValue.Interface() == nil {
- continue
- }
-
- fieldType := reflect.TypeOf(fieldValue.Interface())
- requiredField := useAllCols
-
- if b, ok := getFlagForColumn(mustColumnMap, col); ok {
- if b {
- requiredField = true
- } else {
- continue
- }
- }
-
- if fieldType.Kind() == reflect.Ptr {
- if fieldValue.IsNil() {
- if includeNil {
- conds = append(conds, builder.Eq{colName: nil})
- }
- continue
- } else if !fieldValue.IsValid() {
- continue
- } else {
- // dereference ptr type to instance type
- fieldValue = fieldValue.Elem()
- fieldType = reflect.TypeOf(fieldValue.Interface())
- requiredField = true
- }
- }
-
- var val interface{}
- switch fieldType.Kind() {
- case reflect.Bool:
- if allUseBool || requiredField {
- val = fieldValue.Interface()
- } else {
- // if a bool in a struct, it will not be as a condition because it default is false,
- // please use Where() instead
- continue
- }
- case reflect.String:
- if !requiredField && fieldValue.String() == "" {
- continue
- }
- // for MyString, should convert to string or panic
- if fieldType.String() != reflect.String.String() {
- val = fieldValue.String()
- } else {
- val = fieldValue.Interface()
- }
- case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
- if !requiredField && fieldValue.Int() == 0 {
- continue
- }
- val = fieldValue.Interface()
- case reflect.Float32, reflect.Float64:
- if !requiredField && fieldValue.Float() == 0.0 {
- continue
- }
- val = fieldValue.Interface()
- case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
- if !requiredField && fieldValue.Uint() == 0 {
- continue
- }
- t := int64(fieldValue.Uint())
- val = reflect.ValueOf(&t).Interface()
- case reflect.Struct:
- if fieldType.ConvertibleTo(core.TimeType) {
- t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
- if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
- continue
- }
- val = engine.formatColTime(col, t)
- } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
- continue
- } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
- val, _ = valNul.Value()
- if val == nil {
- continue
- }
- } else {
- if col.SQLType.IsJson() {
- if col.SQLType.IsText() {
- bytes, err := json.Marshal(fieldValue.Interface())
- if err != nil {
- engine.logger.Error(err)
- continue
- }
- val = string(bytes)
- } else if col.SQLType.IsBlob() {
- var bytes []byte
- var err error
- bytes, err = json.Marshal(fieldValue.Interface())
- if err != nil {
- engine.logger.Error(err)
- continue
- }
- val = bytes
- }
- } else {
- engine.autoMapType(fieldValue)
- if table, ok := engine.Tables[fieldValue.Type()]; ok {
- if len(table.PrimaryKeys) == 1 {
- pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
- // fix non-int pk issues
- //if pkField.Int() != 0 {
- if pkField.IsValid() && !isZero(pkField.Interface()) {
- val = pkField.Interface()
- } else {
- continue
- }
- } else {
- //TODO: how to handler?
- panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
- }
- } else {
- val = fieldValue.Interface()
- }
- }
- }
- case reflect.Array:
- continue
- case reflect.Slice, reflect.Map:
- if fieldValue == reflect.Zero(fieldType) {
- continue
- }
- if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
- continue
- }
-
- if col.SQLType.IsText() {
- bytes, err := json.Marshal(fieldValue.Interface())
- if err != nil {
- engine.logger.Error(err)
- continue
- }
- val = string(bytes)
- } else if col.SQLType.IsBlob() {
- var bytes []byte
- var err error
- if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
- fieldType.Elem().Kind() == reflect.Uint8 {
- if fieldValue.Len() > 0 {
- val = fieldValue.Bytes()
- } else {
- continue
- }
- } else {
- bytes, err = json.Marshal(fieldValue.Interface())
- if err != nil {
- engine.logger.Error(err)
- continue
- }
- val = bytes
- }
- } else {
- continue
- }
- default:
- val = fieldValue.Interface()
- }
-
- conds = append(conds, builder.Eq{colName: val})
- }
-
- return builder.And(conds...), nil
-}
-
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
@@ -810,6 +603,22 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
return newColumns
}
+func (statement *Statement) colmap2NewColsWithQuote() []string {
+ newColumns := make([]string, 0, len(statement.columnMap))
+ for col := range statement.columnMap {
+ fields := strings.Split(strings.TrimSpace(col), ".")
+ if len(fields) == 1 {
+ newColumns = append(newColumns, statement.Engine.quote(fields[0]))
+ } else if len(fields) == 2 {
+ newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
+ statement.Engine.quote(fields[1]))
+ } else {
+ panic(errors.New("unwanted colnames"))
+ }
+ }
+ return newColumns
+}
+
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
@@ -836,7 +645,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
statement.columnMap[strings.ToLower(nc)] = true
}
- newColumns := statement.col2NewColsWithQuote(columns...)
+ newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement
@@ -1098,32 +907,45 @@ func (statement *Statement) genDelIndexSQL() []string {
func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := statement.Engine.Quote
- sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
+ sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
col.String(statement.Engine.dialect))
+ if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
+ sql += " COMMENT '" + col.Comment + "'"
+ }
+ sql += ";"
return sql, []interface{}{}
}
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
- return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
+ return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
}
-func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
+func (statement *Statement) mergeConds(bean interface{}) error {
if !statement.noAutoCondition {
var addedTableName = (len(statement.JoinStr) > 0)
autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil {
- return "", nil, err
+ return err
}
statement.cond = statement.cond.And(autoCond)
}
- statement.processIDParam()
+ if err := statement.processIDParam(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
+ if err := statement.mergeConds(bean); err != nil {
+ return "", nil, err
+ }
return builder.ToSQL(statement.cond)
}
-func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
+func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
@@ -1156,21 +978,37 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
columnStr = "*"
}
- var condSQL string
- var condArgs []interface{}
if isStruct {
- condSQL, condArgs, _ = statement.genConds(bean)
- } else {
- condSQL, condArgs, _ = builder.ToSQL(statement.cond)
+ if err := statement.mergeConds(bean); err != nil {
+ return "", nil, err
+ }
+ }
+ condSQL, condArgs, err := builder.ToSQL(statement.cond)
+ if err != nil {
+ return "", nil, err
}
- return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
-}
+ sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
+ if err != nil {
+ return "", nil, err
+ }
-func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
- statement.setRefValue(rValue(bean))
+ return sqlStr, append(statement.joinArgs, condArgs...), nil
+}
- condSQL, condArgs, _ := statement.genConds(bean)
+func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
+ var condSQL string
+ var condArgs []interface{}
+ var err error
+ if len(beans) > 0 {
+ statement.setRefValue(rValue(beans[0]))
+ condSQL, condArgs, err = statement.genConds(beans[0])
+ } else {
+ condSQL, condArgs, err = builder.ToSQL(statement.cond)
+ }
+ if err != nil {
+ return "", nil, err
+ }
var selectSQL = statement.selectStr
if len(selectSQL) <= 0 {
@@ -1180,10 +1018,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}
selectSQL = "count(*)"
}
}
- return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...)
+ sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
+ if err != nil {
+ return "", nil, err
+ }
+
+ return sqlStr, append(statement.joinArgs, condArgs...), nil
}
-func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
+func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefValue(rValue(bean))
var sumStrs = make([]string, 0, len(columns))
@@ -1195,12 +1038,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
}
sumSelect := strings.Join(sumStrs, ", ")
- condSQL, condArgs, _ := statement.genConds(bean)
+ condSQL, condArgs, err := statement.genConds(bean)
+ if err != nil {
+ return "", nil, err
+ }
+
+ sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
+ if err != nil {
+ return "", nil, err
+ }
- return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...)
+ return sqlStr, append(statement.joinArgs, condArgs...), nil
}
-func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
+func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
@@ -1211,7 +1062,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
var top string
var mssqlCondi string
- statement.processIDParam()
+ if err := statement.processIDParam(); err != nil {
+ return "", err
+ }
var buf bytes.Buffer
if len(condSQL) > 0 {
@@ -1296,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
- if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
- if statement.Start > 0 {
- a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
- } else if statement.LimitN > 0 {
- a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
- }
- } else if dialect.DBType() == core.ORACLE {
- if statement.Start != 0 || statement.LimitN != 0 {
- a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
+ if needLimit {
+ if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
+ if statement.Start > 0 {
+ a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
+ } else if statement.LimitN > 0 {
+ a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
+ }
+ } else if dialect.DBType() == core.ORACLE {
+ if statement.Start != 0 || statement.LimitN != 0 {
+ a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
+ }
}
}
if statement.IsForUpdate {
@@ -1314,19 +1169,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
return
}
-func (statement *Statement) processIDParam() {
+func (statement *Statement) processIDParam() error {
if statement.idParam == nil {
- return
+ return nil
+ }
+
+ if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
+ return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
+ len(statement.RefTable.PrimaryKeys),
+ len(*statement.idParam),
+ )
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
- if i < len(*(statement.idParam)) {
- statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
- } else {
- statement.cond = statement.cond.And(builder.Eq{colName: ""})
- }
+ statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
}
+ return nil
}
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
@@ -1360,7 +1219,8 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
top = fmt.Sprintf("TOP %d ", statement.LimitN)
}
- return fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
+ newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
+ return newsql
}
return ""
}