造轮子之ORM集成

Dotnet的ORM千千万,还是喜欢用EF CORE

前面一些基础完成的差不多了,接下来可以集成数据库了,官方出品的ORM还是比较香。所以接下来就是来集成EF CORE。

安装包

首先我们需要安装一下EF CORE的NUGET包,有如下几个:

csharp 复制代码
Microsoft.EntityFrameworkCore.Proxies
Microsoft.EntityFrameworkCore.Sqlite
Microsoft.EntityFrameworkCore.Design
Microsoft.EntityFrameworkCore.Tools

其中Microsoft.EntityFrameworkCore.Sqlite我们可以根据我们实际使用的数据库进行替换。

而Microsoft.EntityFrameworkCore.Proxies则是用于启用EF中的懒加载模式。.

Microsoft.EntityFrameworkCore.Design和Microsoft.EntityFrameworkCore.Tools则是用于数据库迁移

DbContext

接下来创建我们的DbContext文件

csharp 复制代码
namespace Wheel.EntityFrameworkCore
{
    public class WheelDbContext : DbContext
    {
        public WheelDbContext(DbContextOptions<WheelDbContext> options) : base(options)
        {

        }

        protected override void OnModelCreating(ModelBuilder builder)
        {
            base.OnModelCreating(builder);
        }

}

在Program中添加DbContext

csharp 复制代码
var connectionString = builder.Configuration.GetConnectionString("Default") ?? throw new InvalidOperationException("Connection string 'Default' not found.");

builder.Services.AddDbContext<WheelDbContext>(options =>
    options.UseSqlite(connectionString)
        .UseLazyLoadingProxies()
);

在配置文件中添加连接字符串

csharp 复制代码
"ConnectionStrings": {
  "Default": "Data Source=Wheel.WebApi.Host.db"
}

封装Repository

在AddDbContext之后,我们就可以在程序中直接注入WheelDbContext来操作我们的数据库了。但是为了我们以后可能随时切换ORM,我们还是封装一层Repository来操作我们的数据库。

新增IBasicRepository泛型接口:

csharp 复制代码
    public interface IBasicRepository<TEntity, TKey> where TEntity : class 
    {
        Task<TEntity> InsertAsync(TEntity entity, bool autoSave = false, CancellationToken cancellationToken = default);
        Task InsertManyAsync(List<TEntity> entities, bool autoSave = false, CancellationToken cancellationToken = default);
        Task<TEntity> UpdateAsync(TEntity entity, bool autoSave = false, CancellationToken cancellationToken = default);
        Task UpdateAsync(Expression<Func<TEntity, bool>> predicate, Expression<Func<SetPropertyCalls<TEntity>, SetPropertyCalls<TEntity>>> setPropertyCalls, bool autoSave = false, CancellationToken cancellationToken = default);
        Task UpdateManyAsync(List<TEntity> entities, bool autoSave = false, CancellationToken cancellationToken = default);
        Task DeleteAsync(TKey id, bool autoSave = false, CancellationToken cancellationToken = default);
        Task DeleteAsync(TEntity entity, bool autoSave = false, CancellationToken cancellationToken = default);
        Task DeleteAsync(Expression<Func<TEntity, bool>> predicate, bool autoSave = false, CancellationToken cancellationToken = default);
        Task DeleteManyAsync(List<TEntity> entities, bool autoSave = false, CancellationToken cancellationToken = default);
        Task<TEntity?> FindAsync(TKey id, CancellationToken cancellationToken = default);
        Task<TEntity?> FindAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default);
        Task<bool> AnyAsync(CancellationToken cancellationToken = default);
        Task<bool> AnyAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default);
        Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors);
        Task<List<TSelect>> SelectListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, CancellationToken cancellationToken = default);
        Task<List<TSelect>> SelectListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors);
        Task<(List<TSelect> items, long total)> SelectPageListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default);
        Task<(List<TSelect> items, long total)> SelectPageListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors);
        Task<(List<TEntity> items, long total)> GetPageListAsync(Expression<Func<TEntity, bool>> predicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default);
        Task<(List<TEntity> items, long total)> GetPageListAsync(Expression<Func<TEntity, bool>> predicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors);
        IQueryable<TEntity> GetQueryable(bool noTracking = true);

        IQueryable<TEntity> GetQueryableWithIncludes(params Expression<Func<TEntity, object>>[] propertySelectors);

        Task<int> SaveChangeAsync(CancellationToken cancellationToken = default);
        Expression<Func<TEntity, bool>> BuildPredicate(params (bool condition, Expression<Func<TEntity, bool>> predicate)[] conditionPredicates);
    }

    public interface IBasicRepository<TEntity> : IBasicRepository<TEntity, object> where TEntity : class
    {

    }

