aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorJoe Chen <jc@unknwon.io>2022-06-01 22:51:46 +0800
committerGitHub <noreply@github.com>2022-06-01 22:51:46 +0800
commit5f34265db6548e3be72b865e41f227137b18474c (patch)
tree6452369976b1ef2c2567b68ca73c09e2122a2cd7 /internal/db
parent05cdf8616ba4e9a27a37c68bada31568d495f26a (diff)
ci: run database tests against Postgres, MySQL and SQLite (#6996)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/backup.go5
-rw-r--r--internal/db/db.go24
-rw-r--r--internal/db/db_test.go4
-rw-r--r--internal/db/lfs_test.go5
-rw-r--r--internal/db/login_sources_test.go5
-rw-r--r--internal/db/main_test.go108
-rw-r--r--internal/db/perms_test.go3
-rw-r--r--internal/db/repos_test.go3
-rw-r--r--internal/db/two_factors_test.go3
-rw-r--r--internal/db/users_test.go7
10 files changed, 133 insertions, 34 deletions
diff --git a/internal/db/backup.go b/internal/db/backup.go
index e03d78c1..836672b1 100644
--- a/internal/db/backup.go
+++ b/internal/db/backup.go
@@ -85,6 +85,11 @@ func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
return errors.Wrap(err, "scan rows")
}
+ switch e := elem.(type) {
+ case *LFSObject:
+ e.CreatedAt = e.CreatedAt.UTC()
+ }
+
err = jsoniter.NewEncoder(w).Encode(elem)
if err != nil {
return errors.Wrap(err, "encode JSON")
diff --git a/internal/db/db.go b/internal/db/db.go
index 7d2c12e2..987d90df 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -53,8 +53,8 @@ func parseMSSQLHostPort(info string) (host, port string) {
return host, port
}
-// parseDSN takes given database options and returns parsed DSN.
-func parseDSN(opts conf.DatabaseOpts) (dsn string, err error) {
+// newDSN takes given database options and returns parsed DSN.
+func newDSN(opts conf.DatabaseOpts) (dsn string, err error) {
// In case the database name contains "?" with some parameters
concate := "?"
if strings.Contains(opts.Name, concate) {
@@ -109,7 +109,7 @@ func newLogWriter() (logger.Writer, error) {
}
func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) {
- dsn, err := parseDSN(opts)
+ dsn, err := newDSN(opts)
if err != nil {
return nil, errors.Wrap(err, "parse DSN")
}
@@ -151,14 +151,18 @@ func Init(w logger.Writer) (*gorm.DB, error) {
LogLevel: level,
})
- db, err := openDB(conf.Database, &gorm.Config{
- NamingStrategy: schema.NamingStrategy{
- SingularTable: true,
+ db, err := openDB(
+ conf.Database,
+ &gorm.Config{
+ SkipDefaultTransaction: true,
+ NamingStrategy: schema.NamingStrategy{
+ SingularTable: true,
+ },
+ NowFunc: func() time.Time {
+ return time.Now().UTC().Truncate(time.Microsecond)
+ },
},
- NowFunc: func() time.Time {
- return time.Now().UTC().Truncate(time.Microsecond)
- },
- })
+ )
if err != nil {
return nil, errors.Wrap(err, "open database")
}
diff --git a/internal/db/db_test.go b/internal/db/db_test.go
index e16c9486..1f4f0109 100644
--- a/internal/db/db_test.go
+++ b/internal/db/db_test.go
@@ -56,7 +56,7 @@ func Test_parseMSSQLHostPort(t *testing.T) {
func Test_parseDSN(t *testing.T) {
t.Run("bad dialect", func(t *testing.T) {
- _, err := parseDSN(conf.DatabaseOpts{
+ _, err := newDSN(conf.DatabaseOpts{
Type: "bad_dialect",
})
assert.Equal(t, "unrecognized dialect: bad_dialect", fmt.Sprintf("%v", err))
@@ -140,7 +140,7 @@ func Test_parseDSN(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- dsn, err := parseDSN(test.opts)
+ dsn, err := newDSN(test.opts)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go
index 4aff4f22..6c7af93a 100644
--- a/internal/db/lfs_test.go
+++ b/internal/db/lfs_test.go
@@ -43,6 +43,9 @@ func Test_lfs(t *testing.T) {
})
tc.test(t, db)
})
+ if t.Failed() {
+ break
+ }
}
}
@@ -60,7 +63,7 @@ func test_lfs_CreateObject(t *testing.T, db *lfs) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, db.NowFunc().Format(time.RFC3339), object.CreatedAt.Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), object.CreatedAt.UTC().Format(time.RFC3339))
// Try create second LFS object with same oid should fail
err = db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal)
diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go
index 280a1bca..f8f9a0c5 100644
--- a/internal/db/login_sources_test.go
+++ b/internal/db/login_sources_test.go
@@ -21,6 +21,7 @@ func TestLoginSource_BeforeSave(t *testing.T) {
now := time.Now()
db := &gorm.DB{
Config: &gorm.Config{
+ SkipDefaultTransaction: true,
NowFunc: func() time.Time {
return now
},
@@ -54,6 +55,7 @@ func TestLoginSource_BeforeCreate(t *testing.T) {
now := time.Now()
db := &gorm.DB{
Config: &gorm.Config{
+ SkipDefaultTransaction: true,
NowFunc: func() time.Time {
return now
},
@@ -108,6 +110,9 @@ func Test_loginSources(t *testing.T) {
})
tc.test(t, db)
})
+ if t.Failed() {
+ break
+ }
}
}
diff --git a/internal/db/main_test.go b/internal/db/main_test.go
index 9491bda9..cedd7cfc 100644
--- a/internal/db/main_test.go
+++ b/internal/db/main_test.go
@@ -5,6 +5,7 @@
package db
import (
+ "database/sql"
"flag"
"fmt"
"os"
@@ -12,6 +13,7 @@ import (
"testing"
"time"
+ "github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
@@ -59,16 +61,92 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error {
}
func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
- t.Helper()
+ dbType := os.Getenv("GOGS_DATABASE_TYPE")
+
+ var dbName string
+ var dbOpts conf.DatabaseOpts
+ var cleanup func(db *gorm.DB)
+ switch dbType {
+ case "mysql":
+ dbOpts = conf.DatabaseOpts{
+ Type: "mysql",
+ Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
+ Name: dbName,
+ User: os.Getenv("MYSQL_USER"),
+ Password: os.Getenv("MYSQL_PASSWORD"),
+ }
+
+ dsn, err := newDSN(dbOpts)
+ require.NoError(t, err)
+
+ sqlDB, err := sql.Open("mysql", dsn)
+ require.NoError(t, err)
+
+ // Set up test database
+ dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
+ _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
+ require.NoError(t, err)
+
+ _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
+ require.NoError(t, err)
+
+ dbOpts.Name = dbName
+
+ cleanup = func(db *gorm.DB) {
+ db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
+ _ = sqlDB.Close()
+ }
+ case "postgres":
+ dbOpts = conf.DatabaseOpts{
+ Type: "postgres",
+ Host: os.ExpandEnv("$PGHOST:$PGPORT"),
+ Name: dbName,
+ Schema: "public",
+ User: os.Getenv("PGUSER"),
+ Password: os.Getenv("PGPASSWORD"),
+ SSLMode: os.Getenv("PGSSLMODE"),
+ }
+
+ dsn, err := newDSN(dbOpts)
+ require.NoError(t, err)
+
+ sqlDB, err := sql.Open("pgx", dsn)
+ require.NoError(t, err)
+
+ // Set up test database
+ dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
+ _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
+ require.NoError(t, err)
+
+ _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
+ require.NoError(t, err)
+
+ dbOpts.Name = dbName
+
+ cleanup = func(db *gorm.DB) {
+ db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
+ _ = sqlDB.Close()
+ }
+ default:
+ dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
+ dbOpts = conf.DatabaseOpts{
+ Type: "sqlite3",
+ Path: dbName,
+ }
+ cleanup = func(db *gorm.DB) {
+ sqlDB, err := db.DB()
+ if err == nil {
+ _ = sqlDB.Close()
+ }
+ _ = os.Remove(dbName)
+ }
+ }
- dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
now := time.Now().UTC().Truncate(time.Second)
db, err := openDB(
- conf.DatabaseOpts{
- Type: "sqlite3",
- Path: dbpath,
- },
+ dbOpts,
&gorm.Config{
+ SkipDefaultTransaction: true,
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
@@ -77,27 +155,19 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
},
},
)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(func() {
- sqlDB, err := db.DB()
- if err == nil {
- _ = sqlDB.Close()
- }
+ require.NoError(t, err)
+ t.Cleanup(func() {
if t.Failed() {
- t.Logf("Database %q left intact for inspection", dbpath)
+ t.Logf("Database %q left intact for inspection", dbName)
return
}
- _ = os.Remove(dbpath)
+ cleanup(db)
})
err = db.Migrator().AutoMigrate(tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
return db
}
diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go
index e242f843..9b821e6b 100644
--- a/internal/db/perms_test.go
+++ b/internal/db/perms_test.go
@@ -39,6 +39,9 @@ func Test_perms(t *testing.T) {
})
tc.test(t, db)
})
+ if t.Failed() {
+ break
+ }
}
}
diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go
index 2949549f..d248f3b0 100644
--- a/internal/db/repos_test.go
+++ b/internal/db/repos_test.go
@@ -41,6 +41,9 @@ func Test_repos(t *testing.T) {
})
tc.test(t, db)
})
+ if t.Failed() {
+ break
+ }
}
}
diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go
index 20dfae61..c8412213 100644
--- a/internal/db/two_factors_test.go
+++ b/internal/db/two_factors_test.go
@@ -42,6 +42,9 @@ func Test_twoFactors(t *testing.T) {
})
tc.test(t, db)
})
+ if t.Failed() {
+ break
+ }
}
}
diff --git a/internal/db/users_test.go b/internal/db/users_test.go
index cb93b036..dac2c208 100644
--- a/internal/db/users_test.go
+++ b/internal/db/users_test.go
@@ -45,6 +45,9 @@ func Test_users(t *testing.T) {
})
tc.test(t, db)
})
+ if t.Failed() {
+ break
+ }
}
}
@@ -136,7 +139,7 @@ func test_users_GetByEmail(t *testing.T, db *users) {
t.Fatal(err)
}
- err = db.Exec(`UPDATE user SET type = ? WHERE id = ?`, UserOrganization, org.ID).Error
+ err = db.Model(&User{}).Where("id", org.ID).UpdateColumn("type", UserOrganization).Error
if err != nil {
t.Fatal(err)
}
@@ -158,7 +161,7 @@ func test_users_GetByEmail(t *testing.T, db *users) {
// Mark user as activated
// TODO: Use UserEmails.Verify to replace SQL hack when the method is available.
- err = db.Exec(`UPDATE user SET is_active = ? WHERE id = ?`, true, alice.ID).Error
+ err = db.Model(&User{}).Where("id", alice.ID).UpdateColumn("is_active", true).Error
if err != nil {
t.Fatal(err)
}