aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-xorm/xorm/session_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-xorm/xorm/session_update.go')
-rw-r--r--vendor/github.com/go-xorm/xorm/session_update.go71
1 files changed, 42 insertions, 29 deletions
diff --git a/vendor/github.com/go-xorm/xorm/session_update.go b/vendor/github.com/go-xorm/xorm/session_update.go
index 27e2deb0..1d77d294 100644
--- a/vendor/github.com/go-xorm/xorm/session_update.go
+++ b/vendor/github.com/go-xorm/xorm/session_update.go
@@ -169,7 +169,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
if isStruct {
- session.Statement.setRefValue(v)
+ if err := session.Statement.setRefValue(v); err != nil {
+ return 0, err
+ }
if len(session.Statement.TableName()) <= 0 {
return 0, ErrTableNotFound
@@ -253,48 +255,59 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var condSQL string
cond := session.Statement.cond.And(autoCond)
- doIncVer := false
+ var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion)
var verValue *reflect.Value
- if table != nil && table.Version != "" && session.Statement.checkVersion {
+ if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean)
if err != nil {
return 0, err
}
cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()})
- condSQL, condArgs, _ = builder.ToSQL(cond)
-
- if len(condSQL) > 0 {
- condSQL = "WHERE " + condSQL
- }
-
- if st.LimitN > 0 {
- condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
- }
+ colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
+ }
- sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v",
- session.Engine.Quote(session.Statement.TableName()),
- strings.Join(colNames, ", "),
- session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
- condSQL)
+ condSQL, condArgs, _ = builder.ToSQL(cond)
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
- doIncVer = true
- } else {
- condSQL, condArgs, _ = builder.ToSQL(cond)
- if len(condSQL) > 0 {
- condSQL = "WHERE " + condSQL
- }
+ if st.OrderStr != "" {
+ condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr)
+ }
- if st.LimitN > 0 {
+ // TODO: Oracle support needed
+ var top string
+ if st.LimitN > 0 {
+ if st.Engine.dialect.DBType() == core.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ } else if st.Engine.dialect.DBType() == core.SQLITE {
+ tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
+ session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
+ condSQL, condArgs, _ = builder.ToSQL(cond)
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
+ } else if st.Engine.dialect.DBType() == core.POSTGRES {
+ tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
+ session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
+ condSQL, condArgs, _ = builder.ToSQL(cond)
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
+ } else if st.Engine.dialect.DBType() == core.MSSQL {
+ top = fmt.Sprintf("top (%d) ", st.LimitN)
}
-
- sqlStr = fmt.Sprintf("UPDATE %v SET %v %v",
- session.Engine.Quote(session.Statement.TableName()),
- strings.Join(colNames, ", "),
- condSQL)
}
+ sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
+ top,
+ session.Engine.Quote(session.Statement.TableName()),
+ strings.Join(colNames, ", "),
+ condSQL)
+
res, err := session.exec(sqlStr, append(args, condArgs...)...)
if err != nil {
return 0, err