IBasicRepository<TEntity, TKey>用于单主键的表结构,IBasicRepository : IBasicRepository<TEntity, object>用于复合主键的表结构。
然后我们来实现一下BasicRepository:

csharp 复制代码
    public class EFBasicRepository<TEntity, TKey> : IBasicRepository<TEntity, TKey> where TEntity : class
    {
        private readonly WheelDbContext _dbContext;

        private DbSet<TEntity> DbSet => _dbContext.Set<TEntity>();

        public EFBasicRepository(WheelDbContext dbContext)
        {
            _dbContext = dbContext;
        }

        public async Task<TEntity> InsertAsync(TEntity entity, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            var savedEntity = (await _dbContext.Set<TEntity>().AddAsync(entity, cancellationToken)).Entity;
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
            return savedEntity;
        }
        public async Task InsertManyAsync(List<TEntity> entities, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            await _dbContext.Set<TEntity>().AddRangeAsync(entities, cancellationToken);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task<TEntity> UpdateAsync(TEntity entity, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            var savedEntity = _dbContext.Set<TEntity>().Update(entity).Entity;

            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
            return savedEntity;
        }
        public async Task UpdateAsync(Expression<Func<TEntity, bool>> predicate, Expression<Func<SetPropertyCalls<TEntity>, SetPropertyCalls<TEntity>>> setPropertyCalls, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            await _dbContext.Set<TEntity>().Where(predicate).ExecuteUpdateAsync(setPropertyCalls, cancellationToken);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task UpdateManyAsync(List<TEntity> entities, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            _dbContext.Set<TEntity>().UpdateRange(entities);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task DeleteAsync(TKey id, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            var entity = await _dbContext.Set<TEntity>().FindAsync(id, cancellationToken);
            if(entity != null)
                _dbContext.Set<TEntity>().Remove(entity);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task DeleteAsync(TEntity entity, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            _dbContext.Set<TEntity>().Remove(entity);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task DeleteAsync(Expression<Func<TEntity, bool>> predicate, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            await _dbContext.Set<TEntity>().Where(predicate).ExecuteDeleteAsync(cancellationToken);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task DeleteManyAsync(List<TEntity> entities, bool autoSave = false, CancellationToken cancellationToken = default)
        {
            _dbContext.Set<TEntity>().RemoveRange(entities);
            if (autoSave)
            {
                await _dbContext.SaveChangesAsync(cancellationToken);
            }
        }
        public async Task<TEntity?> FindAsync(TKey id, CancellationToken cancellationToken = default)
        {
            return await _dbContext.Set<TEntity>().FindAsync(id, cancellationToken);
        }
        public async Task<TEntity?> FindAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default)
        {
            return await _dbContext.Set<TEntity>().AsNoTracking().FirstOrDefaultAsync(predicate, cancellationToken);
        }
        public async Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default)
        {
            return await _dbContext.Set<TEntity>().Where(predicate).ToListAsync(cancellationToken);
        }
        public async Task<List<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors)
        {
            return await GetQueryableWithIncludes(propertySelectors).Where(predicate).ToListAsync(cancellationToken);
        }
        public async Task<List<TSelect>> SelectListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors)
        {
            return await GetQueryableWithIncludes(propertySelectors).Where(predicate).Select(selectPredicate).ToListAsync(cancellationToken);
        }
        public async Task<List<TSelect>> SelectListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, CancellationToken cancellationToken = default)
        {
            return await GetQueryable().Where(predicate).Select(selectPredicate).ToListAsync(cancellationToken);
        }
        public async Task<(List<TSelect> items, long total)> SelectPageListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default)
        {
            var query = GetQueryable().Where(predicate).Select(selectPredicate);
            var total = await query.LongCountAsync(cancellationToken);
            var items = await query.OrderBy(orderby)
                .Skip(skip).Take(take)
                .ToListAsync(cancellationToken);
            return (items, total);
        }
        public async Task<(List<TSelect> items, long total)> SelectPageListAsync<TSelect>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TSelect>> selectPredicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors)
        {
            var query = GetQueryableWithIncludes(propertySelectors).Where(predicate).Select(selectPredicate);
            var total = await query.LongCountAsync(cancellationToken);
            var items = await query.OrderBy(orderby)
                .Skip(skip).Take(take)
                .ToListAsync(cancellationToken);
            return (items, total);
        }
        public async Task<(List<TEntity> items, long total)> GetPageListAsync(Expression<Func<TEntity, bool>> predicate, int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default)
        {
            var query = GetQueryable().Where(predicate);
            var total = await query.LongCountAsync(cancellationToken);
            var items = await query.OrderBy(orderby)
                .Skip(skip).Take(take)
                .ToListAsync(cancellationToken);
            return (items, total);
        }
        public async Task<(List<TEntity> items, long total)> GetPageListAsync(Expression<Func<TEntity, bool>> predicate, 
            int skip, int take, string orderby = "Id", CancellationToken cancellationToken = default, params Expression<Func<TEntity, object>>[] propertySelectors)
        {
            var query = GetQueryableWithIncludes(propertySelectors).Where(predicate);
            var total = await query.LongCountAsync(cancellationToken);
            var items = await query.OrderBy(orderby)
                .Skip(skip).Take(take)
                .ToListAsync(cancellationToken);
            return (items, total);
        }

        public Task<bool> AnyAsync(CancellationToken cancellationToken = default)
        {
            return DbSet.AnyAsync(cancellationToken);
        }

        public Task<bool> AnyAsync(Expression<Func<TEntity, bool>> predicate, CancellationToken cancellationToken = default)
        {
            return DbSet.AnyAsync(predicate, cancellationToken);
        }
        public IQueryable<TEntity> GetQueryable(bool noTracking = true)
        {
            if (noTracking)
            {
                return _dbContext.Set<TEntity>().AsNoTracking();
            }
            return _dbContext.Set<TEntity>();
        }
        public IQueryable<TEntity> GetQueryableWithIncludes(params Expression<Func<TEntity, object>>[] propertySelectors)
        {
            return Includes(GetQueryable(), propertySelectors);
        }

        public Expression<Func<TEntity, bool>> BuildPredicate(params (bool condition, Expression<Func<TEntity, bool>> predicate)[] conditionPredicates)
        {
            if(conditionPredicates == null || conditionPredicates.Length == 0)
            {
                throw new ArgumentNullException("conditionPredicates can not be null.");
            }
            Expression<Func<TEntity, bool>>? buildPredicate = null;
            foreach (var (condition, predicate) in conditionPredicates)
            {
                if (condition)
                {
                    if (buildPredicate == null)
                        buildPredicate = predicate;
                    else if(predicate != null)
                        buildPredicate = buildPredicate.And(predicate);
                }
            }
            if(buildPredicate == null)
            {
                buildPredicate = (o) => true;
            }
            return buildPredicate;
        }

        private static IQueryable<TEntity> Includes(IQueryable<TEntity> query, Expression<Func<TEntity, object>>[] propertySelectors)
        {
            if (propertySelectors != null && propertySelectors.Length > 0)
            {
                foreach (var propertySelector in propertySelectors)
                {
                    query = query.Include(propertySelector);
                }
            }

            return query;
        }
        public async Task<int> SaveChangeAsync(CancellationToken cancellationToken = default)
        {
            return await _dbContext.SaveChangesAsync(cancellationToken);
        }

        protected DbSet<TEntity> GetDbSet()
        {
            return _dbContext.Set<TEntity>();
        }

        protected IDbConnection GetDbConnection()
        {
            return _dbContext.Database.GetDbConnection();
        }

        protected IDbTransaction? GetDbTransaction()
        {
            return _dbContext.Database.CurrentTransaction?.GetDbTransaction();
        }

    }


    public class EFBasicRepository<TEntity> : EFBasicRepository<TEntity, object>, IBasicRepository<TEntity> where TEntity : class
    {
        public EFBasicRepository(WheelDbContext dbContext) : base(dbContext)
        {
        }
    }

