diff options
author | ᴜɴᴋɴᴡᴏɴ <u@gogs.io> | 2020-09-06 17:02:25 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-06 17:02:25 +0800 |
commit | 06193ed825094581da4fb5550804d7445fc5ae13 (patch) | |
tree | bd2d966235a224b9d28d316ef915481a63fac5f9 /internal/db | |
parent | 519e59b5778571ace3f681b81a21b92a38ede890 (diff) |
schemadoc: add go:generate to output database schema (#6310)
* schemadoc: add go:generate to output database schema
* Check errors
* Revert string renames
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/backup.go | 4 | ||||
-rw-r--r-- | internal/db/backup_test.go | 10 | ||||
-rw-r--r-- | internal/db/db.go | 6 | ||||
-rw-r--r-- | internal/db/schemadoc/main.go | 149 |
4 files changed, 160 insertions, 9 deletions
diff --git a/internal/db/backup.go b/internal/db/backup.go index 985b3c0f..2f76bb04 100644 --- a/internal/db/backup.go +++ b/internal/db/backup.go @@ -40,7 +40,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error { return errors.Wrap(err, "dump legacy tables") } - for _, table := range tables { + for _, table := range Tables { tableName := getTableType(table) if verbose { log.Trace("Dumping table %q...", tableName) @@ -125,7 +125,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error { return errors.Wrap(err, "import legacy tables") } - for _, table := range tables { + for _, table := range Tables { tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.") err := func() error { tableFile := filepath.Join(dirPath, tableName+".json") diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go index fafaf28a..a65e6944 100644 --- a/internal/db/backup_test.go +++ b/internal/db/backup_test.go @@ -26,11 +26,11 @@ func Test_dumpAndImport(t *testing.T) { t.Parallel() - if len(tables) != 3 { - t.Fatalf("New table has added (want 3 got %d), please add new tests for the table and update this check", len(tables)) + if len(Tables) != 3 { + t.Fatalf("New table has added (want 3 got %d), please add new tests for the table and update this check", len(Tables)) } - db := initTestDB(t, "dumpAndImport", tables...) + db := initTestDB(t, "dumpAndImport", Tables...) setupDBToDump(t, db) dumpTables(t, db) importTables(t, db) @@ -109,7 +109,7 @@ func setupDBToDump(t *testing.T, db *gorm.DB) { func dumpTables(t *testing.T, db *gorm.DB) { t.Helper() - for _, table := range tables { + for _, table := range Tables { tableName := getTableType(table) var buf bytes.Buffer @@ -126,7 +126,7 @@ func dumpTables(t *testing.T, db *gorm.DB) { func importTables(t *testing.T, db *gorm.DB) { t.Helper() - for _, table := range tables { + for _, table := range Tables { tableName := getTableType(table) err := func() error { diff --git a/internal/db/db.go b/internal/db/db.go index 03da40ca..ce130f2c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -137,8 +137,10 @@ func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) { return gorm.Open(dialector, cfg) } +// Tables is the list of struct-to-table mappings. +// // NOTE: Lines are sorted in alphabetical order, each letter in its own line. -var tables = []interface{}{ +var Tables = []interface{}{ new(AccessToken), new(LFSObject), new(LoginSource), } @@ -196,7 +198,7 @@ func Init() (*gorm.DB, error) { // NOTE: GORM has problem detecting existing columns, see https://github.com/gogs/gogs/issues/6091. // Therefore only use it to create new tables, and do customized migration with future changes. - for _, table := range tables { + for _, table := range Tables { if db.Migrator().HasTable(table) { continue } diff --git a/internal/db/schemadoc/main.go b/internal/db/schemadoc/main.go new file mode 100644 index 00000000..3048f72f --- /dev/null +++ b/internal/db/schemadoc/main.go @@ -0,0 +1,149 @@ +package main + +import ( + "log" + "os" + "strings" + + "github.com/olekukonko/tablewriter" + "github.com/pkg/errors" + "gopkg.in/DATA-DOG/go-sqlmock.v2" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gogs.io/gogs/internal/db" +) + +//go:generate go run main.go ../../../docs/dev/database_schema.md + +func main() { + w, err := os.Create(os.Args[1]) + if err != nil { + log.Fatalf("Failed to create file: %v", err) + } + defer func() { _ = w.Close() }() + + conn, _, err := sqlmock.New() + if err != nil { + log.Fatalf("Failed to get mock connection: %v", err) + } + defer func() { _ = conn.Close() }() + + dialectors := []gorm.Dialector{ + postgres.New(postgres.Config{ + Conn: conn, + }), + mysql.New(mysql.Config{ + Conn: conn, + SkipInitializeWithVersion: true, + }), + sqlite.Open(""), + } + collected := make([][]*tableInfo, 0, len(dialectors)) + for i, dialector := range dialectors { + tableInfos, err := generate(dialector) + if err != nil { + log.Fatalf("Failed to get table info of %d: %v", i, err) + } + + collected = append(collected, tableInfos) + } + + for i, ti := range collected[0] { + _, _ = w.WriteString(`# Table "` + ti.Name + `"`) + _, _ = w.WriteString("\n\n") + + _, _ = w.WriteString("```\n") + + table := tablewriter.NewWriter(w) + table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"}) + table.SetBorder(false) + for j, f := range ti.Fields { + table.Append([]string{ + f.Name, f.Column, + strings.ToUpper(f.Type), // PostgreSQL + strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL + strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3 + }) + } + table.Render() + _, _ = w.WriteString("\n") + + _, _ = w.WriteString("Primary keys: ") + _, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", ")) + _, _ = w.WriteString("\n") + + _, _ = w.WriteString("```\n\n") + } +} + +type tableField struct { + Name string + Column string + Type string +} + +type tableInfo struct { + Name string + Fields []*tableField + PrimaryKeys []string +} + +// This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable. +func generate(dialector gorm.Dialector) ([]*tableInfo, error) { + conn, err := gorm.Open(dialector, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + }, + DryRun: true, + DisableAutomaticPing: true, + }) + if err != nil { + return nil, errors.Wrap(err, "open database") + } + + m := conn.Migrator().(interface { + RunWithValue(value interface{}, fc func(*gorm.Statement) error) error + FullDataTypeOf(*schema.Field) clause.Expr + }) + tableInfos := make([]*tableInfo, 0, len(db.Tables)) + for _, table := range db.Tables { + err = m.RunWithValue(table, func(stmt *gorm.Statement) error { + fields := make([]*tableField, 0, len(stmt.Schema.DBNames)) + for _, field := range stmt.Schema.Fields { + if field.DBName == "" { + continue + } + + fields = append(fields, &tableField{ + Name: field.Name, + Column: field.DBName, + Type: m.FullDataTypeOf(field).SQL, + }) + } + + primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields)) + if len(stmt.Schema.PrimaryFields) > 0 { + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, field.DBName) + } + } + + tableInfos = append(tableInfos, &tableInfo{ + Name: stmt.Table, + Fields: fields, + PrimaryKeys: primaryKeys, + }) + return nil + }) + if err != nil { + return nil, errors.Wrap(err, "gather table information") + } + } + + return tableInfos, nil +} |