diff options
Diffstat (limited to 'internal/db/main_test.go')
-rw-r--r-- | internal/db/main_test.go | 108 |
1 files changed, 89 insertions, 19 deletions
diff --git a/internal/db/main_test.go b/internal/db/main_test.go index 9491bda9..cedd7cfc 100644 --- a/internal/db/main_test.go +++ b/internal/db/main_test.go @@ -5,6 +5,7 @@ package db import ( + "database/sql" "flag" "fmt" "os" @@ -12,6 +13,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -59,16 +61,92 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error { } func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB { - t.Helper() + dbType := os.Getenv("GOGS_DATABASE_TYPE") + + var dbName string + var dbOpts conf.DatabaseOpts + var cleanup func(db *gorm.DB) + switch dbType { + case "mysql": + dbOpts = conf.DatabaseOpts{ + Type: "mysql", + Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"), + Name: dbName, + User: os.Getenv("MYSQL_USER"), + Password: os.Getenv("MYSQL_PASSWORD"), + } + + dsn, err := newDSN(dbOpts) + require.NoError(t, err) + + sqlDB, err := sql.Open("mysql", dsn) + require.NoError(t, err) + + // Set up test database + dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix()) + _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName)) + require.NoError(t, err) + + _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName)) + require.NoError(t, err) + + dbOpts.Name = dbName + + cleanup = func(db *gorm.DB) { + db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName)) + _ = sqlDB.Close() + } + case "postgres": + dbOpts = conf.DatabaseOpts{ + Type: "postgres", + Host: os.ExpandEnv("$PGHOST:$PGPORT"), + Name: dbName, + Schema: "public", + User: os.Getenv("PGUSER"), + Password: os.Getenv("PGPASSWORD"), + SSLMode: os.Getenv("PGSSLMODE"), + } + + dsn, err := newDSN(dbOpts) + require.NoError(t, err) + + sqlDB, err := sql.Open("pgx", dsn) + require.NoError(t, err) + + // Set up test database + dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix()) + _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName)) + require.NoError(t, err) + + _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName)) + require.NoError(t, err) + + dbOpts.Name = dbName + + cleanup = func(db *gorm.DB) { + db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName)) + _ = sqlDB.Close() + } + default: + dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) + dbOpts = conf.DatabaseOpts{ + Type: "sqlite3", + Path: dbName, + } + cleanup = func(db *gorm.DB) { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + _ = os.Remove(dbName) + } + } - dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) now := time.Now().UTC().Truncate(time.Second) db, err := openDB( - conf.DatabaseOpts{ - Type: "sqlite3", - Path: dbpath, - }, + dbOpts, &gorm.Config{ + SkipDefaultTransaction: true, NamingStrategy: schema.NamingStrategy{ SingularTable: true, }, @@ -77,27 +155,19 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB { }, }, ) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - sqlDB, err := db.DB() - if err == nil { - _ = sqlDB.Close() - } + require.NoError(t, err) + t.Cleanup(func() { if t.Failed() { - t.Logf("Database %q left intact for inspection", dbpath) + t.Logf("Database %q left intact for inspection", dbName) return } - _ = os.Remove(dbpath) + cleanup(db) }) err = db.Migrator().AutoMigrate(tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return db } |