这样我们CURD的操作的Repository就实现好了。

在列表查询和分页查询中,特意实现了SelectList,避免在某些场景下每次查询数据库都查询所有表字段却只使用了其中几个字段。也能有效提高查询性能。

这里分页查询特意使用了元组返回值,避免我们在分页查询时需要写两次操作,一次查总数,一次查真实数据。

还有实现了一个BuildPredicate来拼接我们的条件表达式,个人由于写太多WhereIf有点腻了,所以弄了个方法来干掉WhereIf,虽然这个方法可能不算完美。

实际操作如下图:

当然BuildPredicate这个方法也不只有在查询方法中可以使用,在删除和更新方法中,我们同样可以根据条件这样拼接条件表达式。

添加到依赖注入

由于Autofac的RegisterAssemblyTypes不支持泛型接口注入,所以我们这里需要使用RegisterGeneric来注册我们的泛型仓储。

在WheelAutofacModule添加如下代码即可:

csharp 复制代码
builder.RegisterGeneric(typeof(EFBasicRepository<,>)).As(typeof(IBasicRepository<,>)).InstancePerDependency();
builder.RegisterGeneric(typeof(EFBasicRepository<>)).As(typeof(IBasicRepository<>)).InstancePerDependency();

工作单元UOW

工作单元模式用于协调多个仓储的操作并确保它们在一个事务中进行。

