aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorᴜɴᴋɴᴡᴏɴ <u@gogs.io>2020-09-06 17:02:25 +0800
committerGitHub <noreply@github.com>2020-09-06 17:02:25 +0800
commit06193ed825094581da4fb5550804d7445fc5ae13 (patch)
treebd2d966235a224b9d28d316ef915481a63fac5f9 /internal/db
parent519e59b5778571ace3f681b81a21b92a38ede890 (diff)
schemadoc: add go:generate to output database schema (#6310)
* schemadoc: add go:generate to output database schema * Check errors * Revert string renames
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/backup.go4
-rw-r--r--internal/db/backup_test.go10
-rw-r--r--internal/db/db.go6
-rw-r--r--internal/db/schemadoc/main.go149
4 files changed, 160 insertions, 9 deletions
diff --git a/internal/db/backup.go b/internal/db/backup.go
index 985b3c0f..2f76bb04 100644
--- a/internal/db/backup.go
+++ b/internal/db/backup.go
@@ -40,7 +40,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
return errors.Wrap(err, "dump legacy tables")
}
- for _, table := range tables {
+ for _, table := range Tables {
tableName := getTableType(table)
if verbose {
log.Trace("Dumping table %q...", tableName)
@@ -125,7 +125,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
return errors.Wrap(err, "import legacy tables")
}
- for _, table := range tables {
+ for _, table := range Tables {
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
err := func() error {
tableFile := filepath.Join(dirPath, tableName+".json")
diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go
index fafaf28a..a65e6944 100644
--- a/internal/db/backup_test.go
+++ b/internal/db/backup_test.go
@@ -26,11 +26,11 @@ func Test_dumpAndImport(t *testing.T) {
t.Parallel()
- if len(tables) != 3 {
- t.Fatalf("New table has added (want 3 got %d), please add new tests for the table and update this check", len(tables))
+ if len(Tables) != 3 {
+ t.Fatalf("New table has added (want 3 got %d), please add new tests for the table and update this check", len(Tables))
}
- db := initTestDB(t, "dumpAndImport", tables...)
+ db := initTestDB(t, "dumpAndImport", Tables...)
setupDBToDump(t, db)
dumpTables(t, db)
importTables(t, db)
@@ -109,7 +109,7 @@ func setupDBToDump(t *testing.T, db *gorm.DB) {
func dumpTables(t *testing.T, db *gorm.DB) {
t.Helper()
- for _, table := range tables {
+ for _, table := range Tables {
tableName := getTableType(table)
var buf bytes.Buffer
@@ -126,7 +126,7 @@ func dumpTables(t *testing.T, db *gorm.DB) {
func importTables(t *testing.T, db *gorm.DB) {
t.Helper()
- for _, table := range tables {
+ for _, table := range Tables {
tableName := getTableType(table)
err := func() error {
diff --git a/internal/db/db.go b/internal/db/db.go
index 03da40ca..ce130f2c 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -137,8 +137,10 @@ func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) {
return gorm.Open(dialector, cfg)
}
+// Tables is the list of struct-to-table mappings.
+//
// NOTE: Lines are sorted in alphabetical order, each letter in its own line.
-var tables = []interface{}{
+var Tables = []interface{}{
new(AccessToken),
new(LFSObject), new(LoginSource),
}
@@ -196,7 +198,7 @@ func Init() (*gorm.DB, error) {
// NOTE: GORM has problem detecting existing columns, see https://github.com/gogs/gogs/issues/6091.
// Therefore only use it to create new tables, and do customized migration with future changes.
- for _, table := range tables {
+ for _, table := range Tables {
if db.Migrator().HasTable(table) {
continue
}
diff --git a/internal/db/schemadoc/main.go b/internal/db/schemadoc/main.go
new file mode 100644
index 00000000..3048f72f
--- /dev/null
+++ b/internal/db/schemadoc/main.go
@@ -0,0 +1,149 @@
+package main
+
+import (
+ "log"
+ "os"
+ "strings"
+
+ "github.com/olekukonko/tablewriter"
+ "github.com/pkg/errors"
+ "gopkg.in/DATA-DOG/go-sqlmock.v2"
+ "gorm.io/driver/mysql"
+ "gorm.io/driver/postgres"
+ "gorm.io/driver/sqlite"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+ "gorm.io/gorm/schema"
+
+ "gogs.io/gogs/internal/db"
+)
+
+//go:generate go run main.go ../../../docs/dev/database_schema.md
+
+func main() {
+ w, err := os.Create(os.Args[1])
+ if err != nil {
+ log.Fatalf("Failed to create file: %v", err)
+ }
+ defer func() { _ = w.Close() }()
+
+ conn, _, err := sqlmock.New()
+ if err != nil {
+ log.Fatalf("Failed to get mock connection: %v", err)
+ }
+ defer func() { _ = conn.Close() }()
+
+ dialectors := []gorm.Dialector{
+ postgres.New(postgres.Config{
+ Conn: conn,
+ }),
+ mysql.New(mysql.Config{
+ Conn: conn,
+ SkipInitializeWithVersion: true,
+ }),
+ sqlite.Open(""),
+ }
+ collected := make([][]*tableInfo, 0, len(dialectors))
+ for i, dialector := range dialectors {
+ tableInfos, err := generate(dialector)
+ if err != nil {
+ log.Fatalf("Failed to get table info of %d: %v", i, err)
+ }
+
+ collected = append(collected, tableInfos)
+ }
+
+ for i, ti := range collected[0] {
+ _, _ = w.WriteString(`# Table "` + ti.Name + `"`)
+ _, _ = w.WriteString("\n\n")
+
+ _, _ = w.WriteString("```\n")
+
+ table := tablewriter.NewWriter(w)
+ table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})
+ table.SetBorder(false)
+ for j, f := range ti.Fields {
+ table.Append([]string{
+ f.Name, f.Column,
+ strings.ToUpper(f.Type), // PostgreSQL
+ strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL
+ strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3
+ })
+ }
+ table.Render()
+ _, _ = w.WriteString("\n")
+
+ _, _ = w.WriteString("Primary keys: ")
+ _, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))
+ _, _ = w.WriteString("\n")
+
+ _, _ = w.WriteString("```\n\n")
+ }
+}
+
+type tableField struct {
+ Name string
+ Column string
+ Type string
+}
+
+type tableInfo struct {
+ Name string
+ Fields []*tableField
+ PrimaryKeys []string
+}
+
+// This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.
+func generate(dialector gorm.Dialector) ([]*tableInfo, error) {
+ conn, err := gorm.Open(dialector, &gorm.Config{
+ NamingStrategy: schema.NamingStrategy{
+ SingularTable: true,
+ },
+ DryRun: true,
+ DisableAutomaticPing: true,
+ })
+ if err != nil {
+ return nil, errors.Wrap(err, "open database")
+ }
+
+ m := conn.Migrator().(interface {
+ RunWithValue(value interface{}, fc func(*gorm.Statement) error) error
+ FullDataTypeOf(*schema.Field) clause.Expr
+ })
+ tableInfos := make([]*tableInfo, 0, len(db.Tables))
+ for _, table := range db.Tables {
+ err = m.RunWithValue(table, func(stmt *gorm.Statement) error {
+ fields := make([]*tableField, 0, len(stmt.Schema.DBNames))
+ for _, field := range stmt.Schema.Fields {
+ if field.DBName == "" {
+ continue
+ }
+
+ fields = append(fields, &tableField{
+ Name: field.Name,
+ Column: field.DBName,
+ Type: m.FullDataTypeOf(field).SQL,
+ })
+ }
+
+ primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))
+ if len(stmt.Schema.PrimaryFields) > 0 {
+ for _, field := range stmt.Schema.PrimaryFields {
+ primaryKeys = append(primaryKeys, field.DBName)
+ }
+ }
+
+ tableInfos = append(tableInfos, &tableInfo{
+ Name: stmt.Table,
+ Fields: fields,
+ PrimaryKeys: primaryKeys,
+ })
+ return nil
+ })
+ if err != nil {
+ return nil, errors.Wrap(err, "gather table information")
+ }
+ }
+
+ return tableInfos, nil
+}