GO学习记录三

初始化数据库表结构代码。

这次遇到个新手坑,之前直接用代码是成功的,封装后发现不报错也不能成功创建表了。

排查问题后,发现是因为在封装的时候,把defer db.Close()释放连接的方法写在了封装方法initDB()中。

defer 这个关键字的特点是延迟执行,会在函数结束时,应该时return之前执行,所以写在封装方法initDB中,就会导致initDB()运行结束后,直接就释放了数据库连接,致使之后创建表的函数都是处于未创建数据库连接的环境下运行的,导致最后没有创建成功。

go 复制代码
package main

//引用的包
import (
	"database/sql"
	"fmt"
	"log"
	"strings"

	_ "github.com/lib/pq" //pgsql数据库组件
)

// 定义数据库相关配置
const (
	host     = "localhost"        //数据库ip
	port     = 5432               //数据库端口
	user     = "postgres"         //数据库用户名
	password = "postgres"         //数据库密码
	dbname   = "postgresLearning" //数据库名
)

// 启动函数
func main() {
	//初始化数据库连接
	db := initDB()
	defer db.Close() //defer db.Close() 是 Go 语言中一种确保资源被正确释放的惯用写法,它的作用是:
	//在函数返回前,自动调用 db.Close() 来关闭数据库连接,无论函数是正常返回还是发生 panic
	createUserTable(db)  //创建users表
	createTest1Table(db) //创建test1表
	createTest2Table(db) //创建test2表
}

// 创建users表
func createUserTable(db *sql.DB) {
	createTable := Table{
		Name: "users",
		Columns: []Column{
			{Name: "id", Type: SERIAL, Primary: true},
			{Name: "username", Type: VARCHAR, Length: 50, Unique: true, NotNull: true},
			{Name: "password", Type: VARCHAR, Length: 100, NotNull: true},
			{Name: "email", Type: VARCHAR, Length: 100, Unique: true, NotNull: true},
			{Name: "created_at", Type: TIMESTAMP, Default: "CURRENT_TIMESTAMP"},
			{Name: "updated_at", Type: TIMESTAMP, Default: "CURRENT_TIMESTAMP"},
		},
	}
	success, createTableSQL := CreateTable(createTable)
	if success {
		db.Exec(createTableSQL)
		fmt.Println("创建Users数据表完成")
	}
}

// 创建test1表
func createTest1Table(db *sql.DB) {
	createTable := Table{
		Name: "test1",
		Columns: []Column{
			{Name: "id", Type: SERIAL, Primary: true},
			{Name: "name", Type: VARCHAR, Length: 50, Unique: true, NotNull: true},
			{Name: "age", Type: INT, NotNull: true},
			{Name: "created_at", Type: TIMESTAMP, Default: "CURRENT_TIMESTAMP"},
			{Name: "updated_at", Type: TIMESTAMP, Default: "CURRENT_TIMESTAMP"},
		},
	}
	success, createTableSQL := CreateTable(createTable)
	if success {
		db.Exec(createTableSQL)
		fmt.Println("创建Test1数据表完成")
	}
}

// 创建test2表
func createTest2Table(db *sql.DB) {
	createTable := Table{
		Name: "test2",
		Columns: []Column{
			{Name: "id", Type: SERIAL, Primary: true},
			{Name: "name", Type: VARCHAR, Length: 50, Unique: true, NotNull: true},
			{Name: "age", Type: INT, NotNull: true},
			{Name: "created_at", Type: TIMESTAMP, Default: "CURRENT_TIMESTAMP"},
			{Name: "updated_at", Type: TIMESTAMP, Default: "CURRENT_TIMESTAMP"},
		},
	}
	success, createTableSQL := CreateTable(createTable)
	if success {
		db.Exec(createTableSQL)
		fmt.Println("创建Test1数据表完成")
	}
}

// 初始化数据库连接
func initDB() *sql.DB {
	// 构建连接字符串
	psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
		host, port, user, password, dbname)

	// 连接数据库
	db, err := sql.Open("postgres", psqlInfo)
	if err != nil {
		log.Fatal(err)
	}
	//defer db.Close()  这里有一个注意点,这块代码回直接关闭数据库连接
	// 检查连接
	err = db.Ping()
	if err != nil {
		log.Fatal(err)
	}
	fmt.Println("Successfully connected to PostgreSQL database!")
	return db
}

// ColumnItemType 定义列类型的自定义类型
type ColumnItemType string

// 支持的列类型常量
const (
	VARCHAR   ColumnItemType = "VARCHAR"
	TIMESTAMP ColumnItemType = "TIMESTAMP"
	SERIAL    ColumnItemType = "SERIAL"
	TEXT      ColumnItemType = "TEXT"
	DECIMAL   ColumnItemType = "DECIMAL"
	INT       ColumnItemType = "INT"
)

// Column 定义字段结构
type Column struct {
	Name    string
	Type    ColumnItemType
	Length  int // 长度,仅对 VARCHAR、DECIMAL 等有效
	NotNull bool
	Unique  bool
	Primary bool
	Default string
}

// Table 定义表结构
type Table struct {
	Name    string
	Columns []Column
}

// CreateTable 生成 CREATE TABLE SQL 语句
func CreateTable(table Table) (bool, string) {
	if table.Name == "" {
		fmt.Printf("表名不能为空")
		return false, ""
	}
	if len(table.Columns) == 0 {
		fmt.Printf("字段列表不能为空")
		return false, ""
	}

	var fieldDefs []string

	for _, col := range table.Columns {
		if col.Name == "" {
			fmt.Printf("字段名不能为空")
			return false, ""
		}

		def := col.Name + " " + string(col.Type) // 注意:col.Type 是 ColumnItemType,需转为 string

		// 处理长度(仅对支持长度的类型)
		if col.Length > 0 {
			// 可以根据 Type 判断是否支持 Length,例如只对 VARCHAR 和 DECIMAL 生效
			def += fmt.Sprintf("(%d)", col.Length)
		}

		// 添加约束
		if col.NotNull {
			def += " NOT NULL"
		}
		if col.Unique {
			def += " UNIQUE"
		}
		if col.Primary {
			def += " PRIMARY KEY"
		}
		if col.Default != "" {
			def += " DEFAULT " + col.Default
		}

		fieldDefs = append(fieldDefs, def)
	}

	// 拼接完整 SQL
	sql := fmt.Sprintf(
		"CREATE TABLE IF NOT EXISTS %s (%s);",
		table.Name,
		strings.Join(fieldDefs, ", "),
	)

	return true, sql
}