这里我们来实现一个简单的工作单元模式。

首先实现一个DbTransaction:

csharp 复制代码
namespace Wheel.Uow
{
    public interface IDbTransaction : IDisposable, IAsyncDisposable
    {
        Task<IDbContextTransaction> BeginTransactionAsync(CancellationToken cancellationToken = default);
        Task CommitAsync(CancellationToken cancellationToken = default);
        Task RollbackAsync(CancellationToken cancellationToken = default);
    }
    public class DbTransaction : IDbTransaction
    {
        private readonly DbContext _dbContext;

        IDbContextTransaction? CurrentDbContextTransaction;

        bool isCommit = false;
        bool isRollback = false;
        public DbTransaction(DbContext dbContext)
        {
            _dbContext = dbContext;
        }

        public async Task<IDbContextTransaction> BeginTransactionAsync(CancellationToken cancellationToken = default)
        {
            CurrentDbContextTransaction = await _dbContext.Database.BeginTransactionAsync(cancellationToken);
            return CurrentDbContextTransaction;
        }

        public async Task CommitAsync(CancellationToken cancellationToken = default)
        {
            await _dbContext.SaveChangesAsync();
            await _dbContext.Database.CommitTransactionAsync();
            isCommit = true;
            CurrentDbContextTransaction = null;
        }
        public void Commit()
        {
            _dbContext.Database.CommitTransaction();
            isCommit = true;
            CurrentDbContextTransaction = null;
        }

        public async Task RollbackAsync(CancellationToken cancellationToken = default)
        {
            await _dbContext.Database.RollbackTransactionAsync(cancellationToken);
            isRollback = true;
            CurrentDbContextTransaction = null;
        }
        public void Dispose()
        {
            if(CurrentDbContextTransaction != null)
            {
                if(!isCommit && !isRollback)
                {
                    Commit();
                }
                CurrentDbContextTransaction.Dispose();
            }
        }

        public async ValueTask DisposeAsync()
        {
            if(CurrentDbContextTransaction != null)
            {
                if (!isCommit && !isRollback)
                {
                    await CommitAsync();
                }
                await CurrentDbContextTransaction.DisposeAsync();
            }
        }

    }
}

DbTransaction负责操作开启事务,提交事务以及回滚事务。

实现UnitOfWork:

csharp 复制代码
namespace Wheel.Uow
{
    public interface IUnitOfWork : IScopeDependency, IDisposable, IAsyncDisposable
    {
        Task<int> SaveChangesAsync(CancellationToken cancellationToken = default); 
        Task<IDbTransaction> BeginTransactionAsync(CancellationToken cancellationToken = default);
        Task CommitAsync(CancellationToken cancellationToken = default);
        Task RollbackAsync(CancellationToken cancellationToken = default);
    }
    public class UnitOfWork : IUnitOfWork
    {
        private readonly WheelDbContext _dbContext;
        private IDbTransaction? Transaction = null;

