aboutsummaryrefslogtreecommitdiff
path: root/internal/db/main_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/main_test.go')
-rw-r--r--internal/db/main_test.go108
1 files changed, 89 insertions, 19 deletions
diff --git a/internal/db/main_test.go b/internal/db/main_test.go
index 9491bda9..cedd7cfc 100644
--- a/internal/db/main_test.go
+++ b/internal/db/main_test.go
@@ -5,6 +5,7 @@
package db
import (
+ "database/sql"
"flag"
"fmt"
"os"
@@ -12,6 +13,7 @@ import (
"testing"
"time"
+ "github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
@@ -59,16 +61,92 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error {
}
func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
- t.Helper()
+ dbType := os.Getenv("GOGS_DATABASE_TYPE")
+
+ var dbName string
+ var dbOpts conf.DatabaseOpts
+ var cleanup func(db *gorm.DB)
+ switch dbType {
+ case "mysql":
+ dbOpts = conf.DatabaseOpts{
+ Type: "mysql",
+ Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
+ Name: dbName,
+ User: os.Getenv("MYSQL_USER"),
+ Password: os.Getenv("MYSQL_PASSWORD"),
+ }
+
+ dsn, err := newDSN(dbOpts)
+ require.NoError(t, err)
+
+ sqlDB, err := sql.Open("mysql", dsn)
+ require.NoError(t, err)
+
+ // Set up test database
+ dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
+ _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
+ require.NoError(t, err)
+
+ _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
+ require.NoError(t, err)
+
+ dbOpts.Name = dbName
+
+ cleanup = func(db *gorm.DB) {
+ db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
+ _ = sqlDB.Close()
+ }
+ case "postgres":
+ dbOpts = conf.DatabaseOpts{
+ Type: "postgres",
+ Host: os.ExpandEnv("$PGHOST:$PGPORT"),
+ Name: dbName,
+ Schema: "public",
+ User: os.Getenv("PGUSER"),
+ Password: os.Getenv("PGPASSWORD"),
+ SSLMode: os.Getenv("PGSSLMODE"),
+ }
+
+ dsn, err := newDSN(dbOpts)
+ require.NoError(t, err)
+
+ sqlDB, err := sql.Open("pgx", dsn)
+ require.NoError(t, err)
+
+ // Set up test database
+ dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
+ _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
+ require.NoError(t, err)
+
+ _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
+ require.NoError(t, err)
+
+ dbOpts.Name = dbName
+
+ cleanup = func(db *gorm.DB) {
+ db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
+ _ = sqlDB.Close()
+ }
+ default:
+ dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
+ dbOpts = conf.DatabaseOpts{
+ Type: "sqlite3",
+ Path: dbName,
+ }
+ cleanup = func(db *gorm.DB) {
+ sqlDB, err := db.DB()
+ if err == nil {
+ _ = sqlDB.Close()
+ }
+ _ = os.Remove(dbName)
+ }
+ }
- dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
now := time.Now().UTC().Truncate(time.Second)
db, err := openDB(
- conf.DatabaseOpts{
- Type: "sqlite3",
- Path: dbpath,
- },
+ dbOpts,
&gorm.Config{
+ SkipDefaultTransaction: true,
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
@@ -77,27 +155,19 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
},
},
)
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(func() {
- sqlDB, err := db.DB()
- if err == nil {
- _ = sqlDB.Close()
- }
+ require.NoError(t, err)
+ t.Cleanup(func() {
if t.Failed() {
- t.Logf("Database %q left intact for inspection", dbpath)
+ t.Logf("Database %q left intact for inspection", dbName)
return
}
- _ = os.Remove(dbpath)
+ cleanup(db)
})
err = db.Migrator().AutoMigrate(tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
return db
}