Go语言database/sql与SQLx:构建健壮的数据访问层

引言

数据库操作是后端开发的核心技能之一。Go语言标准库database/sql提供了通用的SQL数据库接口,而sqlx库在此基础上提供了更便捷的操作方式。本文将深入探讨Go语言数据库编程的各个方面,从连接池管理到事务处理,从SQL注入防护到实际的数据访问层构建。

一、database/sql驱动注册与连接池

1.1 驱动注册机制

Go的database/sql包采用驱动接口分离设计,数据库驱动通过database/sql/driver接口实现注册:

复制代码
package main
​
import (
    "fmt"
    _ "github.com/go-sql-driver/mysql"  // 匿名导入,只注册驱动
    "database/sql"
)
​
func main() {
    // 驱动注册后,Open函数可以创建数据库连接
    // 格式:driver://username:password@host:port/database
    db, err := sql.Open("mysql", "user:password@tcp(localhost:3306)/mydb")
    if err != nil {
        fmt.Printf("Open failed: %v\n", err)
        return
    }
    defer db.Close()
​
    fmt.Printf("Connected to MySQL: %v\n", db.Ping())
}

常见的驱动导入方式:

复制代码
// MySQL
import _ "github.com/go-sql-driver/mysql"
​
// PostgreSQL
import _ "github.com/lib/pq"
​
// SQLite
import _ "github.com/mattn/go-sqlite3"
​
// ClickHouse
import _ "github.com/ClickHouse/clickhouse-go"

1.2 连接池配置

复制代码
package main
​
import (
    "database/sql"
    "fmt"
    "log"
    "time"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func createDBPool() *sql.DB {
    db, err := sql.Open("mysql", "user:password@tcp(localhost:3306)/mydb?parseTime=true")
    if err != nil {
        log.Fatalf("Open failed: %v", err)
    }
​
    // 设置最大打开连接数
    db.SetMaxOpenConns(25)
​
    // 设置最大空闲连接数
    db.SetMaxIdleConns(10)
​
    // 设置连接最大生命周期
    db.SetConnMaxLifetime(5 * time.Minute)
​
    // 设置空闲连接最大生命周期
    db.SetConnMaxIdleTime(1 * time.Minute)
​
    // 验证连接
    if err := db.Ping(); err != nil {
        log.Fatalf("Ping failed: %v", err)
    }
​
    return db
}
​
func main() {
    db := createDBPool()
    defer db.Close()
​
    fmt.Println("Database connection pool created successfully")
}

1.3 连接池工作原理

复制代码
package main
​
import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "sync"
    "time"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func poolDemo() {
    db, _ := sql.Open("mysql", "root:password@tcp(localhost:3306)/test?parseTime=true")
    db.SetMaxOpenConns(3)
    db.SetMaxIdleConns(2)
​
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
            defer cancel()
​
            start := time.Now()
            var name string
            err := db.QueryRowContext(ctx, "SELECT 'User'").Scan(&name)
            elapsed := time.Since(start)
​
            fmt.Printf("Goroutine %d: elapsed=%v, err=%v\n", id, elapsed, err)
        }(i)
    }
    wg.Wait()
​
    fmt.Printf("Stats: %+v\n", db.Stats())
}

二、查询结果处理

2.1 Query、QueryRow、Exec区别

复制代码
package main
​
import (
    "database/sql"
    "fmt"
    "log"
​
    _ "github.com/go-sql-driver/mysql"
)
​
// Query - 用于返回多行结果
func queryMultipleRows(db *sql.DB) {
    rows, err := db.Query("SELECT id, name, email FROM users WHERE age > ?", 18)
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()
​
    for rows.Next() {
        var id int
        var name, email string
        if err := rows.Scan(&id, &name, &email); err != nil {
            log.Fatal(err)
        }
        fmt.Printf("User: id=%d, name=%s, email=%s\n", id, name, email)
    }
​
    if err := rows.Err(); err != nil {
        log.Fatal(err)
    }
}
​
// QueryRow - 用于返回单行结果
func querySingleRow(db *sql.DB) {
    var name string
    var email string
​
    err := db.QueryRow("SELECT name, email FROM users WHERE id = ?", 1).
        Scan(&name, &email)
    if err != nil {
        if err == sql.ErrNoRows {
            fmt.Println("No user found")
        } else {
            log.Fatal(err)
        }
    } else {
        fmt.Printf("User: name=%s, email=%s\n", name, email)
    }
}
​
// Exec - 用于INSERT、UPDATE、DELETE操作
func executeStatement(db *sql.DB) {
    result, err := db.Exec("INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
        "张三", "zhangsan@example.com", 25)
    if err != nil {
        log.Fatal(err)
    }
​
    id, _ := result.LastInsertId()
    rowsAffected, _ := result.RowsAffected()
    fmt.Printf("Inserted: id=%d, rows_affected=%d\n", id, rowsAffected)
​
    result, err = db.Exec("UPDATE users SET age = age + 1 WHERE id = ?", id)
    if err != nil {
        log.Fatal(err)
    }
    rowsAffected, _ = result.RowsAffected()
    fmt.Printf("Updated: rows_affected=%d\n", rowsAffected)
}