        public UnitOfWork(WheelDbContext dbContext)
        {
            _dbContext = dbContext;
        }

        public async Task<int> SaveChangesAsync(CancellationToken cancellationToken = default)
        {
            return await _dbContext.SaveChangesAsync(cancellationToken);
        }
        public async Task<IDbTransaction> BeginTransactionAsync(CancellationToken cancellationToken = default)
        {
            Transaction = new DbTransaction(_dbContext);
            await Transaction.BeginTransactionAsync(cancellationToken);
            return Transaction;
        }
        public async Task CommitAsync(CancellationToken cancellationToken = default)
        {
            if(Transaction == null) 
            {
                throw new Exception("Transaction is null, Please BeginTransaction");
            }
            await Transaction.CommitAsync(cancellationToken);
        }

        public async Task RollbackAsync(CancellationToken cancellationToken = default)
        {
            if (Transaction == null)
            {
                throw new Exception("Transaction is null, Please BeginTransaction");
            }
            await Transaction.RollbackAsync(cancellationToken);
        }
        public void Dispose()
        {
            if(Transaction != null)
                Transaction.Dispose();
            _dbContext.Dispose();
        }

        public async ValueTask DisposeAsync()
        {
            if (Transaction != null)
                await Transaction.DisposeAsync();
            await _dbContext.DisposeAsync();
        }
    }
}

UnitOfWork负责控制DbTransaction的操作以及数据库SaveChanges。

EF拦截器

在数据库操作中,我们经常有一些数据是希望可以自动记录的,如插入数据自动根据当前时间给创建时间字段赋值,修改时自动根据当前时间修改最近更新时间字段。亦或者当需要软删除操作时,我们正常调用Delete方法,实际是修改表数据,而不是从表中物理删除数据。

添加软删除,创建时间以及更新时间接口:

csharp 复制代码
public interface ISoftDelete
{
    /// <summary>
    /// 是否删除
    /// </summary>
    public bool IsDeleted { get; set; }
}
csharp 复制代码
public interface IHasUpdateTime
{
    /// <summary>
    /// 最近修改时间
    /// </summary>
    DateTimeOffset UpdateTime { get; set; }
}
csharp 复制代码
public interface IHasCreationTime
{
    /// <summary>
    /// 创建时间
    /// </summary>
    DateTimeOffset CreationTime { get; set; }
}

实现WheelEFCoreInterceptor,继承SaveChangesInterceptor,当调用SaveChanges方法是就会执行拦截器的逻辑操作。

csharp 复制代码
namespace Wheel.EntityFrameworkCore
{
    /// <summary>
    /// EF拦截器
    /// </summary>
    public sealed class WheelEFCoreInterceptor : SaveChangesInterceptor
    {
        public override InterceptionResult<int> SavingChanges(DbContextEventData eventData, InterceptionResult<int> result)
        {
            OnSavingChanges(eventData);
            return base.SavingChanges(eventData, result);
        }

        public static void OnSavingChanges(DbContextEventData eventData)
        {
            ArgumentNullException.ThrowIfNull(eventData.Context);
            eventData.Context.ChangeTracker.DetectChanges();
            foreach (var entityEntry in eventData.Context.ChangeTracker.Entries())
            {
                if (entityEntry is { State: EntityState.Deleted, Entity: ISoftDelete softDeleteEntity })
                {
                    softDeleteEntity.IsDeleted = true;
                    entityEntry.State = EntityState.Modified;
                }
                if (entityEntry is { State: EntityState.Modified, Entity: IHasUpdateTime hasUpdateTimeEntity })
                {
                    hasUpdateTimeEntity.UpdateTime = DateTimeOffset.Now;
                }
                if (entityEntry is { State: EntityState.Added, Entity: IHasCreationTime hasCreationTimeEntity })
                {
                    hasCreationTimeEntity.CreationTime = DateTimeOffset.Now;
                }
            }
        }
    }
}

在AddDbContext添加我们的拦截器:

csharp 复制代码
builder.Services.AddDbContext<WheelDbContext>(options =>
    options.UseSqlite(connectionString)
        .AddInterceptors(new WheelEFCoreInterceptor())
        .UseLazyLoadingProxies()
);

这样就完成了我们ORM的集成了。

欢迎进群催更。