diff --git a/DigitalData.Core.Infrastructure/DependencyInjection.cs b/DigitalData.Core.Infrastructure/DependencyInjection.cs index 41a328a..ed9f42b 100644 --- a/DigitalData.Core.Infrastructure/DependencyInjection.cs +++ b/DigitalData.Core.Infrastructure/DependencyInjection.cs @@ -11,14 +11,13 @@ public static class DependencyInjection { public static IServiceCollection AddDbRepository(this IServiceCollection services, Action options) { + // register services from configuration var cfg = new RepositoryConfiguration(); options.Invoke(cfg); + cfg.RegisterAllServices(services); - // 1. FromAssembly - cfg.RegsFromAssembly.InvokeAll(services); - - // 2. Entities to be able overwrite - cfg.RegsEntity.InvokeAll(services); + // register db repository factory + services.AddSingleton(); return services; } @@ -26,10 +25,25 @@ public static class DependencyInjection public class RepositoryConfiguration { // 1. register from assembly - internal Queue> RegsFromAssembly = new(); + private readonly Queue> RegsFromAssembly = new(); // 2. register entities (can overwrite) - internal Queue> RegsEntity = new(); + private readonly Queue> RegsEntity = new(); + + // 3. register db set factories (can overwrite) + private readonly Queue> RegsDbSetFactory = new(); + + internal void RegisterAllServices(IServiceCollection services) + { + // 1. register from assembly + RegsFromAssembly.InvokeAll(services); + + // 2. register entities (can overwrite) + RegsEntity.InvokeAll(services); + + // 1. register db set factories (can overwrite) + RegsDbSetFactory.InvokeAll(services); + } internal RepositoryConfiguration() { } @@ -68,14 +82,22 @@ public static class DependencyInjection RegsFromAssembly.Enqueue(reg); } - public void RegisterEntity() where TDbContext : DbContext + public void RegisterEntity(Func>? dbSetFactory = null) + where TDbContext : DbContext where TEntity : class { - static void reg(IServiceCollection services) - => services.AddScoped, DbRepository>(); + void reg(IServiceCollection services) + => services + .AddScoped, DbRepository>() + .AddDbSetFactory(dbSetFactory); RegsEntity.Enqueue(reg); } + + public void RegisterDbSetFactory(Func> dbSetFactory) + where TDbContext : DbContext + where TEntity : class + => RegsDbSetFactory.Enqueue(s => s.AddDbSetFactory(dbSetFactory)); } private static void InvokeAll(this Queue> queue, T services)