2.2 Scan与类型转换

复制代码
package main
​
import (
    "database/sql"
    "fmt"
    "log"
    "time"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func scanTypes(db *sql.DB) {
    // 创建测试表
    db.Exec("CREATE TEMPORARY TABLE test_types (int_col INT, bigint_col BIGINT, varchar_col VARCHAR(50), bool_col BOOLEAN, time_col TIMESTAMP, decimal_col DECIMAL(10,2))")
​
    // 插入测试数据
    db.Exec("INSERT INTO test_types VALUES (?, ?, ?, ?, ?, ?)",
        42, 1234567890123, "hello", true, time.Now(), 99.99)
​
    // 扫描到不同类型
    var (
        intVal    int
        bigintVal int64
        strVal    string
        boolVal   bool
        timeVal   time.Time
        decimalVal float64
    )
​
    err := db.QueryRow("SELECT * FROM test_types").
        Scan(&intVal, &bigintVal, &strVal, &boolVal, &timeVal, &decimalVal)
    if err != nil {
        log.Fatal(err)
    }
​
    fmt.Printf("intVal=%d, bigintVal=%d, strVal=%s\n", intVal, bigintVal, strVal)
    fmt.Printf("boolVal=%v, timeVal=%v, decimalVal=%.2f\n", boolVal, timeVal, decimalVal)
}
​
// Null类型处理
func handleNulls(db *sql.DB) {
    db.Exec("CREATE TEMPORARY TABLE null_test (name VARCHAR(50), nickname VARCHAR(50))")
    db.Exec("INSERT INTO null_test VALUES ('Alice', NULL), ('Bob', 'Bobby')")
​
    rows, _ := db.Query("SELECT name, nickname FROM null_test")
    for rows.Next() {
        var name string
        var nickname sql.NullString  // 处理NULL值
​
        if err := rows.Scan(&name, &nickname); err != nil {
            log.Fatal(err)
        }
​
        if nickname.Valid {
            fmt.Printf("name=%s, nickname=%s\n", name, nickname.String)
        } else {
            fmt.Printf("name=%s, nickname=NULL\n", name)
        }
    }
}

2.3 通用类型扫描

复制代码
package main
​
import (
    "database/sql"
    "encoding/json"
    "fmt"
    "log"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func genericScan(rows *sql.Rows) ([]map[string]interface{}, error) {
    // 获取列信息
    columns, err := rows.Columns()
    if err != nil {
        return nil, err
    }
​
    // 创建结果切片
    results := make([]map[string]interface{}, 0)
​
    // 创建扫描目标
    values := make([]interface{}, len(columns))
    valuePtrs := make([]interface{}, len(columns))
    for i := range values {
        valuePtrs[i] = &values[i]
    }
​
    for rows.Next() {
        if err := rows.Scan(valuePtrs...); err != nil {
            return nil, err
        }
​
        row := make(map[string]interface{})
        for i, col := range columns {
            val := values[i]
            // 处理字节切片
            if b, ok := val.([]byte); ok {
                row[col] = string(b)
            } else {
                row[col] = val
            }
        }
        results = append(results, row)
    }
​
    return results, rows.Err()
}
​
func main() {
    db, _ := sql.Open("mysql", "root:password@tcp(localhost:3306)/test?parseTime=true")
    defer db.Close()
​
    rows, err := db.Query("SELECT 1 as num, 'hello' as str, true as flag")
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()
​
    results, err := genericScan(rows)
    if err != nil {
        log.Fatal(err)
    }
​
    data, _ := json.MarshalIndent(results, "", "  ")
    fmt.Printf("Results: %s\n", data)
}

三、预处理语句

3.1 Prepare与Stmt

复制代码
package main
​
import (
    "database/sql"
    "fmt"
    "log"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func preparedStatementDemo(db *sql.DB) {
    // 预处理SQL语句
    stmt, err := db.Prepare("SELECT name, email FROM users WHERE age > ? AND active = ?")
    if err != nil {
        log.Fatal(err)
    }
    defer stmt.Close()
​
    // 多次执行同一预处理语句
    ages := []int{18, 25, 30}
    for _, age := range ages {
        rows, err := stmt.Query(age, true)
        if err != nil {
            log.Printf("Query error for age %d: %v", age, err)
            continue
        }
​
        fmt.Printf("Users older than %d:\n", age)
        for rows.Next() {
            var name, email string
            rows.Scan(&name, &email)
            fmt.Printf("  - %s (%s)\n", name, email)
        }
        rows.Close()
    }
}
​
// Exec也可以使用预处理
func preparedExecDemo(db *sql.DB) {
    stmt, err := db.Prepare("UPDATE users SET last_login = NOW() WHERE id = ?")
    if err != nil {
        log.Fatal(err)
    }
    defer stmt.Close()
​
    userIDs := []int{1, 2, 3, 4, 5}
    for _, id := range userIDs {
        result, err := stmt.Exec(id)
        if err != nil {
            log.Printf("Update error for id %d: %v", id, err)
            continue
        }
        affected, _ := result.RowsAffected()
        fmt.Printf("Updated user %d: %d rows affected\n", id, affected)
    }
}

3.2 预处理语句的优势

  1. 性能提升:服务器只需解析一次SQL结构

  2. SQL注入防护:参数自动转义

  3. 代码复用:减少重复SQL字符串

四、事务处理

4.1 基本事务操作

复制代码
package main
​
import (
    "database/sql"
    "fmt"
    "log"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func transferMoney(db *sql.DB, fromID, toID int, amount float64) error {
    // 开始事务
    tx, err := db.Begin()
    if err != nil {
        return fmt.Errorf("begin transaction failed: %w", err)
    }
​
    // 确保失败时回滚
    defer func() {
        if r := recover(); r != nil {
            tx.Rollback()
            panic(r)
        }
    }()
​
    // 扣款
    result, err := tx.Exec("UPDATE accounts SET balance = balance - ? WHERE id = ? AND balance >= ?",
        amount, fromID, amount)
    if err != nil {
        tx.Rollback()
        return fmt.Errorf("deduct failed: %w", err)
    }
​
    affected, _ := result.RowsAffected()
    if affected == 0 {
        tx.Rollback()
        return fmt.Errorf("insufficient balance or account not found")
    }
​
    // 加款
    _, err = tx.Exec("UPDATE accounts SET balance = balance + ? WHERE id = ?", amount, toID)
    if err != nil {
        tx.Rollback()
        return fmt.Errorf("credit failed: %w", err)
    }
​
    // 提交事务
    if err := tx.Commit(); err != nil {
        return fmt.Errorf("commit failed: %w", err)
    }
​
    return nil
}
​
func main() {
    db, _ := sql.Open("mysql", "root:password@tcp(localhost:3306)/bank?parseTime=true")
    defer db.Close()
​
    if err := transferMoney(db, 1, 2, 100.00); err != nil {
        log.Printf("Transfer failed: %v", err)
    } else {
        fmt.Println("Transfer successful")
    }
}

4.2 Savepoint与嵌套事务

复制代码
package main
​
import (
    "database/sql"
    "fmt"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func savepointDemo(db *sql.DB) {
    tx, _ := db.Begin()
​
    // 创建保存点
    tx.Exec("INSERT INTO logs (message) VALUES ('start transaction')")
    tx.Exec("SAVEPOINT sp1")
​
    // 执行可能失败的操作
    tx.Exec("INSERT INTO logs (message) VALUES ('operation 1')")
​
    // 回滚到保存点
    tx.Exec("ROLLBACK TO SAVEPOINT sp1")
​
    // 继续执行
    tx.Exec("INSERT INTO logs (message) VALUES ('operation 2')")
​
    tx.Commit()
}

4.3 事务隔离级别

复制代码
package main
​
import (
    "database/sql"
    "fmt"
​
    _ "github.com/go-sql-driver/mysql"
)
​
func isolationLevelDemo(db *sql.DB) {
    // 查看当前隔离级别
    var isolationLevel string
    db.QueryRow("SELECT @@tx_isolation").Scan(&isolationLevel)
    fmt.Printf("Current isolation level: %s\n", isolationLevel)
​
    // 设置隔离级别(MySQL)
    tx, _ := db.Begin()
    tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
    // ... 执行操作
    tx.Commit()
}

五、sqlx扩展库的高级用法

5.1 sqlx基础使用

复制代码
package main
​
import (
    "fmt"
    "log"
​
    "github.com/jmoiron/sqlx"
    _ "github.com/go-sql-driver/mysql"
)
​
type User struct {
    ID    int    `db:"id"`
    Name  string `db:"name"`
    Email string `db:"email"`
    Age   int    `db:"age"`
}
​
func sqlxBasicDemo(db *sqlx.DB) {
    // 查询单行
    var user User
    err := db.Get(&user, "SELECT * FROM users WHERE id = ?", 1)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("User: %+v\n", user)
​
    // 查询多行
    var users []User
    err = db.Select(&users, "SELECT * FROM users WHERE age > ?", 18)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("Users: %+v\n", users)
​
    // Named Exec
    _, err = db.NamedExec("INSERT INTO users (name, email, age) VALUES (:name, :email, :age)",
        map[string]interface{}{
            "name":  "新用户",
            "email": "new@example.com",
            "age":   30,
        })
    if err != nil {
        log.Fatal(err)
    }
}

5.2 批量操作

复制代码
package main
​
import (
    "fmt"
    "log"
​
    "github.com/jmoiron/sqlx"
    _ "github.com/go-sql-driver/mysql"
)
​
type Product struct {
    ID       int     `db:"id"`
    Name     string  `db:"name"`
    Price    float64 `db:"price"`
    Quantity int     `db:"quantity"`
}
​
func batchInsertDemo(db *sqlx.DB) {
    // 批量插入
    products := []Product{
        {Name: "iPhone", Price: 9999.00, Quantity: 100},
        {Name: "iPad", Price: 5999.00, Quantity: 50},
        {Name: "MacBook", Price: 19999.00, Quantity: 30},
    }
​
    // 使用sqlx.In进行批量操作
    query, args, _ := sqlx.In(
        "INSERT INTO products (name, price, quantity) VALUES (?), (?), (?)",
        products[0].Name, products[1].Name, products[2].Name,
        products[0].Price, products[1].Price, products[2].Price,
        products[0].Quantity, products[1].Quantity, products[2].Quantity,
    )
    fmt.Printf("Query: %s\nArgs: %v\n", query, args)
​
    _, err := db.Exec(query, args...)
    if err != nil {
        log.Fatal(err)
    }
}
​
func batchUpdateDemo(db *sqlx.DB) {
    // 批量更新价格
    priceUpdates := []struct {
        ID    int
        Price float64
    }{
        {1, 8999.00},
        {2, 5499.00},
        {3, 17999.00},
    }
​
    tx, _ := db.Beginx()
    for _, update := range priceUpdates {
        _, err := tx.Exec("UPDATE products SET price = ? WHERE id = ?", update.Price, update.ID)
        if err != nil {
            tx.Rollback()
            log.Fatal(err)
        }
    }
    tx.Commit()
}

5.3 动态查询构建

复制代码
package main
​
import (
    "fmt"
    "log"
    "strings"
​
    "github.com/jmoiron/sqlx"
    _ "github.com/go-sql-driver/mysql"
)
​
type UserFilter struct {
    Name   string
    Email  string
    AgeMin int
    AgeMax int
    Active *bool
}
​
func buildDynamicQuery(db *sqlx.DB, filter UserFilter) {
    // 构建动态SQL
    conditions := []string{}
    args := []interface{}{}
​
    if filter.Name != "" {
        conditions = append(conditions, "name LIKE ?")
        args = append(args, "%"+filter.Name+"%")
    }
​
    if filter.Email != "" {
        conditions = append(conditions, "email LIKE ?")
        args = append(args, "%"+filter.Email+"%")
    }
​
    if filter.AgeMin > 0 {
        conditions = append(conditions, "age >= ?")
        args = append(args, filter.AgeMin)
    }
​
    if filter.AgeMax > 0 {
        conditions = append(conditions, "age <= ?")
        args = append(args, filter.AgeMax)
    }
​
    if filter.Active != nil {
        conditions = append(conditions, "active = ?")
        args = append(args, *filter.Active)
    }
​
    // 构建最终SQL
    whereClause := ""
    if len(conditions) > 0 {
        whereClause = "WHERE " + strings.Join(conditions, " AND ")
    }
​
    query := fmt.Sprintf("SELECT * FROM users %s ORDER BY id", whereClause)
    fmt.Printf("Query: %s\nArgs: %v\n", query, args)
​
    var users []User
    err := db.Select(&users, query, args...)
    if err != nil {
        log.Fatal(err)
    }
​
    fmt.Printf("Found %d users\n", len(users))
}

六、SQL注入防护

6.1 注入攻击原理

复制代码
package main
​
import (
    "database/sql"
    "fmt"
​
    _ "github.com/go-sql-driver/mysql"
)
​
// 危险:直接拼接SQL
func dangerousQuery(db *sql.DB, username string) {
    // 恶意输入: "admin' OR '1'='1"
    query := "SELECT * FROM users WHERE username = '" + username + "'"
    fmt.Printf("Query: %s\n", query)
​
    rows, _ := db.Query(query)
    defer rows.Close()
​
    for rows.Next() {
        var id int
        var name string
        rows.Scan(&id, &name)
        fmt.Printf("User: id=%d, name=%s\n", id, name)
    }
}
​
// 安全:使用参数化查询
func safeQuery(db *sql.DB, username string) {
    query := "SELECT * FROM users WHERE username = ?"
    fmt.Printf("Query: %s\n", query)
​
    rows, err := db.Query(query, username)
    if err != nil {
        fmt.Printf("Query error: %v\n", err)
        return
    }
    defer rows.Close()
​
    for rows.Next() {
        var id int
        var name string
        rows.Scan(&id, &name)
        fmt.Printf("User: id=%d, name=%s\n", id, name)
    }
}

6.2 最佳实践

  1. 始终使用参数化查询

  2. 使用ORM或sqlx等库减少手写SQL

  3. 输入验证和过滤

  4. 最小权限原则

复制代码
package main
​
import (
    "fmt"
    "regexp"
    "strings"
​
    "github.com/jmoiron/sqlx"
    _ "github.com/go-sql-driver/mysql"
)
​
// 输入验证
func validateInput(input string) bool {
    // 只允许字母、数字和下划线
    matched, _ := regexp.MatchString(`^[a-zA-Z0-9_]+$`, input)
    return matched
}
​
// 清理输入
func sanitizeInput(input string) string {
    // 转义特殊字符
    input = strings.ReplaceAll(input, "'", "''")
    input = strings.ReplaceAll(input, "\\", "\\\\")
    return input
}
​
func main() {
    db, _ := sqlx.Connect("mysql", "root:password@tcp(localhost:3306)/test?parseTime=true")
    defer db.Close()
​
    username := "admin"
    if !validateInput(username) {
        fmt.Println("Invalid input")
        return
    }
​
    var user struct {
        ID   int    `db:"id"`
        Name string `db:"name"`
    }
    err := db.Get(&user, "SELECT * FROM users WHERE username = ?", username)
    if err != nil {
        fmt.Printf("Error: %v\n", err)
        return
    }
    fmt.Printf("User: %+v\n", user)
}

七、实际案例:构建数据访问层

7.1 项目结构设计

复制代码
dal/
├── dal.go          # 初始化和配置
├── user.go         # 用户数据访问
├── product.go      # 产品数据访问
├── order.go        # 订单数据访问
└── repository.go   # 泛型Repository模式

7.2 数据访问层实现

复制代码
package dal
​
import (
    "context"
    "database/sql"
    "fmt"
    "time"
​
    "github.com/jmoiron/sqlx"
    _ "github.com/go-sql-driver/mysql"
)
​
type Config struct {
    Host            string
    Port            int
    User            string
    Password        string
    Database        string
    MaxOpenConns    int
    MaxIdleConns    int
    ConnMaxLifetime time.Duration
}
​
type DAL struct {
    DB *sqlx.DB
}
​
// NewDAL 创建DAL实例
func NewDAL(cfg Config) (*DAL, error) {
    dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&loc=Local",
        cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
​
    db, err := sqlx.Connect("mysql", dsn)
    if err != nil {
        return nil, fmt.Errorf("connect to database failed: %w", err)
    }
​
    // 配置连接池
    db.SetMaxOpenConns(cfg.MaxOpenConns)
    db.SetMaxIdleConns(cfg.MaxIdleConns)
    db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
​
    return &DAL{DB: db}, nil
}
​
// Close 关闭数据库连接
func (d *DAL) Close() error {
    return d.DB.Close()
}
​
// WithTransaction 执行事务
func (d *DAL) WithTransaction(ctx context.Context, fn func(*sqlx.Tx) error) error {
    tx, err := d.DB.BeginTxx(ctx, nil)
    if err != nil {
        return fmt.Errorf("begin transaction failed: %w", err)
    }
​
    if err := fn(tx); err != nil {
        tx.Rollback()
        return err
    }
​
    if err := tx.Commit(); err != nil {
        return fmt.Errorf("commit failed: %w", err)
    }
​
    return nil
}

7.3 用户仓储实现

复制代码
package dal
​
import (
    "context"
    "database/sql"
    "errors"
    "fmt"
​
    "github.com/jmoiron/sqlx"
)
​
type User struct {
    ID        int       `db:"id" json:"id"`
    Username  string    `db:"username" json:"username"`
    Email     string    `db:"email" json:"email"`
    Password  string    `db:"password_hash" json:"-"`
    Age       int       `db:"age" json:"age"`
    Active    bool      `db:"active" json:"active"`
    CreatedAt time.Time `db:"created_at" json:"created_at"`
    UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
​
type UserRepository struct {
    dal *DAL
}
​
func NewUserRepository(dal *DAL) *UserRepository {
    return &UserRepository{dal: dal}
}
​
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(ctx context.Context, id int) (*User, error) {
    var user User
    query := "SELECT * FROM users WHERE id = ?"
​
    err := r.dal.DB.GetContext(ctx, &user, query, id)
    if err != nil {
        if errors.Is(err, sql.ErrNoRows) {
            return nil, nil
        }
        return nil, fmt.Errorf("get user by id failed: %w", err)
    }
​
    return &user, nil
}
​
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*User, error) {
    var user User
    query := "SELECT * FROM users WHERE username = ?"
​
    err := r.dal.DB.GetContext(ctx, &user, query, username)
    if err != nil {
        if errors.Is(err, sql.ErrNoRows) {
            return nil, nil
        }
        return nil, fmt.Errorf("get user by username failed: %w", err)
    }
​
    return &user, nil
}
​
// List 分页获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]User, int, error) {
    var users []User
    query := "SELECT * FROM users ORDER BY id LIMIT ? OFFSET ?"
​
    err := r.dal.DB.SelectContext(ctx, &users, query, limit, offset)
    if err != nil {
        return nil, 0, fmt.Errorf("list users failed: %w", err)
    }
​
    var total int
    countQuery := "SELECT COUNT(*) FROM users"
    if err := r.dal.DB.GetContext(ctx, &total, countQuery); err != nil {
        return nil, 0, fmt.Errorf("count users failed: %w", err)
    }
​
    return users, total, nil
}
​
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *User) error {
    query := `INSERT INTO users (username, email, password_hash, age, active, created_at, updated_at)
              VALUES (?, ?, ?, ?, ?, NOW(), NOW())`
​
    result, err := r.dal.DB.ExecContext(ctx, query,
        user.Username, user.Email, user.Password, user.Age, user.Active)
    if err != nil {
        return fmt.Errorf("create user failed: %w", err)
    }
​
    id, err := result.LastInsertId()
    if err != nil {
        return fmt.Errorf("get last insert id failed: %w", err)
    }
​
    user.ID = int(id)
    return nil
}
​
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, user *User) error {
    query := `UPDATE users SET username = ?, email = ?, age = ?, active = ?, updated_at = NOW()
              WHERE id = ?`
​
    result, err := r.dal.DB.ExecContext(ctx, query,
        user.Username, user.Email, user.Age, user.Active, user.ID)
    if err != nil {
        return fmt.Errorf("update user failed: %w", err)
    }
​
    affected, _ := result.RowsAffected()
    if affected == 0 {
        return errors.New("user not found")
    }
​
    return nil
}
​
// Delete 删除用户
func (r *UserRepository) Delete(ctx context.Context, id int) error {
    query := "DELETE FROM users WHERE id = ?"
​
    result, err := r.dal.DB.ExecContext(ctx, query, id)
    if err != nil {
        return fmt.Errorf("delete user failed: %w", err)
    }
​
    affected, _ := result.RowsAffected()
    if affected == 0 {
        return errors.New("user not found")
    }
​
    return nil
}
​
// WithTransaction 在事务中操作
func (r *UserRepository) CreateWithProfile(ctx context.Context, user *User, profile map[string]string) error {
    return r.dal.WithTransaction(ctx, func(tx *sqlx.Tx) error {
        // 创建用户
        query := `INSERT INTO users (username, email, password_hash, age, active, created_at, updated_at)
                  VALUES (?, ?, ?, ?, ?, NOW(), NOW())`
​
        result, err := tx.ExecContext(ctx, query,
            user.Username, user.Email, user.Password, user.Age, user.Active)
        if err != nil {
            return fmt.Errorf("create user failed: %w", err)
        }
​
        userID, _ := result.LastInsertId()
        user.ID = int(userID)
​
        // 创建用户资料
        profileQuery := "INSERT INTO user_profiles (user_id, key, value) VALUES (?, ?, ?)"
        for k, v := range profile {
            _, err := tx.ExecContext(ctx, profileQuery, userID, k, v)
            if err != nil {
                return fmt.Errorf("create profile failed: %w", err)
            }
        }
​
        return nil
    })
}

7.4 订单仓储实现

复制代码
package dal
​
import (
    "context"
    "errors"
    "fmt"
​
    "github.com/jmoiron/sqlx"
)
​
type OrderItem struct {
    ID        int     `db:"id" json:"id"`
    OrderID   int     `db:"order_id" json:"order_id"`
    ProductID int     `db:"product_id" json:"product_id"`
    ProductName string `db:"product_name" json:"product_name"`
    Price     float64 `db:"price" json:"price"`
    Quantity  int     `db:"quantity" json:"quantity"`
}
​
type Order struct {
    ID         int         `db:"id" json:"id"`
    UserID     int         `db:"user_id" json:"user_id"`
    Status     string      `db:"status" json:"status"`
    Total      float64     `db:"total" json:"total"`
    Items      []OrderItem `db:"-" json:"items"`
    CreatedAt time.Time   `db:"created_at" json:"created_at"`
    UpdatedAt time.Time   `db:"updated_at" json:"updated_at"`
}
​
type OrderRepository struct {
    dal *DAL
}
​
func NewOrderRepository(dal *DAL) *OrderRepository {
    return &OrderRepository{dal: dal}
}
​
// Create 创建订单
func (r *OrderRepository) Create(ctx context.Context, order *Order) error {
    return r.dal.WithTransaction(ctx, func(tx *sqlx.Tx) error {
        // 创建订单头
        orderQuery := `INSERT INTO orders (user_id, status, total, created_at, updated_at)
                       VALUES (?, ?, ?, NOW(), NOW())`
​
        result, err := tx.ExecContext(ctx, orderQuery, order.UserID, order.Status, order.Total)
        if err != nil {
            return fmt.Errorf("create order failed: %w", err)
        }
​
        orderID, _ := result.LastInsertId()
        order.ID = int(orderID)
​
        // 创建订单明细
        itemQuery := `INSERT INTO order_items (order_id, product_id, product_name, price, quantity)
                      VALUES (?, ?, ?, ?, ?)`
​
        for i := range order.Items {
            item := &order.Items[i]
            item.OrderID = int(orderID)
​
            result, err := tx.ExecContext(ctx, itemQuery,
                item.OrderID, item.ProductID, item.ProductName, item.Price, item.Quantity)
            if err != nil {
                return fmt.Errorf("create order item failed: %w", err)
            }
​
            itemID, _ := result.LastInsertId()
            item.ID = int(itemID)
        }
​
        return nil
    })
}
​
// GetByID 获取订单详情
func (r *OrderRepository) GetByID(ctx context.Context, id int) (*Order, error) {
    var order Order
    query := "SELECT * FROM orders WHERE id = ?"
​
    if err := r.dal.DB.GetContext(ctx, &order, query, id); err != nil {
        if errors.Is(err, sql.ErrNoRows) {
            return nil, nil
        }
        return nil, fmt.Errorf("get order failed: %w", err)
    }
​
    // 获取订单明细
    var items []OrderItem
    itemsQuery := "SELECT * FROM order_items WHERE order_id = ?"
    if err := r.dal.DB.SelectContext(ctx, &items, itemsQuery, id); err != nil {
        return nil, fmt.Errorf("get order items failed: %w", err)
    }
​
    order.Items = items
    return &order, nil
}
​
// GetByUserID 获取用户的订单列表
func (r *OrderRepository) GetByUserID(ctx context.Context, userID int) ([]Order, error) {
    var orders []Order
    query := "SELECT * FROM orders WHERE user_id = ? ORDER BY created_at DESC"
​
    if err := r.dal.DB.SelectContext(ctx, &orders, query, userID); err != nil {
        return nil, fmt.Errorf("get user orders failed: %w", err)
    }
​
    return orders, nil
}
​
// UpdateStatus 更新订单状态
func (r *OrderRepository) UpdateStatus(ctx context.Context, id int, status string) error {
    query := "UPDATE orders SET status = ?, updated_at = NOW() WHERE id = ?"
​
    result, err := r.dal.DB.ExecContext(ctx, query, status, id)
    if err != nil {
        return fmt.Errorf("update order status failed: %w", err)
    }
​
    affected, _ := result.RowsAffected()
    if affected == 0 {
        return errors.New("order not found")
    }
​
    return nil
}

7.5 使用示例

复制代码
package main
​
import (
    "context"
    "fmt"
    "log"
    "time"
​
    "mypackage/dal"
)
​
func main() {
    // 初始化DAL
    database, err := dal.NewDAL(dal.Config{
        Host:            "localhost",
        Port:            3306,
        User:            "root",
        Password:        "password",
        Database:        "ecommerce",
        MaxOpenConns:    25,
        MaxIdleConns:    10,
        ConnMaxLifetime: 5 * time.Minute,
    })
    if err != nil {
        log.Fatalf("Failed to connect to database: %v", err)
    }
    defer database.Close()
​
    userRepo := dal.NewUserRepository(database)
    orderRepo := dal.NewOrderRepository(database)
​
    ctx := context.Background()
​
    // 创建用户
    user := &dal.User{
        Username: "testuser",
        Email:    "test@example.com",
        Password: "hashed_password",
        Age:      25,
        Active:   true,
    }
​
    if err := userRepo.Create(ctx, user); err != nil {
        log.Fatalf("Failed to create user: %v", err)
    }
    fmt.Printf("User created: ID=%d\n", user.ID)
​
    // 查询用户
    fetchedUser, err := userRepo.GetByID(ctx, user.ID)
    if err != nil {
        log.Fatalf("Failed to get user: %v", err)
    }
    fmt.Printf("User fetched: %+v\n", fetchedUser)
​
    // 创建订单
    order := &dal.Order{
        UserID: user.ID,
        Status: "pending",
        Total:  299.99,
        Items: []dal.OrderItem{
            {ProductID: 1, ProductName: "iPhone", Price: 9999.00, Quantity: 1},
        },
    }
​
    if err := orderRepo.Create(ctx, order); err != nil {
        log.Fatalf("Failed to create order: %v", err)
    }
    fmt.Printf("Order created: ID=%d\n", order.ID)
​
    // 查询订单
    fetchedOrder, err := orderRepo.GetByID(ctx, order.ID)
    if err != nil {
        log.Fatalf("Failed to get order: %v", err)
    }
    fmt.Printf("Order fetched: %+v\n", fetchedOrder)
}

总结

本文全面介绍了Go语言数据库编程的各个方面:

  1. 驱动注册与连接池 :理解database/sql的驱动接口分离设计和连接池配置。

  2. 查询处理 :掌握QueryQueryRowExec的使用场景,以及正确的Scan方式。

  3. 预处理语句 :使用Prepare创建预处理语句,提升性能并防止SQL注入。

  4. 事务处理 :使用BeginCommitRollback进行事务管理,掌握保存点的使用。

  5. sqlx扩展 :利用sqlx库的便捷功能简化数据库操作。

  6. SQL注入防护:始终使用参数化查询,避免SQL拼接。

  7. 数据访问层设计:构建结构清晰、错误处理完善、事务支持良好的DAL层。

良好的数据访问层设计是构建可维护、可扩展后端系统的基础。通过本文介绍的技术和模式,开发者可以编写出高效、安全、健壮的数据库操作代码。

相关推荐
晚风吹红霞1 小时前
C++异常处理核心知识点全解析
开发语言·c++
CoderCodingNo1 小时前
【信奥业余科普】C++ 的奇妙之旅 | 17:面的铺展与文本的本质——二维数组与字符串
开发语言·c++
J2虾虾1 小时前
Java Lambda 表达式详解文档
java·开发语言
csbysj20201 小时前
CSS 网格元素
开发语言
lly2024061 小时前
DOM 元素:深入理解与高效运用
开发语言
鸟儿不吃草1 小时前
安卓实现左右布局聊天界面
android·开发语言·python
jieyucx2 小时前
Go 零基础数据结构:顺序表(像「排抽屉」一样学增删改查)
java·数据结构·golang
曦夜日长2 小时前
C++ STL容器string(一):string的变量细节、默认函数的认识以及常用接口的使用
java·开发语言·c++
代码中介商2 小时前
C++ STL 标准模板库完全指南:从容器到迭代器
开发语言·c++·stl