Go语言学习之对象关系映射GORM

在 Gin 框架中集成 ORM,GORM 是绝对的主流选择------成熟、功能完善、生态丰富。我来给你一个从零到生产的完整方案。

技术栈选择

ORM 特点 适用场景
GORM 功能最全、生态最好、Auto Migration 通用项目、快速开发
Ent Facebook 出品、代码生成、类型安全 大型项目、强类型需求
sqlx 轻量、原生 SQL、无魔法 性能敏感、SQL 掌控
sqlc 编译期生成、类型安全 追求性能和类型安全

本篇以 GORM 为核心,覆盖 90% 的业务场景。

项目结构

复制代码
project/
├── main.go
├── config/
│   └── config.go
├── models/
│   └── user.go
├── controllers/
│   └── user_controller.go
├── repositories/
│   └── user_repository.go
├── services/
│   └── user_service.go
├── middlewares/
│   └── auth.go
└── database/
    └── database.go

数据库连接与配置

复制代码
// database/database.go
package database

import (
    "fmt"
    "log"
    "time"
    
    "gorm.io/driver/mysql"
    "gorm.io/driver/postgres"
    "gorm.io/driver/sqlite"
    "gorm.io/gorm"
    "gorm.io/gorm/logger"
)

type Config struct {
    Driver          string
    DSN             string
    MaxIdleConns    int
    MaxOpenConns    int
    ConnMaxLifetime time.Duration
    LogLevel        logger.LogLevel
}

func NewDB(cfg Config) (*gorm.DB, error) {
    var (
        db  *gorm.DB
        err error
    )
    
    // 根据驱动选择数据库
    switch cfg.Driver {
    case "mysql":
        db, err = gorm.Open(mysql.Open(cfg.DSN), &gorm.Config{
            Logger: logger.Default.LogMode(cfg.LogLevel),
        })
    case "postgres":
        db, err = gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{
            Logger: logger.Default.LogMode(cfg.LogLevel),
        })
    case "sqlite":
        db, err = gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{
            Logger: logger.Default.LogMode(cfg.LogLevel),
        })
    default:
        return nil, fmt.Errorf("unsupported driver: %s", cfg.Driver)
    }
    
    if err != nil {
        return nil, fmt.Errorf("failed to connect database: %w", err)
    }
    
    // 连接池配置
    sqlDB, err := db.DB()
    if err != nil {
        return nil, fmt.Errorf("failed to get sql.DB: %w", err)
    }
    
    sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
    sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
    sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
    
    // 测试连接
    if err := sqlDB.Ping(); err != nil {
        return nil, fmt.Errorf("failed to ping database: %w", err)
    }
    
    log.Println("Database connected successfully")
    return db, nil
}

模型定义

复制代码
// models/user.go
package models

import (
    "time"
    "gorm.io/gorm"
)

