using Abstractions; using AutoMapper; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; using W542.GandalfReborn.Data.Entities.Base; using W542.GandalfReborn.Data.Entities.Security; using W542.GandalfReborn.Data.Entities.Tenant; using W542.GandalfReborn.Data.Entities.Version; using W542.GandalfReborn.Data.Extensions; namespace W542.GandalfReborn.Data.Database; public sealed class ApplicationContext(DbContextOptions options, InvokerContext invokerContext) : CoreContext(options) { private static readonly Dictionary EntityToVersionEntityMap = new(); private const string Schema = "gr"; protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) { optionsBuilder.AddInterceptors(new GrDbConnectionInterceptor(invokerContext)); } protected override void OnModelCreating(ModelBuilder builder) { base.OnModelCreating(builder); ConfigureId(builder); AddEnumToStringConversion(builder); AddVersionRelations(builder); SetTableNames(builder); } private static void ConfigureId(ModelBuilder builder) { var longKeyEntities = builder.Model .GetEntityTypes() .Where(x => !x.ClrType.GetInterfaces().Any(y => y.IsGenericType && y.GetGenericTypeDefinition().IsAssignableTo(typeof(IVersionEntity<>)))) .Where(x => !x.ClrType.GetInterfaces().Any(y => y.IsAssignableTo(typeof(IMappingEntity)))) .Where(x => x.GetProperties().Any(y => y.Name == nameof(IdData.Id) && y.ClrType == typeof(long))) .ToList(); foreach (var entity in longKeyEntities) // var idSequenceName = $"{entity.ClrType.Name.Replace("Entity", string.Empty)}_{nameof(IdData.Id)}_{IdSequenceSuffix}"; // builder.HasSequence(idSequenceName).IncrementsBy(100); builder.Entity(entity.ClrType) .Property(nameof(IdData.Id)) // .UseIdentityAlwaysColumn(); .ValueGeneratedOnAdd(); // .UseHiLo(idSequenceName) // .HasColumnType("bigserial"); } private static void AddVersionRelations(ModelBuilder builder) { // var coreTypes = typeof(CoreContext).Assembly // .GetTypes() // .Where(x => x is { IsAbstract: false, IsInterface: false } && x.GetInterfaces().Any(y => y.IsGenericType && y.GetGenericTypeDefinition() == typeof(IVersionEntity<>))); var dataTypes = typeof(ApplicationContext).Assembly .GetTypes() .Where(x => x is { IsAbstract: false, IsInterface: false } && x.GetInterfaces().Any(y => y.IsGenericType && y.GetGenericTypeDefinition() == typeof(IVersionEntity<>))); // var allVersionTypes = coreTypes.Concat(dataTypes).ToList(); foreach (var type in dataTypes) { var entityBuilder = builder.Entity(type); var referenceType = type.GetProperties() .Where(x => x.Name == nameof(IVersionEntity.Reference)) .Select(x => x.PropertyType) .Single(); EntityToVersionEntityMap.TryAdd(referenceType, type); var referencePrimaryKeys = builder.Model .GetEntityTypes() .Where(x => x.ClrType == referenceType) .Select(x => x.FindPrimaryKey()) .Single(x => x is not null)! .Properties .Select(x => x.Name) .ToArray(); entityBuilder .HasKey([..referencePrimaryKeys, nameof(IVersionEntity.At)]); entityBuilder .HasOne(nameof(IVersionEntity.Reference)) .WithMany() .HasForeignKey(referencePrimaryKeys) .IsRequired(); entityBuilder .HasOne(nameof(IVersionEntity.Suspect)) .WithMany() .HasForeignKey(nameof(IVersionEntity.SuspectId)) .IsRequired(); } var versionTypes = builder.Model .GetEntityTypes() .Where(x => x.ClrType is { IsAbstract: false, IsInterface: false } && x.ClrType.GetInterfaces().Any(y => y.IsGenericType && y.GetGenericTypeDefinition() == typeof(IVersionEntity<>))) .ToList(); versionTypes.ForEach(type => { var allowedNavigationNames = new List { nameof(IVersionEntity.Reference), nameof(IVersionEntity.Suspect) }; var navigations = type .GetNavigations() .Where(x => !allowedNavigationNames.Contains(x.Name)) .ToList(); var entityBuilder = builder.Entity(type.ClrType); navigations.ForEach(x => entityBuilder.Ignore(x.Name)); }); } public void AddVersionTriggers() { foreach (var (entityType, versionType) in EntityToVersionEntityMap) { var dataType = entityType.BaseType; if (dataType == null) throw new Exception($"Could not find base type for {entityType}"); var dataTypeColumns = dataType.GetProperties().Select(x => x.Name).ToList(); var versionTypeColumns = new Dictionary() { [nameof(IVersionEntity.At)] = "current_timestamp", [nameof(IVersionEntity.SuspectId)] = $"current_setting('{GrDbConnectionInterceptor.CurrentSuspectKey}', 't')::bigint", [nameof(IVersionEntity.Action)] = $"(CASE WHEN (tg_op = 'INSERT') THEN '{VersionAction.Created.ToString()}' ELSE '{VersionAction.Modified.ToString()}' END)" }; var rowColumns = string.Join(", ", dataTypeColumns.Concat(versionTypeColumns.Keys).Select(x => $"\"{x}\"")); var rowValues = string.Join(", ", dataTypeColumns.Select(x => $"NEW.\"{x}\"").Concat(versionTypeColumns.Values)); // Trust me, never change those names. var triggerName = $"{GetTableName(entityType).ToLower()}_t"; var functionName = $"p_{triggerName}"; var sql = $""" CREATE OR REPLACE FUNCTION {Schema}.{functionName}() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN insert into {Schema}."{GetTableName(versionType)}" ({rowColumns}) values ({rowValues}); RETURN NEW; END; $$; CREATE OR REPLACE TRIGGER {triggerName} AFTER INSERT Or UPDATE ON {Schema}."{GetTableName(entityType)}" FOR EACH ROW EXECUTE FUNCTION {Schema}.{functionName}(); """; #pragma warning disable EF1002 Database.ExecuteSqlRaw(sql); #pragma warning restore EF1002 } } private static void AddEnumToStringConversion(ModelBuilder builder) { var entityTypes = builder.Model.GetEntityTypes().ToList(); foreach (var type in entityTypes) { var entityBuilder = builder.Entity(type.ClrType); var propertyInfos = type.ClrType.GetProperties(); foreach (var property in propertyInfos) if (property.PropertyType.IsEnum) entityBuilder .Property(property.Name) .HasConversion(); } } private static void SetTableNames(ModelBuilder builder) { var entityTypes = builder.Model .GetEntityTypes(); foreach (var type in entityTypes) { var entityBuilder = builder.Entity(type.ClrType); var tableName = GetTableName(type.ClrType); entityBuilder.ToTable(tableName, Schema); } } private static string GetTableName(Type entityType) { return $"{entityType.Name.Replace("Entity", string.Empty)}"; } }