| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659 |
- using System.Collections.Immutable;
- using System.Diagnostics.CodeAnalysis;
- using System.Linq;
- using System.Net;
- using System.Runtime.CompilerServices;
- using System.Threading;
- using System.Threading.Tasks;
- using Content.Server.Administration.Logs;
- using Content.Server.IP;
- using Content.Server.Preferences.Managers;
- using Content.Shared.CCVar;
- using Content.Shared.Database;
- using Microsoft.EntityFrameworkCore;
- using Robust.Shared.Configuration;
- using Robust.Shared.Network;
- using Robust.Shared.Utility;
- namespace Content.Server.Database
- {
- /// <summary>
- /// Provides methods to retrieve and update character preferences.
- /// Don't use this directly, go through <see cref="ServerPreferencesManager" /> instead.
- /// </summary>
- public sealed class ServerDbSqlite : ServerDbBase
- {
- private readonly Func<DbContextOptions<SqliteServerDbContext>> _options;
- private readonly ConcurrencySemaphore _prefsSemaphore;
- private readonly Task _dbReadyTask;
- private int _msDelay;
- public ServerDbSqlite(
- Func<DbContextOptions<SqliteServerDbContext>> options,
- bool inMemory,
- IConfigurationManager cfg,
- bool synchronous,
- ISawmill opsLog)
- : base(opsLog)
- {
- _options = options;
- var prefsCtx = new SqliteServerDbContext(options());
- // When inMemory we re-use the same connection, so we can't have any concurrency.
- var concurrency = inMemory ? 1 : cfg.GetCVar(CCVars.DatabaseSqliteConcurrency);
- _prefsSemaphore = new ConcurrencySemaphore(concurrency, synchronous);
- if (synchronous)
- {
- prefsCtx.Database.Migrate();
- _dbReadyTask = Task.CompletedTask;
- prefsCtx.Dispose();
- }
- else
- {
- _dbReadyTask = Task.Run(() =>
- {
- prefsCtx.Database.Migrate();
- prefsCtx.Dispose();
- });
- }
- cfg.OnValueChanged(CCVars.DatabaseSqliteDelay, v => _msDelay = v, true);
- }
- #region Ban
- public override async Task<ServerBanDef?> GetServerBanAsync(int id)
- {
- await using var db = await GetDbImpl();
- var ban = await db.SqliteDbContext.Ban
- .Include(p => p.Unban)
- .Where(p => p.Id == id)
- .SingleOrDefaultAsync();
- return ConvertBan(ban);
- }
- public override async Task<ServerBanDef?> GetServerBanAsync(
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds)
- {
- await using var db = await GetDbImpl();
- return (await GetServerBanQueryAsync(db, address, userId, hwId, modernHWIds, includeUnbanned: false)).FirstOrDefault();
- }
- public override async Task<List<ServerBanDef>> GetServerBansAsync(
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- bool includeUnbanned)
- {
- await using var db = await GetDbImpl();
- return (await GetServerBanQueryAsync(db, address, userId, hwId, modernHWIds, includeUnbanned)).ToList();
- }
- private async Task<IEnumerable<ServerBanDef>> GetServerBanQueryAsync(
- DbGuardImpl db,
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- bool includeUnbanned)
- {
- var exempt = await GetBanExemptionCore(db, userId);
- var newPlayer = !await db.SqliteDbContext.Player.AnyAsync(p => p.UserId == userId);
- // SQLite can't do the net masking stuff we need to match IP address ranges.
- // So just pull down the whole list into memory.
- var queryBans = await GetAllBans(db.SqliteDbContext, includeUnbanned, exempt);
- var playerInfo = new BanMatcher.PlayerInfo
- {
- Address = address,
- UserId = userId,
- ExemptFlags = exempt ?? default,
- HWId = hwId,
- ModernHWIds = modernHWIds,
- IsNewPlayer = newPlayer,
- };
- return queryBans
- .Select(ConvertBan)
- .Where(b => BanMatcher.BanMatches(b!, playerInfo))!;
- }
- private static async Task<List<ServerBan>> GetAllBans(
- SqliteServerDbContext db,
- bool includeUnbanned,
- ServerBanExemptFlags? exemptFlags)
- {
- IQueryable<ServerBan> query = db.Ban.Include(p => p.Unban);
- if (!includeUnbanned)
- {
- query = query.Where(p =>
- p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
- }
- if (exemptFlags is { } exempt)
- {
- // Any flag to bypass BlacklistedRange bans.
- if (exempt != ServerBanExemptFlags.None)
- exempt |= ServerBanExemptFlags.BlacklistedRange;
- query = query.Where(b => (b.ExemptFlags & exempt) == 0);
- }
- return await query.ToListAsync();
- }
- public override async Task AddServerBanAsync(ServerBanDef serverBan)
- {
- await using var db = await GetDbImpl();
- db.SqliteDbContext.Ban.Add(new ServerBan
- {
- Address = serverBan.Address.ToNpgsqlInet(),
- Reason = serverBan.Reason,
- Severity = serverBan.Severity,
- BanningAdmin = serverBan.BanningAdmin?.UserId,
- HWId = serverBan.HWId,
- BanTime = serverBan.BanTime.UtcDateTime,
- ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
- RoundId = serverBan.RoundId,
- PlaytimeAtNote = serverBan.PlaytimeAtNote,
- PlayerUserId = serverBan.UserId?.UserId,
- ExemptFlags = serverBan.ExemptFlags
- });
- await db.SqliteDbContext.SaveChangesAsync();
- }
- public override async Task AddServerUnbanAsync(ServerUnbanDef serverUnban)
- {
- await using var db = await GetDbImpl();
- db.SqliteDbContext.Unban.Add(new ServerUnban
- {
- BanId = serverUnban.BanId,
- UnbanningAdmin = serverUnban.UnbanningAdmin?.UserId,
- UnbanTime = serverUnban.UnbanTime.UtcDateTime
- });
- await db.SqliteDbContext.SaveChangesAsync();
- }
- #endregion
- #region Role Ban
- public override async Task<ServerRoleBanDef?> GetServerRoleBanAsync(int id)
- {
- await using var db = await GetDbImpl();
- var ban = await db.SqliteDbContext.RoleBan
- .Include(p => p.Unban)
- .Where(p => p.Id == id)
- .SingleOrDefaultAsync();
- return ConvertRoleBan(ban);
- }
- public override async Task<List<ServerRoleBanDef>> GetServerRoleBansAsync(
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds,
- bool includeUnbanned)
- {
- await using var db = await GetDbImpl();
- // SQLite can't do the net masking stuff we need to match IP address ranges.
- // So just pull down the whole list into memory.
- var queryBans = await GetAllRoleBans(db.SqliteDbContext, includeUnbanned);
- return queryBans
- .Where(b => RoleBanMatches(b, address, userId, hwId, modernHWIds))
- .Select(ConvertRoleBan)
- .ToList()!;
- }
- private static async Task<List<ServerRoleBan>> GetAllRoleBans(
- SqliteServerDbContext db,
- bool includeUnbanned)
- {
- IQueryable<ServerRoleBan> query = db.RoleBan.Include(p => p.Unban);
- if (!includeUnbanned)
- {
- query = query.Where(p =>
- p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
- }
- return await query.ToListAsync();
- }
- private static bool RoleBanMatches(
- ServerRoleBan ban,
- IPAddress? address,
- NetUserId? userId,
- ImmutableArray<byte>? hwId,
- ImmutableArray<ImmutableArray<byte>>? modernHWIds)
- {
- if (address != null && ban.Address is not null && address.IsInSubnet(ban.Address.ToTuple().Value))
- {
- return true;
- }
- if (userId is { } id && ban.PlayerUserId == id.UserId)
- {
- return true;
- }
- switch (ban.HWId?.Type)
- {
- case HwidType.Legacy:
- if (hwId is { Length: > 0 } hwIdVar && hwIdVar.AsSpan().SequenceEqual(ban.HWId.Hwid))
- return true;
- break;
- case HwidType.Modern:
- if (modernHWIds != null)
- {
- foreach (var modernHWId in modernHWIds)
- {
- if (modernHWId.AsSpan().SequenceEqual(ban.HWId.Hwid))
- return true;
- }
- }
- break;
- }
- return false;
- }
- public override async Task<ServerRoleBanDef> AddServerRoleBanAsync(ServerRoleBanDef serverBan)
- {
- await using var db = await GetDbImpl();
- var ban = new ServerRoleBan
- {
- Address = serverBan.Address.ToNpgsqlInet(),
- Reason = serverBan.Reason,
- Severity = serverBan.Severity,
- BanningAdmin = serverBan.BanningAdmin?.UserId,
- HWId = serverBan.HWId,
- BanTime = serverBan.BanTime.UtcDateTime,
- ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
- RoundId = serverBan.RoundId,
- PlaytimeAtNote = serverBan.PlaytimeAtNote,
- PlayerUserId = serverBan.UserId?.UserId,
- RoleId = serverBan.Role,
- };
- db.SqliteDbContext.RoleBan.Add(ban);
- await db.SqliteDbContext.SaveChangesAsync();
- return ConvertRoleBan(ban);
- }
- public override async Task AddServerRoleUnbanAsync(ServerRoleUnbanDef serverUnban)
- {
- await using var db = await GetDbImpl();
- db.SqliteDbContext.RoleUnban.Add(new ServerRoleUnban
- {
- BanId = serverUnban.BanId,
- UnbanningAdmin = serverUnban.UnbanningAdmin?.UserId,
- UnbanTime = serverUnban.UnbanTime.UtcDateTime
- });
- await db.SqliteDbContext.SaveChangesAsync();
- }
- [return: NotNullIfNotNull(nameof(ban))]
- private static ServerRoleBanDef? ConvertRoleBan(ServerRoleBan? ban)
- {
- if (ban == null)
- {
- return null;
- }
- NetUserId? uid = null;
- if (ban.PlayerUserId is { } guid)
- {
- uid = new NetUserId(guid);
- }
- NetUserId? aUid = null;
- if (ban.BanningAdmin is { } aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- var unban = ConvertRoleUnban(ban.Unban);
- return new ServerRoleBanDef(
- ban.Id,
- uid,
- ban.Address.ToTuple(),
- ban.HWId,
- // SQLite apparently always reads DateTime as unspecified, but we always write as UTC.
- DateTime.SpecifyKind(ban.BanTime, DateTimeKind.Utc),
- ban.ExpirationTime == null ? null : DateTime.SpecifyKind(ban.ExpirationTime.Value, DateTimeKind.Utc),
- ban.RoundId,
- ban.PlaytimeAtNote,
- ban.Reason,
- ban.Severity,
- aUid,
- unban,
- ban.RoleId);
- }
- private static ServerRoleUnbanDef? ConvertRoleUnban(ServerRoleUnban? unban)
- {
- if (unban == null)
- {
- return null;
- }
- NetUserId? aUid = null;
- if (unban.UnbanningAdmin is { } aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- return new ServerRoleUnbanDef(
- unban.Id,
- aUid,
- // SQLite apparently always reads DateTime as unspecified, but we always write as UTC.
- DateTime.SpecifyKind(unban.UnbanTime, DateTimeKind.Utc));
- }
- #endregion
- [return: NotNullIfNotNull(nameof(ban))]
- private static ServerBanDef? ConvertBan(ServerBan? ban)
- {
- if (ban == null)
- {
- return null;
- }
- NetUserId? uid = null;
- if (ban.PlayerUserId is { } guid)
- {
- uid = new NetUserId(guid);
- }
- NetUserId? aUid = null;
- if (ban.BanningAdmin is { } aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- var unban = ConvertUnban(ban.Unban);
- return new ServerBanDef(
- ban.Id,
- uid,
- ban.Address.ToTuple(),
- ban.HWId,
- // SQLite apparently always reads DateTime as unspecified, but we always write as UTC.
- DateTime.SpecifyKind(ban.BanTime, DateTimeKind.Utc),
- ban.ExpirationTime == null ? null : DateTime.SpecifyKind(ban.ExpirationTime.Value, DateTimeKind.Utc),
- ban.RoundId,
- ban.PlaytimeAtNote,
- ban.Reason,
- ban.Severity,
- aUid,
- unban);
- }
- private static ServerUnbanDef? ConvertUnban(ServerUnban? unban)
- {
- if (unban == null)
- {
- return null;
- }
- NetUserId? aUid = null;
- if (unban.UnbanningAdmin is { } aGuid)
- {
- aUid = new NetUserId(aGuid);
- }
- return new ServerUnbanDef(
- unban.Id,
- aUid,
- // SQLite apparently always reads DateTime as unspecified, but we always write as UTC.
- DateTime.SpecifyKind(unban.UnbanTime, DateTimeKind.Utc));
- }
- public override async Task<int> AddConnectionLogAsync(
- NetUserId userId,
- string userName,
- IPAddress address,
- ImmutableTypedHwid? hwId,
- float trust,
- ConnectionDenyReason? denied,
- int serverId)
- {
- await using var db = await GetDbImpl();
- var connectionLog = new ConnectionLog
- {
- Address = address,
- Time = DateTime.UtcNow,
- UserId = userId.UserId,
- UserName = userName,
- HWId = hwId,
- Denied = denied,
- ServerId = serverId,
- Trust = trust,
- };
- db.SqliteDbContext.ConnectionLog.Add(connectionLog);
- await db.SqliteDbContext.SaveChangesAsync();
- return connectionLog.Id;
- }
- public override async Task<((Admin, string? lastUserName)[] admins, AdminRank[])> GetAllAdminAndRanksAsync(
- CancellationToken cancel)
- {
- await using var db = await GetDbImpl(cancel);
- var admins = await db.SqliteDbContext.Admin
- .Include(a => a.Flags)
- .GroupJoin(db.SqliteDbContext.Player, a => a.UserId, p => p.UserId, (a, grouping) => new {a, grouping})
- .SelectMany(t => t.grouping.DefaultIfEmpty(), (t, p) => new {t.a, p!.LastSeenUserName})
- .ToArrayAsync(cancel);
- var adminRanks = await db.DbContext.AdminRank.Include(a => a.Flags).ToArrayAsync(cancel);
- return (admins.Select(p => (p.a, p.LastSeenUserName)).ToArray(), adminRanks)!;
- }
- protected override IQueryable<AdminLog> StartAdminLogsQuery(ServerDbContext db, LogFilter? filter = null)
- {
- IQueryable<AdminLog> query = db.AdminLog;
- if (filter?.Search != null)
- query = query.Where(log => EF.Functions.Like(log.Message, $"%{filter.Search}%"));
- return query;
- }
- public override async Task<int> AddAdminNote(AdminNote note)
- {
- await using (var db = await GetDb())
- {
- var nextId = 1;
- if (await db.DbContext.AdminNotes.AnyAsync())
- {
- nextId = await db.DbContext.AdminNotes.MaxAsync(adminNote => adminNote.Id) + 1;
- }
- note.Id = nextId;
- }
- return await base.AddAdminNote(note);
- }
- public override async Task<int> AddAdminWatchlist(AdminWatchlist watchlist)
- {
- await using (var db = await GetDb())
- {
- var nextId = 1;
- if (await db.DbContext.AdminWatchlists.AnyAsync())
- {
- nextId = await db.DbContext.AdminWatchlists.MaxAsync(adminWatchlist => adminWatchlist.Id) + 1;
- }
- watchlist.Id = nextId;
- }
- return await base.AddAdminWatchlist(watchlist);
- }
- public override async Task<int> AddAdminMessage(AdminMessage message)
- {
- await using (var db = await GetDb())
- {
- var nextId = 1;
- if (await db.DbContext.AdminMessages.AnyAsync())
- {
- nextId = await db.DbContext.AdminMessages.MaxAsync(adminMessage => adminMessage.Id) + 1;
- }
- message.Id = nextId;
- }
- return await base.AddAdminMessage(message);
- }
- public override Task SendNotification(DatabaseNotification notification)
- {
- // Notifications not implemented on SQLite.
- return Task.CompletedTask;
- }
- protected override DateTime NormalizeDatabaseTime(DateTime time)
- {
- DebugTools.Assert(time.Kind == DateTimeKind.Unspecified);
- return DateTime.SpecifyKind(time, DateTimeKind.Utc);
- }
- private async Task<DbGuardImpl> GetDbImpl(
- CancellationToken cancel = default,
- [CallerMemberName] string? name = null)
- {
- LogDbOp(name);
- await _dbReadyTask;
- if (_msDelay > 0)
- await Task.Delay(_msDelay, cancel);
- await _prefsSemaphore.WaitAsync(cancel);
- var dbContext = new SqliteServerDbContext(_options());
- return new DbGuardImpl(this, dbContext);
- }
- protected override async Task<DbGuard> GetDb(
- CancellationToken cancel = default,
- [CallerMemberName] string? name = null)
- {
- return await GetDbImpl(cancel, name).ConfigureAwait(false);
- }
- private sealed class DbGuardImpl : DbGuard
- {
- private readonly ServerDbSqlite _db;
- private readonly SqliteServerDbContext _ctx;
- public DbGuardImpl(ServerDbSqlite db, SqliteServerDbContext dbContext)
- {
- _db = db;
- _ctx = dbContext;
- }
- public override ServerDbContext DbContext => _ctx;
- public SqliteServerDbContext SqliteDbContext => _ctx;
- public override async ValueTask DisposeAsync()
- {
- await _ctx.DisposeAsync();
- _db._prefsSemaphore.Release();
- }
- }
- private sealed class ConcurrencySemaphore
- {
- private readonly bool _synchronous;
- private readonly SemaphoreSlim _semaphore;
- private Thread? _holdingThread;
- public ConcurrencySemaphore(int maxCount, bool synchronous)
- {
- if (synchronous && maxCount != 1)
- throw new ArgumentException("If synchronous, max concurrency must be 1");
- _synchronous = synchronous;
- _semaphore = new SemaphoreSlim(maxCount, maxCount);
- }
- public Task WaitAsync(CancellationToken cancel = default)
- {
- var task = _semaphore.WaitAsync(cancel);
- if (_synchronous)
- {
- if (!task.IsCompleted)
- {
- if (Thread.CurrentThread == _holdingThread)
- {
- throw new InvalidOperationException(
- "Multiple database requests from same thread on synchronous database!");
- }
- throw new InvalidOperationException(
- $"Different threads trying to access the database at once! " +
- $"Holding thread: {DiagThread(_holdingThread)}, " +
- $"current thread: {DiagThread(Thread.CurrentThread)}");
- }
- _holdingThread = Thread.CurrentThread;
- }
- return task;
- }
- public void Release()
- {
- if (_synchronous)
- {
- if (Thread.CurrentThread != _holdingThread)
- throw new InvalidOperationException("Released on different thread than took lock???");
- _holdingThread = null;
- }
- _semaphore.Release();
- }
- private static string DiagThread(Thread? thread)
- {
- if (thread != null)
- return $"{thread.Name} ({thread.ManagedThreadId})";
- return "<null thread>";
- }
- }
- }
- }
|