type User struct {
    ID        uint           `gorm:"primaryKey" json:"id"`
    Username  string         `gorm:"size:50;uniqueIndex;not null" json:"username"`
    Email     string         `gorm:"size:100;uniqueIndex;not null" json:"email"`
    Password  string         `gorm:"size:255;not null" json:"-"`  // 不暴露到 JSON
    Nickname  string         `gorm:"size:50" json:"nickname"`
    Avatar    string         `gorm:"size:255" json:"avatar"`
    Status    int8           `gorm:"default:1;index" json:"status"`  // 1:active 2:inactive 3:banned
    Role      string         `gorm:"size:20;default:user" json:"role"`
    LastLogin *time.Time     `json:"last_login"`
    CreatedAt time.Time      `json:"created_at"`
    UpdatedAt time.Time      `json:"updated_at"`
    DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`  // 软删除
    
    // 关联关系
    Posts     []Post         `gorm:"foreignKey:AuthorID" json:"posts,omitempty"`
}

func (User) TableName() string {
    return "users"
}

type Post struct {
    ID        uint           `gorm:"primaryKey" json:"id"`
    Title     string         `gorm:"size:200;not null" json:"title"`
    Content   string         `gorm:"type:text" json:"content"`
    AuthorID  uint           `gorm:"not null;index" json:"author_id"`
    Status    int8           `gorm:"default:1;index" json:"status"`  // 1:draft 2:published
    ViewCount int            `gorm:"default:0" json:"view_count"`
    CreatedAt time.Time      `json:"created_at"`
    UpdatedAt time.Time      `json:"updated_at"`
    DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
    
    // 关联关系
    Author    User           `gorm:"foreignKey:AuthorID" json:"author,omitempty"`
    Tags      []Tag          `gorm:"many2many:post_tags;" json:"tags,omitempty"`
}

func (Post) TableName() string {
    return "posts"
}

type Tag struct {
    ID        uint      `gorm:"primaryKey" json:"id"`
    Name      string    `gorm:"size:50;uniqueIndex;not null" json:"name"`
    CreatedAt time.Time `json:"created_at"`
    
    Posts     []Post    `gorm:"many2many:post_tags;" json:"posts,omitempty"`
}

func (Tag) TableName() string {
    return "tags"
}

Repository 层(数据访问)

复制代码
// repositories/user_repository.go
package repositories

import (
    "errors"
    "time"
    
    "gorm.io/gorm"
    "your-project/models"
)

type UserRepository struct {
    db *gorm.DB
}

func NewUserRepository(db *gorm.DB) *UserRepository {
    return &UserRepository{db: db}
}

// Create 创建用户
func (r *UserRepository) Create(user *models.User) error {
    return r.db.Create(user).Error
}

// FindByID 根据 ID 查询
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
    var user models.User
    err := r.db.First(&user, id).Error
    if err != nil {
        return nil, err
    }
    return &user, nil
}

// FindByUsername 根据用户名查询
func (r *UserRepository) FindByUsername(username string) (*models.User, error) {
    var user models.User
    err := r.db.Where("username = ?", username).First(&user).Error
    if err != nil {
        return nil, err
    }
    return &user, nil
}

// FindByEmail 根据邮箱查询
func (r *UserRepository) FindByEmail(email string) (*models.User, error) {
    var user models.User
    err := r.db.Where("email = ?", email).First(&user).Error
    if err != nil {
        return nil, err
    }
    return &user, nil
}

// FindAll 分页查询
func (r *UserRepository) FindAll(page, pageSize int) ([]models.User, int64, error) {
    var users []models.User
    var total int64
    
    // 计算总数
    if err := r.db.Model(&models.User{}).Count(&total).Error; err != nil {
        return nil, 0, err
    }
    
    // 分页查询
    offset := (page - 1) * pageSize
    err := r.db.Offset(offset).Limit(pageSize).Find(&users).Error
    if err != nil {
        return nil, 0, err
    }
    
    return users, total, nil
}

// Update 更新用户
func (r *UserRepository) Update(user *models.User) error {
    return r.db.Save(user).Error
}

// UpdateFields 更新指定字段
func (r *UserRepository) UpdateFields(id uint, fields map[string]interface{}) error {
    return r.db.Model(&models.User{}).Where("id = ?", id).Updates(fields).Error
}

// UpdateLastLogin 更新最后登录时间
func (r *UserRepository) UpdateLastLogin(id uint) error {
    now := time.Now()
    return r.UpdateFields(id, map[string]interface{}{
        "last_login": &now,
    })
}

// Delete 软删除
func (r *UserRepository) Delete(id uint) error {
    return r.db.Delete(&models.User{}, id).Error
}

// HardDelete 硬删除
func (r *UserRepository) HardDelete(id uint) error {
    return r.db.Unscoped().Delete(&models.User{}, id).Error
}

// ExistsByUsername 检查用户名是否存在
func (r *UserRepository) ExistsByUsername(username string) (bool, error) {
    var count int64
    err := r.db.Model(&models.User{}).Where("username = ?", username).Count(&count).Error
    return count > 0, err
}

// ExistsByEmail 检查邮箱是否存在
func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
    var count int64
    err := r.db.Model(&models.User{}).Where("email = ?", email).Count(&count).Error
    return count > 0, err
}

// FindWithPosts 查询用户及其文章(预加载)
func (r *UserRepository) FindWithPosts(id uint) (*models.User, error) {
    var user models.User
    err := r.db.Preload("Posts").First(&user, id).Error
    if err != nil {
        return nil, err
    }
    return &user, nil
}

// FindActiveUsers 查询活跃用户
func (r *UserRepository) FindActiveUsers(limit int) ([]models.User, error) {
    var users []models.User
    err := r.db.Where("status = ?", 1).Limit(limit).Find(&users).Error
    return users, err
}

Service 层(业务逻辑)

复制代码
// services/user_service.go
package services

import (
    "errors"
    "time"
    
    "golang.org/x/crypto/bcrypt"
    "your-project/models"
    "your-project/repositories"
)

var (
    ErrUserNotFound       = errors.New("user not found")
    ErrUsernameExists     = errors.New("username already exists")
    ErrEmailExists        = errors.New("email already exists")
    ErrInvalidCredentials = errors.New("invalid credentials")
)

type UserService struct {
    repo *repositories.UserRepository
}

func NewUserService(repo *repositories.UserRepository) *UserService {
    return &UserService{repo: repo}
}

type CreateUserInput struct {
    Username string `json:"username" binding:"required,min=3,max=50"`
    Email    string `json:"email" binding:"required,email"`
    Password string `json:"password" binding:"required,min=6"`
    Nickname string `json:"nickname"`
}

type UpdateUserInput struct {
    Nickname string `json:"nickname"`
    Avatar   string `json:"avatar"`
}

type LoginInput struct {
    Email    string `json:"email" binding:"required,email"`
    Password string `json:"password" binding:"required"`
}

type UserResponse struct {
    ID        uint      `json:"id"`
    Username  string    `json:"username"`
    Email     string    `json:"email"`
    Nickname  string    `json:"nickname"`
    Avatar    string    `json:"avatar"`
    Status    int8      `json:"status"`
    Role      string    `json:"role"`
    LastLogin *time.Time `json:"last_login"`
    CreatedAt time.Time `json:"created_at"`
}

// Create 创建用户
func (s *UserService) Create(input CreateUserInput) (*UserResponse, error) {
    // 检查用户名是否存在
    if exists, _ := s.repo.ExistsByUsername(input.Username); exists {
        return nil, ErrUsernameExists
    }
    
    // 检查邮箱是否存在
    if exists, _ := s.repo.ExistsByEmail(input.Email); exists {
        return nil, ErrEmailExists
    }
    
    // 密码加密
    hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
    if err != nil {
        return nil, err
    }
    
    user := &models.User{
        Username: input.Username,
        Email:    input.Email,
        Password: string(hashedPassword),
        Nickname: input.Nickname,
        Status:   1,
        Role:     "user",
    }
    
    if err := s.repo.Create(user); err != nil {
        return nil, err
    }
    
    return s.toResponse(user), nil
}

// Login 用户登录
func (s *UserService) Login(input LoginInput) (*UserResponse, error) {
    user, err := s.repo.FindByEmail(input.Email)
    if err != nil {
        return nil, ErrInvalidCredentials
    }
    
    // 验证密码
    if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(input.Password)); err != nil {
        return nil, ErrInvalidCredentials
    }
    
    // 更新最后登录时间
    _ = s.repo.UpdateLastLogin(user.ID)
    
    return s.toResponse(user), nil
}

// GetByID 根据 ID 获取用户
func (s *UserService) GetByID(id uint) (*UserResponse, error) {
    user, err := s.repo.FindByID(id)
    if err != nil {
        return nil, ErrUserNotFound
    }
    return s.toResponse(user), nil
}

// List 获取用户列表
func (s *UserService) List(page, pageSize int) ([]UserResponse, int64, error) {
    users, total, err := s.repo.FindAll(page, pageSize)
    if err != nil {
        return nil, 0, err
    }
    
    responses := make([]UserResponse, len(users))
    for i, user := range users {
        responses[i] = *s.toResponse(&user)
    }
    
    return responses, total, nil
}

// Update 更新用户
func (s *UserService) Update(id uint, input UpdateUserInput) (*UserResponse, error) {
    user, err := s.repo.FindByID(id)
    if err != nil {
        return nil, ErrUserNotFound
    }
    
    if input.Nickname != "" {
        user.Nickname = input.Nickname
    }
    if input.Avatar != "" {
        user.Avatar = input.Avatar
    }
    
    if err := s.repo.Update(user); err != nil {
        return nil, err
    }
    
    return s.toResponse(user), nil
}

// Delete 删除用户
func (s *UserService) Delete(id uint) error {
    return s.repo.Delete(id)
}

func (s *UserService) toResponse(user *models.User) *UserResponse {
    return &UserResponse{
        ID:        user.ID,
        Username:  user.Username,
        Email:     user.Email,
        Nickname:  user.Nickname,
        Avatar:    user.Avatar,
        Status:    user.Status,
        Role:      user.Role,
        LastLogin: user.LastLogin,
        CreatedAt: user.CreatedAt,
    }
}

Controller 层

复制代码
// controllers/user_controller.go
package controllers

import (
    "net/http"
    "strconv"
    
    "github.com/gin-gonic/gin"
    "your-project/services"
)

type UserController struct {
    service *services.UserService
}

func NewUserController(service *services.UserService) *UserController {
    return &UserController{service: service}
}

// Register 用户注册
func (c *UserController) Register(ctx *gin.Context) {
    var input services.CreateUserInput
    if err := ctx.ShouldBindJSON(&input); err != nil {
        ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
        return
    }
    
    user, err := c.service.Create(input)
    if err != nil {
        ctx.JSON(http.StatusConflict, gin.H{"error": err.Error()})
        return
    }
    
    ctx.JSON(http.StatusCreated, gin.H{
        "message": "user created successfully",
        "data":    user,
    })
}

// Login 用户登录
func (c *UserController) Login(ctx *gin.Context) {
    var input services.LoginInput
    if err := ctx.ShouldBindJSON(&input); err != nil {
        ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
        return
    }
    
    user, err := c.service.Login(input)
    if err != nil {
        ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
        return
    }
    
    // 生成 JWT Token(示例)
    // token, _ := generateJWT(user.ID)
    
    ctx.JSON(http.StatusOK, gin.H{
        "message": "login successful",
        "data":    user,
        // "token":   token,
    })
}

// GetProfile 获取用户信息
func (c *UserController) GetProfile(ctx *gin.Context) {
    // 从中间件获取用户 ID
    userID, exists := ctx.Get("user_id")
    if !exists {
        ctx.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
        return
    }
    
    user, err := c.service.GetByID(userID.(uint))
    if err != nil {
        ctx.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
        return
    }
    
    ctx.JSON(http.StatusOK, gin.H{"data": user})
}

// GetUser 获取单个用户
func (c *UserController) GetUser(ctx *gin.Context) {
    id, err := strconv.ParseUint(ctx.Param("id"), 10, 32)
    if err != nil {
        ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
        return
    }
    
    user, err := c.service.GetByID(uint(id))
    if err != nil {
        ctx.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
        return
    }
    
    ctx.JSON(http.StatusOK, gin.H{"data": user})
}

// ListUsers 获取用户列表
func (c *UserController) ListUsers(ctx *gin.Context) {
    page, _ := strconv.Atoi(ctx.DefaultQuery("page", "1"))
    pageSize, _ := strconv.Atoi(ctx.DefaultQuery("page_size", "10"))
    
    if page < 1 {
        page = 1
    }
    if pageSize < 1 || pageSize > 100 {
        pageSize = 10
    }
    
    users, total, err := c.service.List(page, pageSize)
    if err != nil {
        ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get users"})
        return
    }
    
    ctx.JSON(http.StatusOK, gin.H{
        "data": users,
        "meta": gin.H{
            "page":      page,
            "page_size": pageSize,
            "total":     total,
        },
    })
}

// UpdateUser 更新用户
func (c *UserController) UpdateUser(ctx *gin.Context) {
    id, err := strconv.ParseUint(ctx.Param("id"), 10, 32)
    if err != nil {
        ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
        return
    }
    
    var input services.UpdateUserInput
    if err := ctx.ShouldBindJSON(&input); err != nil {
        ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
        return
    }
    
    user, err := c.service.Update(uint(id), input)
    if err != nil {
        ctx.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
        return
    }
    
    ctx.JSON(http.StatusOK, gin.H{
        "message": "user updated successfully",
        "data":    user,
    })
}

// DeleteUser 删除用户
func (c *UserController) DeleteUser(ctx *gin.Context) {
    id, err := strconv.ParseUint(ctx.Param("id"), 10, 32)
    if err != nil {
        ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
        return
    }
    
    if err := c.service.Delete(uint(id)); err != nil {
        ctx.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
        return
    }
    
    ctx.JSON(http.StatusOK, gin.H{"message": "user deleted successfully"})
}

主程序集成

复制代码
// main.go
package main

import (
    "log"
    
    "github.com/gin-gonic/gin"
    "gorm.io/gorm"
    
    "your-project/config"
    "your-project/controllers"
    "your-project/database"
    "your-project/middlewares"
    "your-project/models"
    "your-project/repositories"
    "your-project/services"
)

func main() {
    // 加载配置
    cfg := config.Load()
    
    // 初始化数据库
    db, err := database.NewDB(database.Config{
        Driver:          cfg.DB.Driver,
        DSN:             cfg.DB.DSN,
        MaxIdleConns:    cfg.DB.MaxIdleConns,
        MaxOpenConns:    cfg.DB.MaxOpenConns,
        ConnMaxLifetime: cfg.DB.ConnMaxLifetime,
        LogLevel:        gorm.LogLevel(cfg.DB.LogLevel),
    })
    if err != nil {
        log.Fatalf("Failed to connect database: %v", err)
    }
    
    // 自动迁移
    if err := autoMigrate(db); err != nil {
        log.Fatalf("Failed to migrate database: %v", err)
    }
    
    // 初始化依赖
    userRepo := repositories.NewUserRepository(db)
    userService := services.NewUserService(userRepo)
    userController := controllers.NewUserController(userService)
    
    // 初始化 Gin
    r := gin.Default()
    
    // 注册路由
    setupRoutes(r, userController)
    
    // 启动服务
    log.Printf("Server starting on :%s", cfg.Server.Port)
    if err := r.Run(":" + cfg.Server.Port); err != nil {
        log.Fatalf("Failed to start server: %v", err)
    }
}

func autoMigrate(db *gorm.DB) error {
    return db.AutoMigrate(
        &models.User{},
        &models.Post{},
        &models.Tag{},
    )
}

func setupRoutes(r *gin.Engine, userCtrl *controllers.UserController) {
    // 公开路由
    public := r.Group("/api/v1")
    {
        public.POST("/register", userCtrl.Register)
        public.POST("/login", userCtrl.Login)
        public.GET("/users", userCtrl.ListUsers)
        public.GET("/users/:id", userCtrl.GetUser)
    }
    
    // 需要认证的路由
    protected := r.Group("/api/v1")
    protected.Use(middlewares.AuthMiddleware())
    {
        protected.GET("/profile", userCtrl.GetProfile)
        protected.PUT("/users/:id", userCtrl.UpdateUser)
        protected.DELETE("/users/:id", userCtrl.DeleteUser)
    }
}

高级查询示例

复制代码
// repositories/post_repository.go
package repositories

import (
    "gorm.io/gorm"
    "your-project/models"
)

type PostRepository struct {
    db *gorm.DB
}

func NewPostRepository(db *gorm.DB) *PostRepository {
    return &PostRepository{db: db}
}

// FindPublished 查询已发布文章(分页)
func (r *PostRepository) FindPublished(page, pageSize int) ([]models.Post, int64, error) {
    var posts []models.Post
    var total int64
    
    query := r.db.Model(&models.Post{}).Where("status = ?", 2)  // 2: published
    
    if err := query.Count(&total).Error; err != nil {
        return nil, 0, err
    }
    
    offset := (page - 1) * pageSize
    err := query.Preload("Author").
        Preload("Tags").
        Order("created_at DESC").
        Offset(offset).
        Limit(pageSize).
        Find(&posts).Error
    
    return posts, total, err
}

// Search 搜索文章
func (r *PostRepository) Search(keyword string, page, pageSize int) ([]models.Post, int64, error) {
    var posts []models.Post
    var total int64
    
    query := r.db.Model(&models.Post{}).
        Where("status = ?", 2).
        Where("title LIKE ? OR content LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
    
    if err := query.Count(&total).Error; err != nil {
        return nil, 0, err
    }
    
    offset := (page - 1) * pageSize
    err := query.Preload("Author").
        Order("created_at DESC").
        Offset(offset).
        Limit(pageSize).
        Find(&posts).Error
    
    return posts, total, err
}

// FindByTagID 根据标签查询
func (r *PostRepository) FindByTagID(tagID uint, page, pageSize int) ([]models.Post, int64, error) {
    var posts []models.Post
    var total int64
    
    query := r.db.Model(&models.Post{}).
        Joins("JOIN post_tags ON post_tags.post_id = posts.id").
        Where("post_tags.tag_id = ? AND posts.status = ?", tagID, 2)
    
    if err := query.Count(&total).Error; err != nil {
        return nil, 0, err
    }
    
    offset := (page - 1) * pageSize
    err := query.Preload("Author").
        Preload("Tags").
        Order("posts.created_at DESC").
        Offset(offset).
        Limit(pageSize).
        Find(&posts).Error
    
    return posts, total, err
}

// IncrementViewCount 增加浏览次数
func (r *PostRepository) IncrementViewCount(id uint) error {
    return r.db.Model(&models.Post{}).
        Where("id = ?", id).
        UpdateColumn("view_count", gorm.Expr("view_count + ?", 1)).
        Error
}

// FindHotPosts 查询热门文章(浏览量排序)
func (r *PostRepository) FindHotPosts(limit int) ([]models.Post, error) {
    var posts []models.Post
    err := r.db.Where("status = ?", 2).
        Order("view_count DESC").
        Limit(limit).
        Find(&posts).Error
    return posts, err
}

// FindByAuthorID 查询作者的文章
func (r *PostRepository) FindByAuthorID(authorID uint, page, pageSize int) ([]models.Post, int64, error) {
    var posts []models.Post
    var total int64
    
    query := r.db.Model(&models.Post{}).Where("author_id = ?", authorID)
    
    if err := query.Count(&total).Error; err != nil {
        return nil, 0, err
    }
    
    offset := (page - 1) * pageSize
    err := query.Order("created_at DESC").
        Offset(offset).
        Limit(pageSize).
        Find(&posts).Error
    
    return posts, total, err
}

事务处理

复制代码
// services/post_service.go
package services

import (
    "errors"
    "gorm.io/gorm"
    "your-project/models"
    "your-project/repositories"
)

type PostService struct {
    db       *gorm.DB
    postRepo *repositories.PostRepository
    tagRepo  *repositories.TagRepository
}

func NewPostService(db *gorm.DB, postRepo *repositories.PostRepository, tagRepo *repositories.TagRepository) *PostService {
    return &PostService{db: db, postRepo: postRepo, tagRepo: tagRepo}
}

type CreatePostInput struct {
    Title   string   `json:"title" binding:"required"`
    Content string   `json:"content" binding:"required"`
    Tags    []string `json:"tags"`
}

// Create 创建文章(带事务)
func (s *PostService) Create(authorID uint, input CreatePostInput) (*models.Post, error) {
    var post *models.Post
    
    err := s.db.Transaction(func(tx *gorm.DB) error {
        // 创建文章
        post = &models.Post{
            Title:    input.Title,
            Content:  input.Content,
            AuthorID: authorID,
            Status:   2,  // published
        }
        
        if err := tx.Create(post).Error; err != nil {
            return err
        }
        
        // 处理标签
        if len(input.Tags) > 0 {
            var tags []models.Tag
            for _, tagName := range input.Tags {
                var tag models.Tag
                // 查找或创建标签
                result := tx.Where("name = ?", tagName).FirstOrCreate(&tag, models.Tag{Name: tagName})
                if result.Error != nil {
                    return result.Error
                }
                tags = append(tags, tag)
            }
            
            // 关联标签
            if err := tx.Model(post).Association("Tags").Replace(tags); err != nil {
                return err
            }
        }
        
        return nil
    })
    
    if err != nil {
        return nil, err
    }
    
    return post, nil
}

// Publish 发布文章(状态变更)
func (s *PostService) Publish(id uint) error {
    return s.db.Transaction(func(tx *gorm.DB) error {
        var post models.Post
        if err := tx.First(&post, id).Error; err != nil {
            return err
        }
        
        if post.Status == 2 {
            return errors.New("post already published")
        }
        
        return tx.Model(&post).Update("status", 2).Error
    })
}

常见问题与优化

复制代码
// 1. N+1 查询问题
// ❌ 错误:循环中查询关联
posts, _ := postRepo.FindAll()
for _, post := range posts {
    author, _ := userRepo.FindByID(post.AuthorID)  // N 次查询
}

// ✅ 正确:使用 Preload 预加载
db.Preload("Author").Find(&posts)

// 2. 分页优化
// ❌ 大偏移量性能差
db.Offset(100000).Limit(10).Find(&posts)

// ✅ 使用游标分页
db.Where("id > ?", lastID).Limit(10).Find(&posts)

// 3. 批量插入
// ❌ 循环插入
for _, user := range users {
    db.Create(&user)  // N 次数据库操作
}

// ✅ 批量插入
db.CreateInBatches(users, 100)  // 每批 100 条

// 4. 更新优化
// ❌ 全字段更新
db.Save(&user)  // 更新所有字段

// ✅ 更新指定字段
db.Model(&user).Select("nickname", "avatar").Updates(user)

// 5. 查询优化
// ✅ 只查询需要的字段
db.Select("id", "title", "author_id").Find(&posts)

// ✅ 使用索引
db.Where("status = ? AND created_at > ?", 2, yesterday).Find(&posts)

这个方案提供了从数据库连接到业务逻辑的完整分层架构,适合中大型项目的长期维护。

你的项目现在用的是什么 ORM?是否有特殊的性能需求或复杂查询场景?我可以针对具体场景给出更深入的优化建议。

相关推荐
网络工程小王2 小时前
【Transformer架构详解】(学习笔记)
笔记·学习
白毛大侠3 小时前
理解 Go 接口:eface 与 iface 的区别及动态性解析
开发语言·网络·golang
倒酒小生4 小时前
今日算法学习小结
学习
醇氧4 小时前
【学习】【说人话版】子网划分
学习
不灭锦鲤5 小时前
网络安全学习(面试)
学习·安全·web安全
世人万千丶6 小时前
Flutter 框架跨平台鸿蒙开发 - 鸿蒙版本五子棋游戏应用
学习·flutter·游戏·华为·harmonyos·鸿蒙
Aktx20FNz6 小时前
一文学习 Spring AOP 源码全过程
java·学习·spring
Jay Kay6 小时前
生成式推荐模型学习记录part1
学习
so2F32hj27 小时前
一款Go语言Gin框架DDD脚手架,适合快速搭建项目
开发语言·golang·gin