diff options
Diffstat (limited to 'vendor/github.com/go-xorm/xorm/session_update.go')
-rw-r--r-- | vendor/github.com/go-xorm/xorm/session_update.go | 71 |
1 files changed, 42 insertions, 29 deletions
diff --git a/vendor/github.com/go-xorm/xorm/session_update.go b/vendor/github.com/go-xorm/xorm/session_update.go index 27e2deb0..1d77d294 100644 --- a/vendor/github.com/go-xorm/xorm/session_update.go +++ b/vendor/github.com/go-xorm/xorm/session_update.go @@ -169,7 +169,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - session.Statement.setRefValue(v) + if err := session.Statement.setRefValue(v); err != nil { + return 0, err + } if len(session.Statement.TableName()) <= 0 { return 0, ErrTableNotFound @@ -253,48 +255,59 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var condSQL string cond := session.Statement.cond.And(autoCond) - doIncVer := false + var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion) var verValue *reflect.Value - if table != nil && table.Version != "" && session.Statement.checkVersion { + if doIncVer { verValue, err = table.VersionColumn().ValueOf(bean) if err != nil { return 0, err } cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()}) - condSQL, condArgs, _ = builder.ToSQL(cond) - - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } - - if st.LimitN > 0 { - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) - } + colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") + } - sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v", - session.Engine.Quote(session.Statement.TableName()), - strings.Join(colNames, ", "), - session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", - condSQL) + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } - doIncVer = true - } else { - condSQL, condArgs, _ = builder.ToSQL(cond) - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } + if st.OrderStr != "" { + condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr) + } - if st.LimitN > 0 { + // TODO: Oracle support needed + var top string + if st.LimitN > 0 { + if st.Engine.dialect.DBType() == core.MYSQL { condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + } else if st.Engine.dialect.DBType() == core.SQLITE { + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", + session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + } else if st.Engine.dialect.DBType() == core.POSTGRES { + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", + session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + } else if st.Engine.dialect.DBType() == core.MSSQL { + top = fmt.Sprintf("top (%d) ", st.LimitN) } - - sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", - session.Engine.Quote(session.Statement.TableName()), - strings.Join(colNames, ", "), - condSQL) } + sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v", + top, + session.Engine.Quote(session.Statement.TableName()), + strings.Join(colNames, ", "), + condSQL) + res, err := session.exec(sqlStr, append(args, condArgs...)...) if err != nil { return 0, err |