引言
数据库操作是后端开发的核心技能之一。Go语言标准库database/sql提供了通用的SQL数据库接口,而sqlx库在此基础上提供了更便捷的操作方式。本文将深入探讨Go语言数据库编程的各个方面,从连接池管理到事务处理,从SQL注入防护到实际的数据访问层构建。
一、database/sql驱动注册与连接池
1.1 驱动注册机制
Go的database/sql包采用驱动接口分离设计,数据库驱动通过database/sql/driver接口实现注册:
package main
import (
"fmt"
_ "github.com/go-sql-driver/mysql" // 匿名导入,只注册驱动
"database/sql"
)
func main() {
// 驱动注册后,Open函数可以创建数据库连接
// 格式:driver://username:password@host:port/database
db, err := sql.Open("mysql", "user:password@tcp(localhost:3306)/mydb")
if err != nil {
fmt.Printf("Open failed: %v\n", err)
return
}
defer db.Close()
fmt.Printf("Connected to MySQL: %v\n", db.Ping())
}
常见的驱动导入方式:
// MySQL
import _ "github.com/go-sql-driver/mysql"
// PostgreSQL
import _ "github.com/lib/pq"
// SQLite
import _ "github.com/mattn/go-sqlite3"
// ClickHouse
import _ "github.com/ClickHouse/clickhouse-go"
1.2 连接池配置
package main
import (
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
func createDBPool() *sql.DB {
db, err := sql.Open("mysql", "user:password@tcp(localhost:3306)/mydb?parseTime=true")
if err != nil {
log.Fatalf("Open failed: %v", err)
}
// 设置最大打开连接数
db.SetMaxOpenConns(25)
// 设置最大空闲连接数
db.SetMaxIdleConns(10)
// 设置连接最大生命周期
db.SetConnMaxLifetime(5 * time.Minute)
// 设置空闲连接最大生命周期
db.SetConnMaxIdleTime(1 * time.Minute)
// 验证连接
if err := db.Ping(); err != nil {
log.Fatalf("Ping failed: %v", err)
}
return db
}
func main() {
db := createDBPool()
defer db.Close()
fmt.Println("Database connection pool created successfully")
}
1.3 连接池工作原理
package main
import (
"context"
"database/sql"
"fmt"
"log"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
)
func poolDemo() {
db, _ := sql.Open("mysql", "root:password@tcp(localhost:3306)/test?parseTime=true")
db.SetMaxOpenConns(3)
db.SetMaxIdleConns(2)
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
start := time.Now()
var name string
err := db.QueryRowContext(ctx, "SELECT 'User'").Scan(&name)
elapsed := time.Since(start)
fmt.Printf("Goroutine %d: elapsed=%v, err=%v\n", id, elapsed, err)
}(i)
}
wg.Wait()
fmt.Printf("Stats: %+v\n", db.Stats())
}
二、查询结果处理
2.1 Query、QueryRow、Exec区别
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
// Query - 用于返回多行结果
func queryMultipleRows(db *sql.DB) {
rows, err := db.Query("SELECT id, name, email FROM users WHERE age > ?", 18)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id int
var name, email string
if err := rows.Scan(&id, &name, &email); err != nil {
log.Fatal(err)
}
fmt.Printf("User: id=%d, name=%s, email=%s\n", id, name, email)
}
if err := rows.Err(); err != nil {
log.Fatal(err)
}
}
// QueryRow - 用于返回单行结果
func querySingleRow(db *sql.DB) {
var name string
var email string
err := db.QueryRow("SELECT name, email FROM users WHERE id = ?", 1).
Scan(&name, &email)
if err != nil {
if err == sql.ErrNoRows {
fmt.Println("No user found")
} else {
log.Fatal(err)
}
} else {
fmt.Printf("User: name=%s, email=%s\n", name, email)
}
}
// Exec - 用于INSERT、UPDATE、DELETE操作
func executeStatement(db *sql.DB) {
result, err := db.Exec("INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
"张三", "zhangsan@example.com", 25)
if err != nil {
log.Fatal(err)
}
id, _ := result.LastInsertId()
rowsAffected, _ := result.RowsAffected()
fmt.Printf("Inserted: id=%d, rows_affected=%d\n", id, rowsAffected)
result, err = db.Exec("UPDATE users SET age = age + 1 WHERE id = ?", id)
if err != nil {
log.Fatal(err)
}
rowsAffected, _ = result.RowsAffected()
fmt.Printf("Updated: rows_affected=%d\n", rowsAffected)
}
2.2 Scan与类型转换
package main
import (
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
func scanTypes(db *sql.DB) {
// 创建测试表
db.Exec("CREATE TEMPORARY TABLE test_types (int_col INT, bigint_col BIGINT, varchar_col VARCHAR(50), bool_col BOOLEAN, time_col TIMESTAMP, decimal_col DECIMAL(10,2))")
// 插入测试数据
db.Exec("INSERT INTO test_types VALUES (?, ?, ?, ?, ?, ?)",
42, 1234567890123, "hello", true, time.Now(), 99.99)
// 扫描到不同类型
var (
intVal int
bigintVal int64
strVal string
boolVal bool
timeVal time.Time
decimalVal float64
)
err := db.QueryRow("SELECT * FROM test_types").
Scan(&intVal, &bigintVal, &strVal, &boolVal, &timeVal, &decimalVal)
if err != nil {
log.Fatal(err)
}
fmt.Printf("intVal=%d, bigintVal=%d, strVal=%s\n", intVal, bigintVal, strVal)
fmt.Printf("boolVal=%v, timeVal=%v, decimalVal=%.2f\n", boolVal, timeVal, decimalVal)
}
// Null类型处理
func handleNulls(db *sql.DB) {
db.Exec("CREATE TEMPORARY TABLE null_test (name VARCHAR(50), nickname VARCHAR(50))")
db.Exec("INSERT INTO null_test VALUES ('Alice', NULL), ('Bob', 'Bobby')")
rows, _ := db.Query("SELECT name, nickname FROM null_test")
for rows.Next() {
var name string
var nickname sql.NullString // 处理NULL值
if err := rows.Scan(&name, &nickname); err != nil {
log.Fatal(err)
}
if nickname.Valid {
fmt.Printf("name=%s, nickname=%s\n", name, nickname.String)
} else {
fmt.Printf("name=%s, nickname=NULL\n", name)
}
}
}
2.3 通用类型扫描
package main
import (
"database/sql"
"encoding/json"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
func genericScan(rows *sql.Rows) ([]map[string]interface{}, error) {
// 获取列信息
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 创建结果切片
results := make([]map[string]interface{}, 0)
// 创建扫描目标
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
row := make(map[string]interface{})
for i, col := range columns {
val := values[i]
// 处理字节切片
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
results = append(results, row)
}
return results, rows.Err()
}
func main() {
db, _ := sql.Open("mysql", "root:password@tcp(localhost:3306)/test?parseTime=true")
defer db.Close()
rows, err := db.Query("SELECT 1 as num, 'hello' as str, true as flag")
if err != nil {
log.Fatal(err)
}
defer rows.Close()
results, err := genericScan(rows)
if err != nil {
log.Fatal(err)
}
data, _ := json.MarshalIndent(results, "", " ")
fmt.Printf("Results: %s\n", data)
}
三、预处理语句
3.1 Prepare与Stmt
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
func preparedStatementDemo(db *sql.DB) {
// 预处理SQL语句
stmt, err := db.Prepare("SELECT name, email FROM users WHERE age > ? AND active = ?")
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
// 多次执行同一预处理语句
ages := []int{18, 25, 30}
for _, age := range ages {
rows, err := stmt.Query(age, true)
if err != nil {
log.Printf("Query error for age %d: %v", age, err)
continue
}
fmt.Printf("Users older than %d:\n", age)
for rows.Next() {
var name, email string
rows.Scan(&name, &email)
fmt.Printf(" - %s (%s)\n", name, email)
}
rows.Close()
}
}
// Exec也可以使用预处理
func preparedExecDemo(db *sql.DB) {
stmt, err := db.Prepare("UPDATE users SET last_login = NOW() WHERE id = ?")
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
userIDs := []int{1, 2, 3, 4, 5}
for _, id := range userIDs {
result, err := stmt.Exec(id)
if err != nil {
log.Printf("Update error for id %d: %v", id, err)
continue
}
affected, _ := result.RowsAffected()
fmt.Printf("Updated user %d: %d rows affected\n", id, affected)
}
}
3.2 预处理语句的优势
-
性能提升:服务器只需解析一次SQL结构
-
SQL注入防护:参数自动转义
-
代码复用:减少重复SQL字符串
四、事务处理
4.1 基本事务操作
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
func transferMoney(db *sql.DB, fromID, toID int, amount float64) error {
// 开始事务
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("begin transaction failed: %w", err)
}
// 确保失败时回滚
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
// 扣款
result, err := tx.Exec("UPDATE accounts SET balance = balance - ? WHERE id = ? AND balance >= ?",
amount, fromID, amount)
if err != nil {
tx.Rollback()
return fmt.Errorf("deduct failed: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
tx.Rollback()
return fmt.Errorf("insufficient balance or account not found")
}
// 加款
_, err = tx.Exec("UPDATE accounts SET balance = balance + ? WHERE id = ?", amount, toID)
if err != nil {
tx.Rollback()
return fmt.Errorf("credit failed: %w", err)
}
// 提交事务
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit failed: %w", err)
}
return nil
}
func main() {
db, _ := sql.Open("mysql", "root:password@tcp(localhost:3306)/bank?parseTime=true")
defer db.Close()
if err := transferMoney(db, 1, 2, 100.00); err != nil {
log.Printf("Transfer failed: %v", err)
} else {
fmt.Println("Transfer successful")
}
}
4.2 Savepoint与嵌套事务
package main
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
)
func savepointDemo(db *sql.DB) {
tx, _ := db.Begin()
// 创建保存点
tx.Exec("INSERT INTO logs (message) VALUES ('start transaction')")
tx.Exec("SAVEPOINT sp1")
// 执行可能失败的操作
tx.Exec("INSERT INTO logs (message) VALUES ('operation 1')")
// 回滚到保存点
tx.Exec("ROLLBACK TO SAVEPOINT sp1")
// 继续执行
tx.Exec("INSERT INTO logs (message) VALUES ('operation 2')")
tx.Commit()
}
4.3 事务隔离级别
package main
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
)
func isolationLevelDemo(db *sql.DB) {
// 查看当前隔离级别
var isolationLevel string
db.QueryRow("SELECT @@tx_isolation").Scan(&isolationLevel)
fmt.Printf("Current isolation level: %s\n", isolationLevel)
// 设置隔离级别(MySQL)
tx, _ := db.Begin()
tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
// ... 执行操作
tx.Commit()
}
五、sqlx扩展库的高级用法
5.1 sqlx基础使用
package main
import (
"fmt"
"log"
"github.com/jmoiron/sqlx"
_ "github.com/go-sql-driver/mysql"
)
type User struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
}
func sqlxBasicDemo(db *sqlx.DB) {
// 查询单行
var user User
err := db.Get(&user, "SELECT * FROM users WHERE id = ?", 1)
if err != nil {
log.Fatal(err)
}
fmt.Printf("User: %+v\n", user)
// 查询多行
var users []User
err = db.Select(&users, "SELECT * FROM users WHERE age > ?", 18)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Users: %+v\n", users)
// Named Exec
_, err = db.NamedExec("INSERT INTO users (name, email, age) VALUES (:name, :email, :age)",
map[string]interface{}{
"name": "新用户",
"email": "new@example.com",
"age": 30,
})
if err != nil {
log.Fatal(err)
}
}
5.2 批量操作
package main
import (
"fmt"
"log"
"github.com/jmoiron/sqlx"
_ "github.com/go-sql-driver/mysql"
)
type Product struct {
ID int `db:"id"`
Name string `db:"name"`
Price float64 `db:"price"`
Quantity int `db:"quantity"`
}
func batchInsertDemo(db *sqlx.DB) {
// 批量插入
products := []Product{
{Name: "iPhone", Price: 9999.00, Quantity: 100},
{Name: "iPad", Price: 5999.00, Quantity: 50},
{Name: "MacBook", Price: 19999.00, Quantity: 30},
}
// 使用sqlx.In进行批量操作
query, args, _ := sqlx.In(
"INSERT INTO products (name, price, quantity) VALUES (?), (?), (?)",
products[0].Name, products[1].Name, products[2].Name,
products[0].Price, products[1].Price, products[2].Price,
products[0].Quantity, products[1].Quantity, products[2].Quantity,
)
fmt.Printf("Query: %s\nArgs: %v\n", query, args)
_, err := db.Exec(query, args...)
if err != nil {
log.Fatal(err)
}
}
func batchUpdateDemo(db *sqlx.DB) {
// 批量更新价格
priceUpdates := []struct {
ID int
Price float64
}{
{1, 8999.00},
{2, 5499.00},
{3, 17999.00},
}
tx, _ := db.Beginx()
for _, update := range priceUpdates {
_, err := tx.Exec("UPDATE products SET price = ? WHERE id = ?", update.Price, update.ID)
if err != nil {
tx.Rollback()
log.Fatal(err)
}
}
tx.Commit()
}
5.3 动态查询构建
package main
import (
"fmt"
"log"
"strings"
"github.com/jmoiron/sqlx"
_ "github.com/go-sql-driver/mysql"
)
type UserFilter struct {
Name string
Email string
AgeMin int
AgeMax int
Active *bool
}
func buildDynamicQuery(db *sqlx.DB, filter UserFilter) {
// 构建动态SQL
conditions := []string{}
args := []interface{}{}
if filter.Name != "" {
conditions = append(conditions, "name LIKE ?")
args = append(args, "%"+filter.Name+"%")
}
if filter.Email != "" {
conditions = append(conditions, "email LIKE ?")
args = append(args, "%"+filter.Email+"%")
}
if filter.AgeMin > 0 {
conditions = append(conditions, "age >= ?")
args = append(args, filter.AgeMin)
}
if filter.AgeMax > 0 {
conditions = append(conditions, "age <= ?")
args = append(args, filter.AgeMax)
}
if filter.Active != nil {
conditions = append(conditions, "active = ?")
args = append(args, *filter.Active)
}
// 构建最终SQL
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
query := fmt.Sprintf("SELECT * FROM users %s ORDER BY id", whereClause)
fmt.Printf("Query: %s\nArgs: %v\n", query, args)
var users []User
err := db.Select(&users, query, args...)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Found %d users\n", len(users))
}
六、SQL注入防护
6.1 注入攻击原理
package main
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
)
// 危险:直接拼接SQL
func dangerousQuery(db *sql.DB, username string) {
// 恶意输入: "admin' OR '1'='1"
query := "SELECT * FROM users WHERE username = '" + username + "'"
fmt.Printf("Query: %s\n", query)
rows, _ := db.Query(query)
defer rows.Close()
for rows.Next() {
var id int
var name string
rows.Scan(&id, &name)
fmt.Printf("User: id=%d, name=%s\n", id, name)
}
}
// 安全:使用参数化查询
func safeQuery(db *sql.DB, username string) {
query := "SELECT * FROM users WHERE username = ?"
fmt.Printf("Query: %s\n", query)
rows, err := db.Query(query, username)
if err != nil {
fmt.Printf("Query error: %v\n", err)
return
}
defer rows.Close()
for rows.Next() {
var id int
var name string
rows.Scan(&id, &name)
fmt.Printf("User: id=%d, name=%s\n", id, name)
}
}
6.2 最佳实践
-
始终使用参数化查询
-
使用ORM或sqlx等库减少手写SQL
-
输入验证和过滤
-
最小权限原则
package main
import (
"fmt"
"regexp"
"strings"
"github.com/jmoiron/sqlx"
_ "github.com/go-sql-driver/mysql"
)
// 输入验证
func validateInput(input string) bool {
// 只允许字母、数字和下划线
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_]+$`, input)
return matched
}
// 清理输入
func sanitizeInput(input string) string {
// 转义特殊字符
input = strings.ReplaceAll(input, "'", "''")
input = strings.ReplaceAll(input, "\\", "\\\\")
return input
}
func main() {
db, _ := sqlx.Connect("mysql", "root:password@tcp(localhost:3306)/test?parseTime=true")
defer db.Close()
username := "admin"
if !validateInput(username) {
fmt.Println("Invalid input")
return
}
var user struct {
ID int `db:"id"`
Name string `db:"name"`
}
err := db.Get(&user, "SELECT * FROM users WHERE username = ?", username)
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("User: %+v\n", user)
}
七、实际案例:构建数据访问层
7.1 项目结构设计
dal/
├── dal.go # 初始化和配置
├── user.go # 用户数据访问
├── product.go # 产品数据访问
├── order.go # 订单数据访问
└── repository.go # 泛型Repository模式
7.2 数据访问层实现
package dal
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/jmoiron/sqlx"
_ "github.com/go-sql-driver/mysql"
)
type Config struct {
Host string
Port int
User string
Password string
Database string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
}
type DAL struct {
DB *sqlx.DB
}
// NewDAL 创建DAL实例
func NewDAL(cfg Config) (*DAL, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err := sqlx.Connect("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("connect to database failed: %w", err)
}
// 配置连接池
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
return &DAL{DB: db}, nil
}
// Close 关闭数据库连接
func (d *DAL) Close() error {
return d.DB.Close()
}
// WithTransaction 执行事务
func (d *DAL) WithTransaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
tx, err := d.DB.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("begin transaction failed: %w", err)
}
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit failed: %w", err)
}
return nil
}
7.3 用户仓储实现
package dal
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/jmoiron/sqlx"
)
type User struct {
ID int `db:"id" json:"id"`
Username string `db:"username" json:"username"`
Email string `db:"email" json:"email"`
Password string `db:"password_hash" json:"-"`
Age int `db:"age" json:"age"`
Active bool `db:"active" json:"active"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type UserRepository struct {
dal *DAL
}
func NewUserRepository(dal *DAL) *UserRepository {
return &UserRepository{dal: dal}
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(ctx context.Context, id int) (*User, error) {
var user User
query := "SELECT * FROM users WHERE id = ?"
err := r.dal.DB.GetContext(ctx, &user, query, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get user by id failed: %w", err)
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*User, error) {
var user User
query := "SELECT * FROM users WHERE username = ?"
err := r.dal.DB.GetContext(ctx, &user, query, username)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get user by username failed: %w", err)
}
return &user, nil
}
// List 分页获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]User, int, error) {
var users []User
query := "SELECT * FROM users ORDER BY id LIMIT ? OFFSET ?"
err := r.dal.DB.SelectContext(ctx, &users, query, limit, offset)
if err != nil {
return nil, 0, fmt.Errorf("list users failed: %w", err)
}
var total int
countQuery := "SELECT COUNT(*) FROM users"
if err := r.dal.DB.GetContext(ctx, &total, countQuery); err != nil {
return nil, 0, fmt.Errorf("count users failed: %w", err)
}
return users, total, nil
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *User) error {
query := `INSERT INTO users (username, email, password_hash, age, active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, NOW(), NOW())`
result, err := r.dal.DB.ExecContext(ctx, query,
user.Username, user.Email, user.Password, user.Age, user.Active)
if err != nil {
return fmt.Errorf("create user failed: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("get last insert id failed: %w", err)
}
user.ID = int(id)
return nil
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, user *User) error {
query := `UPDATE users SET username = ?, email = ?, age = ?, active = ?, updated_at = NOW()
WHERE id = ?`
result, err := r.dal.DB.ExecContext(ctx, query,
user.Username, user.Email, user.Age, user.Active, user.ID)
if err != nil {
return fmt.Errorf("update user failed: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
return errors.New("user not found")
}
return nil
}
// Delete 删除用户
func (r *UserRepository) Delete(ctx context.Context, id int) error {
query := "DELETE FROM users WHERE id = ?"
result, err := r.dal.DB.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("delete user failed: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
return errors.New("user not found")
}
return nil
}
// WithTransaction 在事务中操作
func (r *UserRepository) CreateWithProfile(ctx context.Context, user *User, profile map[string]string) error {
return r.dal.WithTransaction(ctx, func(tx *sqlx.Tx) error {
// 创建用户
query := `INSERT INTO users (username, email, password_hash, age, active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, NOW(), NOW())`
result, err := tx.ExecContext(ctx, query,
user.Username, user.Email, user.Password, user.Age, user.Active)
if err != nil {
return fmt.Errorf("create user failed: %w", err)
}
userID, _ := result.LastInsertId()
user.ID = int(userID)
// 创建用户资料
profileQuery := "INSERT INTO user_profiles (user_id, key, value) VALUES (?, ?, ?)"
for k, v := range profile {
_, err := tx.ExecContext(ctx, profileQuery, userID, k, v)
if err != nil {
return fmt.Errorf("create profile failed: %w", err)
}
}
return nil
})
}
7.4 订单仓储实现
package dal
import (
"context"
"errors"
"fmt"
"github.com/jmoiron/sqlx"
)
type OrderItem struct {
ID int `db:"id" json:"id"`
OrderID int `db:"order_id" json:"order_id"`
ProductID int `db:"product_id" json:"product_id"`
ProductName string `db:"product_name" json:"product_name"`
Price float64 `db:"price" json:"price"`
Quantity int `db:"quantity" json:"quantity"`
}
type Order struct {
ID int `db:"id" json:"id"`
UserID int `db:"user_id" json:"user_id"`
Status string `db:"status" json:"status"`
Total float64 `db:"total" json:"total"`
Items []OrderItem `db:"-" json:"items"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type OrderRepository struct {
dal *DAL
}
func NewOrderRepository(dal *DAL) *OrderRepository {
return &OrderRepository{dal: dal}
}
// Create 创建订单
func (r *OrderRepository) Create(ctx context.Context, order *Order) error {
return r.dal.WithTransaction(ctx, func(tx *sqlx.Tx) error {
// 创建订单头
orderQuery := `INSERT INTO orders (user_id, status, total, created_at, updated_at)
VALUES (?, ?, ?, NOW(), NOW())`
result, err := tx.ExecContext(ctx, orderQuery, order.UserID, order.Status, order.Total)
if err != nil {
return fmt.Errorf("create order failed: %w", err)
}
orderID, _ := result.LastInsertId()
order.ID = int(orderID)
// 创建订单明细
itemQuery := `INSERT INTO order_items (order_id, product_id, product_name, price, quantity)
VALUES (?, ?, ?, ?, ?)`
for i := range order.Items {
item := &order.Items[i]
item.OrderID = int(orderID)
result, err := tx.ExecContext(ctx, itemQuery,
item.OrderID, item.ProductID, item.ProductName, item.Price, item.Quantity)
if err != nil {
return fmt.Errorf("create order item failed: %w", err)
}
itemID, _ := result.LastInsertId()
item.ID = int(itemID)
}
return nil
})
}
// GetByID 获取订单详情
func (r *OrderRepository) GetByID(ctx context.Context, id int) (*Order, error) {
var order Order
query := "SELECT * FROM orders WHERE id = ?"
if err := r.dal.DB.GetContext(ctx, &order, query, id); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get order failed: %w", err)
}
// 获取订单明细
var items []OrderItem
itemsQuery := "SELECT * FROM order_items WHERE order_id = ?"
if err := r.dal.DB.SelectContext(ctx, &items, itemsQuery, id); err != nil {
return nil, fmt.Errorf("get order items failed: %w", err)
}
order.Items = items
return &order, nil
}
// GetByUserID 获取用户的订单列表
func (r *OrderRepository) GetByUserID(ctx context.Context, userID int) ([]Order, error) {
var orders []Order
query := "SELECT * FROM orders WHERE user_id = ? ORDER BY created_at DESC"
if err := r.dal.DB.SelectContext(ctx, &orders, query, userID); err != nil {
return nil, fmt.Errorf("get user orders failed: %w", err)
}
return orders, nil
}
// UpdateStatus 更新订单状态
func (r *OrderRepository) UpdateStatus(ctx context.Context, id int, status string) error {
query := "UPDATE orders SET status = ?, updated_at = NOW() WHERE id = ?"
result, err := r.dal.DB.ExecContext(ctx, query, status, id)
if err != nil {
return fmt.Errorf("update order status failed: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
return errors.New("order not found")
}
return nil
}
7.5 使用示例
package main
import (
"context"
"fmt"
"log"
"time"
"mypackage/dal"
)
func main() {
// 初始化DAL
database, err := dal.NewDAL(dal.Config{
Host: "localhost",
Port: 3306,
User: "root",
Password: "password",
Database: "ecommerce",
MaxOpenConns: 25,
MaxIdleConns: 10,
ConnMaxLifetime: 5 * time.Minute,
})
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer database.Close()
userRepo := dal.NewUserRepository(database)
orderRepo := dal.NewOrderRepository(database)
ctx := context.Background()
// 创建用户
user := &dal.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashed_password",
Age: 25,
Active: true,
}
if err := userRepo.Create(ctx, user); err != nil {
log.Fatalf("Failed to create user: %v", err)
}
fmt.Printf("User created: ID=%d\n", user.ID)
// 查询用户
fetchedUser, err := userRepo.GetByID(ctx, user.ID)
if err != nil {
log.Fatalf("Failed to get user: %v", err)
}
fmt.Printf("User fetched: %+v\n", fetchedUser)
// 创建订单
order := &dal.Order{
UserID: user.ID,
Status: "pending",
Total: 299.99,
Items: []dal.OrderItem{
{ProductID: 1, ProductName: "iPhone", Price: 9999.00, Quantity: 1},
},
}
if err := orderRepo.Create(ctx, order); err != nil {
log.Fatalf("Failed to create order: %v", err)
}
fmt.Printf("Order created: ID=%d\n", order.ID)
// 查询订单
fetchedOrder, err := orderRepo.GetByID(ctx, order.ID)
if err != nil {
log.Fatalf("Failed to get order: %v", err)
}
fmt.Printf("Order fetched: %+v\n", fetchedOrder)
}
总结
本文全面介绍了Go语言数据库编程的各个方面:
-
驱动注册与连接池 :理解
database/sql的驱动接口分离设计和连接池配置。 -
查询处理 :掌握
Query、QueryRow、Exec的使用场景,以及正确的Scan方式。 -
预处理语句 :使用
Prepare创建预处理语句,提升性能并防止SQL注入。 -
事务处理 :使用
Begin、Commit、Rollback进行事务管理,掌握保存点的使用。 -
sqlx扩展 :利用
sqlx库的便捷功能简化数据库操作。 -
SQL注入防护:始终使用参数化查询,避免SQL拼接。
-
数据访问层设计:构建结构清晰、错误处理完善、事务支持良好的DAL层。
良好的数据访问层设计是构建可维护、可扩展后端系统的基础。通过本文介绍的技术和模式,开发者可以编写出高效、安全、健壮的数据库操作代码。