| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611 |
- using System.Collections.Immutable;
- using System.Data;
- using System.Diagnostics.CodeAnalysis;
- using System.Linq;
- using System.Net;
- using System.Runtime.CompilerServices;
- using System.Threading;
- using System.Threading.Tasks;
- using Content.Server.Administration.Logs;
- using Content.Server.IP;
- using Content.Shared.CCVar;
- using Content.Shared.Database;
- using Microsoft.EntityFrameworkCore;
- using Robust.Shared.Configuration;
- using Robust.Shared.Network;
- using Robust.Shared.Utility;
- namespace Content.Server.Database
- {
- public sealed partial class ServerDbPostgres : ServerDbBase
- {
- private readonly DbContextOptions<PostgresServerDbContext> _options;
- private readonly ISawmill _notifyLog;
- private readonly SemaphoreSlim _prefsSemaphore;
- private readonly Task _dbReadyTask;
- private int _msLag;
- public ServerDbPostgres(DbContextOptions<PostgresServerDbContext> options,
- string connectionString,
- IConfigurationManager cfg,
- ISawmill opsLog,
- ISawmill notifyLog)
- : base(opsLog)
- {
- var concurrency = cfg.GetCVar(CCVars.DatabasePgConcurrency);
- _options = options;
- _notifyLog = notifyLog;
- _prefsSemaphore = new SemaphoreSlim(concurrency, concurrency);
- _dbReadyTask = Task.Run(async () =>
- {
- await using var ctx = new PostgresServerDbContext(_options);
- try
- {
- await ctx.Database.MigrateAsync();
- }
- finally
- {
- await ctx.DisposeAsync();
- }
- });
- cfg.OnValueChanged(CCVars.DatabasePgFakeLag, v => _msLag = v, true);
- InitNotificationListener(connectionString);
- }
- #region Ban
- public override async Task<ServerBanDef?> GetServerBanAsync(int id)
- {
- await using var db = await GetDbImpl();
- var query = db.PgDbContext.Ban
- .Include(p => p.Unban)
- .Where(p => p.Id == id);
- var ban = await query.SingleOrDefaultAsync();
- return ConvertBan(ban);
- }
- public override async Task<ServerBanDef?> GetServerBanAsync(
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds)
- {
- if (address == null && userId == null && hwId == null)
- {
- throw new ArgumentException("Address, userId, and hwId cannot all be null");
- }
- await using var db = await GetDbImpl();
- var exempt = await GetBanExemptionCore(db, userId);
- var newPlayer = userId == null || !await PlayerRecordExists(db, userId.Value);
- var query = MakeBanLookupQuery(address, userId, hwId, modernHWIds, db, includeUnbanned: false, exempt, newPlayer)
- .OrderByDescending(b => b.BanTime);
- var ban = await query.FirstOrDefaultAsync();
- return ConvertBan(ban);
- }
- public override async Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- bool includeUnbanned)
- {
- if (address == null && userId == null && hwId == null)
- {
- throw new ArgumentException("Address, userId, and hwId cannot all be null");
- }
- await using var db = await GetDbImpl();
- var exempt = await GetBanExemptionCore(db, userId);
- var newPlayer = !await db.PgDbContext.Player.AnyAsync(p => p.UserId == userId);
- var query = MakeBanLookupQuery(address, userId, hwId, modernHWIds, db, includeUnbanned, exempt, newPlayer);
- var queryBans = await query.ToArrayAsync();
- var bans = new List<ServerBanDef>(queryBans.Length);
- foreach (var ban in queryBans)
- {
- var banDef = ConvertBan(ban);
- if (banDef != null)
- {
- bans.Add(banDef);
- }
- }
- return bans;
- }
- private static IQueryable<ServerBan> MakeBanLookupQuery(
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- DbGuardImpl db,
- bool includeUnbanned,
- ServerBanExemptFlags? exemptFlags,
- bool newPlayer)
- {
- DebugTools.Assert(!(address == null && userId == null && hwId == null));
- var query = MakeBanLookupQualityShared<ServerBan, ServerUnban>(
- userId,
- hwId,
- modernHWIds,
- db.PgDbContext.Ban);
- if (address != null && !exemptFlags.GetValueOrDefault(ServerBanExemptFlags.None).HasFlag(ServerBanExemptFlags.IP))
- {
- var newQ = db.PgDbContext.Ban
- .Include(p => p.Unban)
- .Where(b => b.Address != null
- && EF.Functions.ContainsOrEqual(b.Address.Value, address)
- && !(b.ExemptFlags.HasFlag(ServerBanExemptFlags.BlacklistedRange) && !newPlayer));
- query = query == null ? newQ : query.Union(newQ);
- }
- DebugTools.Assert(
- query != null,
- "At least one filter item (IP/UserID/HWID) must have been given to make query not null.");
- if (!includeUnbanned)
- {
- query = query.Where(p =>
- p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
- }
- if (exemptFlags is { } exempt)
- {
- if (exempt != ServerBanExemptFlags.None)
- exempt |= ServerBanExemptFlags.BlacklistedRange; // Any kind of exemption should bypass BlacklistedRange
- query = query.Where(b => (b.ExemptFlags & exempt) == 0);
- }
- return query.Distinct();
- }
- private static IQueryable<TBan>? MakeBanLookupQualityShared<TBan, TUnban>(
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- DbSet<TBan> set)
- where TBan : class, IBanCommon<TUnban>
- where TUnban : class, IUnbanCommon
- {
- IQueryable<TBan>? query = null;
- if (userId is { } uid)
- {
- var newQ = set
- .Include(p => p.Unban)
- .Where(b => b.PlayerUserId == uid.UserId);
- query = query == null ? newQ : query.Union(newQ);
- }
- if (hwId != null && hwId.Value.Length > 0)
- {
- var newQ = set
- .Include(p => p.Unban)
- .Where(b => b.HWId!.Type == HwidType.Legacy && b.HWId!.Hwid.SequenceEqual(hwId.Value.ToArray()));
- query = query == null ? newQ : query.Union(newQ);
- }
- if (modernHWIds != null)
- {
- foreach (var modernHwid in modernHWIds)
- {
- var newQ = set
- .Include(p => p.Unban)
- .Where(b => b.HWId!.Type == HwidType.Modern && b.HWId!.Hwid.SequenceEqual(modernHwid.ToArray()));
- query = query == null ? newQ : query.Union(newQ);
- }
- }
- return query;
- }
- private static ServerBanDef? ConvertBan(ServerBan? ban)
- {
- if (ban == null)
- {
- return null;
- }
- NetUserId? uid = null;
- if (ban.PlayerUserId is {} guid)
- {
- uid = new NetUserId(guid);
- }
- NetUserId? aUid = null;
- if (ban.BanningAdmin is {} aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- var unbanDef = ConvertUnban(ban.Unban);
- return new ServerBanDef(
- ban.Id,
- uid,
- ban.Address.ToTuple(),
- ban.HWId,
- ban.BanTime,
- ban.ExpirationTime,
- ban.RoundId,
- ban.PlaytimeAtNote,
- ban.Reason,
- ban.Severity,
- aUid,
- unbanDef,
- ban.ExemptFlags);
- }
- private static ServerUnbanDef? ConvertUnban(ServerUnban? unban)
- {
- if (unban == null)
- {
- return null;
- }
- NetUserId? aUid = null;
- if (unban.UnbanningAdmin is {} aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- return new ServerUnbanDef(
- unban.Id,
- aUid,
- unban.UnbanTime);
- }
- public override async Task AddServerBanAsync(ServerBanDef serverBan)
- {
- await using var db = await GetDbImpl();
- db.PgDbContext.Ban.Add(new ServerBan
- {
- Address = serverBan.Address.ToNpgsqlInet(),
- HWId = serverBan.HWId,
- Reason = serverBan.Reason,
- Severity = serverBan.Severity,
- BanningAdmin = serverBan.BanningAdmin?.UserId,
- BanTime = serverBan.BanTime.UtcDateTime,
- ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
- RoundId = serverBan.RoundId,
- PlaytimeAtNote = serverBan.PlaytimeAtNote,
- PlayerUserId = serverBan.UserId?.UserId,
- ExemptFlags = serverBan.ExemptFlags
- });
- await db.PgDbContext.SaveChangesAsync();
- }
- public override async Task AddServerUnbanAsync(ServerUnbanDef serverUnban)
- {
- await using var db = await GetDbImpl();
- db.PgDbContext.Unban.Add(new ServerUnban
- {
- BanId = serverUnban.BanId,
- UnbanningAdmin = serverUnban.UnbanningAdmin?.UserId,
- UnbanTime = serverUnban.UnbanTime.UtcDateTime
- });
- await db.PgDbContext.SaveChangesAsync();
- }
- #endregion
- #region Role Ban
- public override async Task<ServerRoleBanDef?> GetServerRoleBanAsync(int id)
- {
- await using var db = await GetDbImpl();
- var query = db.PgDbContext.RoleBan
- .Include(p => p.Unban)
- .Where(p => p.Id == id);
- var ban = await query.SingleOrDefaultAsync();
- return ConvertRoleBan(ban);
- }
- public override async Task<List<ServerRoleBanDef>> GetServerRoleBansAsync(IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- bool includeUnbanned)
- {
- if (address == null && userId == null && hwId == null)
- {
- throw new ArgumentException("Address, userId, and hwId cannot all be null");
- }
- await using var db = await GetDbImpl();
- var query = MakeRoleBanLookupQuery(address, userId, hwId, modernHWIds, db, includeUnbanned)
- .OrderByDescending(b => b.BanTime);
- return await QueryRoleBans(query);
- }
- private static async Task<List<ServerRoleBanDef>> QueryRoleBans(IQueryable<ServerRoleBan> query)
- {
- var queryRoleBans = await query.ToArrayAsync();
- var bans = new List<ServerRoleBanDef>(queryRoleBans.Length);
- foreach (var ban in queryRoleBans)
- {
- var banDef = ConvertRoleBan(ban);
- if (banDef != null)
- {
- bans.Add(banDef);
- }
- }
- return bans;
- }
- private static IQueryable<ServerRoleBan> MakeRoleBanLookupQuery(
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- DbGuardImpl db,
- bool includeUnbanned)
- {
- var query = MakeBanLookupQualityShared<ServerRoleBan, ServerRoleUnban>(
- userId,
- hwId,
- modernHWIds,
- db.PgDbContext.RoleBan);
- if (address != null)
- {
- var newQ = db.PgDbContext.RoleBan
- .Include(p => p.Unban)
- .Where(b => b.Address != null && EF.Functions.ContainsOrEqual(b.Address.Value, address));
- query = query == null ? newQ : query.Union(newQ);
- }
- if (!includeUnbanned)
- {
- query = query?.Where(p =>
- p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
- }
- query = query!.Distinct();
- return query;
- }
- [return: NotNullIfNotNull(nameof(ban))]
- private static ServerRoleBanDef? ConvertRoleBan(ServerRoleBan? ban)
- {
- if (ban == null)
- {
- return null;
- }
- NetUserId? uid = null;
- if (ban.PlayerUserId is {} guid)
- {
- uid = new NetUserId(guid);
- }
- NetUserId? aUid = null;
- if (ban.BanningAdmin is {} aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- var unbanDef = ConvertRoleUnban(ban.Unban);
- return new ServerRoleBanDef(
- ban.Id,
- uid,
- ban.Address.ToTuple(),
- ban.HWId,
- ban.BanTime,
- ban.ExpirationTime,
- ban.RoundId,
- ban.PlaytimeAtNote,
- ban.Reason,
- ban.Severity,
- aUid,
- unbanDef,
- ban.RoleId);
- }
- private static ServerRoleUnbanDef? ConvertRoleUnban(ServerRoleUnban? unban)
- {
- if (unban == null)
- {
- return null;
- }
- NetUserId? aUid = null;
- if (unban.UnbanningAdmin is {} aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- return new ServerRoleUnbanDef(
- unban.Id,
- aUid,
- unban.UnbanTime);
- }
- public override async Task<ServerRoleBanDef> AddServerRoleBanAsync(ServerRoleBanDef serverRoleBan)
- {
- await using var db = await GetDbImpl();
- var ban = new ServerRoleBan
- {
- Address = serverRoleBan.Address.ToNpgsqlInet(),
- HWId = serverRoleBan.HWId,
- Reason = serverRoleBan.Reason,
- Severity = serverRoleBan.Severity,
- BanningAdmin = serverRoleBan.BanningAdmin?.UserId,
- BanTime = serverRoleBan.BanTime.UtcDateTime,
- ExpirationTime = serverRoleBan.ExpirationTime?.UtcDateTime,
- RoundId = serverRoleBan.RoundId,
- PlaytimeAtNote = serverRoleBan.PlaytimeAtNote,
- PlayerUserId = serverRoleBan.UserId?.UserId,
- RoleId = serverRoleBan.Role,
- };
- db.PgDbContext.RoleBan.Add(ban);
- await db.PgDbContext.SaveChangesAsync();
- return ConvertRoleBan(ban);
- }
- public override async Task AddServerRoleUnbanAsync(ServerRoleUnbanDef serverRoleUnban)
- {
- await using var db = await GetDbImpl();
- db.PgDbContext.RoleUnban.Add(new ServerRoleUnban
- {
- BanId = serverRoleUnban.BanId,
- UnbanningAdmin = serverRoleUnban.UnbanningAdmin?.UserId,
- UnbanTime = serverRoleUnban.UnbanTime.UtcDateTime
- });
- await db.PgDbContext.SaveChangesAsync();
- }
- #endregion
- public override async Task<int> AddConnectionLogAsync(
- NetUserId userId,
- string userName,
- IPAddress address,
- ImmutableTypedHwid? hwId,
- float trust,
- ConnectionDenyReason? denied,
- int serverId)
- {
- await using var db = await GetDbImpl();
- var connectionLog = new ConnectionLog
- {
- Address = address,
- Time = DateTime.UtcNow,
- UserId = userId.UserId,
- UserName = userName,
- HWId = hwId,
- Denied = denied,
- ServerId = serverId,
- Trust = trust,
- };
- db.PgDbContext.ConnectionLog.Add(connectionLog);
- await db.PgDbContext.SaveChangesAsync();
- return connectionLog.Id;
- }
- public override async Task<((Admin, string? lastUserName)[] admins, AdminRank[])>
- GetAllAdminAndRanksAsync(CancellationToken cancel)
- {
- await using var db = await GetDbImpl();
- // Honestly this probably doesn't even matter but whatever.
- await using var tx =
- await db.DbContext.Database.BeginTransactionAsync(IsolationLevel.RepeatableRead, cancel);
- // Join with the player table to find their last seen username, if they have one.
- var admins = await db.PgDbContext.Admin
- .Include(a => a.Flags)
- .GroupJoin(db.PgDbContext.Player, a => a.UserId, p => p.UserId, (a, grouping) => new {a, grouping})
- .SelectMany(t => t.grouping.DefaultIfEmpty(), (t, p) => new {t.a, p!.LastSeenUserName})
- .ToArrayAsync(cancel);
- var adminRanks = await db.DbContext.AdminRank.Include(a => a.Flags).ToArrayAsync(cancel);
- return (admins.Select(p => (p.a, p.LastSeenUserName)).ToArray(), adminRanks)!;
- }
- protected override IQueryable<AdminLog> StartAdminLogsQuery(ServerDbContext db, LogFilter? filter = null)
- {
- // https://learn.microsoft.com/en-us/ef/core/querying/sql-queries#passing-parameters
- // Read the link above for parameterization before changing this method or you get the bullet
- if (!string.IsNullOrWhiteSpace(filter?.Search))
- {
- return db.AdminLog.FromSql($"""
- SELECT a.admin_log_id, a.round_id, a.date, a.impact, a.json, a.message, a.type FROM admin_log AS a
- WHERE to_tsvector('english'::regconfig, a.message) @@ websearch_to_tsquery('english'::regconfig, {filter.Search})
- """);
- }
- return db.AdminLog;
- }
- protected override DateTime NormalizeDatabaseTime(DateTime time)
- {
- DebugTools.Assert(time.Kind == DateTimeKind.Utc);
- return time;
- }
- private async Task<DbGuardImpl> GetDbImpl(
- CancellationToken cancel = default,
- [CallerMemberName] string? name = null)
- {
- LogDbOp(name);
- await _dbReadyTask;
- await _prefsSemaphore.WaitAsync(cancel);
- if (_msLag > 0)
- await Task.Delay(_msLag, cancel);
- return new DbGuardImpl(this, new PostgresServerDbContext(_options));
- }
- protected override async Task<DbGuard> GetDb(
- CancellationToken cancel = default,
- [CallerMemberName] string? name = null)
- {
- return await GetDbImpl(cancel, name);
- }
- private sealed class DbGuardImpl : DbGuard
- {
- private readonly ServerDbPostgres _db;
- public DbGuardImpl(ServerDbPostgres db, PostgresServerDbContext dbC)
- {
- _db = db;
- PgDbContext = dbC;
- }
- public PostgresServerDbContext PgDbContext { get; }
- public override ServerDbContext DbContext => PgDbContext;
- public override async ValueTask DisposeAsync()
- {
- await DbContext.DisposeAsync();
- _db._prefsSemaphore.Release();
- }
- }
- }
- }
|