ServerDbPostgres.cs 20 KB


  1. using System.Collections.Immutable;
  2. using System.Data;
  3. using System.Diagnostics.CodeAnalysis;
  4. using System.Linq;
  5. using System.Net;
  6. using System.Runtime.CompilerServices;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using Content.Server.Administration.Logs;
  10. using Content.Server.IP;
  11. using Content.Shared.CCVar;
  12. using Content.Shared.Database;
  13. using Microsoft.EntityFrameworkCore;
  14. using Robust.Shared.Configuration;
  15. using Robust.Shared.Network;
  16. using Robust.Shared.Utility;
  17. namespace Content.Server.Database
  18. {
  19. public sealed partial class ServerDbPostgres : ServerDbBase
  20. {
  21. private readonly DbContextOptions<PostgresServerDbContext> _options;
  22. private readonly ISawmill _notifyLog;
  23. private readonly SemaphoreSlim _prefsSemaphore;
  24. private readonly Task _dbReadyTask;
  25. private int _msLag;
  26. public ServerDbPostgres(DbContextOptions<PostgresServerDbContext> options,
  27. string connectionString,
  28. IConfigurationManager cfg,
  29. ISawmill opsLog,
  30. ISawmill notifyLog)
  31. : base(opsLog)
  32. {
  33. var concurrency = cfg.GetCVar(CCVars.DatabasePgConcurrency);
  34. _options = options;
  35. _notifyLog = notifyLog;
  36. _prefsSemaphore = new SemaphoreSlim(concurrency, concurrency);
  37. _dbReadyTask = Task.Run(async () =>
  38. {
  39. await using var ctx = new PostgresServerDbContext(_options);
  40. try
  41. {
  42. await ctx.Database.MigrateAsync();
  43. }
  44. finally
  45. {
  46. await ctx.DisposeAsync();
  47. }
  48. });
  49. cfg.OnValueChanged(CCVars.DatabasePgFakeLag, v => _msLag = v, true);
  50. InitNotificationListener(connectionString);
  51. }
  52. #region Ban
  53. public override async Task<ServerBanDef?> GetServerBanAsync(int id)
  54. {
  55. await using var db = await GetDbImpl();
  56. var query = db.PgDbContext.Ban
  57. .Include(p => p.Unban)
  58. .Where(p => p.Id == id);
  59. var ban = await query.SingleOrDefaultAsync();
  60. return ConvertBan(ban);
  61. }
  62. public override async Task<ServerBanDef?> GetServerBanAsync(
  63. IPAddress? address,
  64. NetUserId? userId,
  65. ImmutableArray<byte>? hwId,
  66. ImmutableArray<ImmutableArray<byte>>? modernHWIds)
  67. {
  68. if (address == null && userId == null && hwId == null)
  69. {
  70. throw new ArgumentException("Address, userId, and hwId cannot all be null");
  71. }
  72. await using var db = await GetDbImpl();
  73. var exempt = await GetBanExemptionCore(db, userId);
  74. var newPlayer = userId == null || !await PlayerRecordExists(db, userId.Value);
  75. var query = MakeBanLookupQuery(address, userId, hwId, modernHWIds, db, includeUnbanned: false, exempt, newPlayer)
  76. .OrderByDescending(b => b.BanTime);
  77. var ban = await query.FirstOrDefaultAsync();
  78. return ConvertBan(ban);
  79. }
  80. public override async Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address,
  81. NetUserId? userId,
  82. ImmutableArray<byte>? hwId,
  83. ImmutableArray<ImmutableArray<byte>>? modernHWIds,
  84. bool includeUnbanned)
  85. {
  86. if (address == null && userId == null && hwId == null)
  87. {
  88. throw new ArgumentException("Address, userId, and hwId cannot all be null");
  89. }
  90. await using var db = await GetDbImpl();
  91. var exempt = await GetBanExemptionCore(db, userId);
  92. var newPlayer = !await db.PgDbContext.Player.AnyAsync(p => p.UserId == userId);
  93. var query = MakeBanLookupQuery(address, userId, hwId, modernHWIds, db, includeUnbanned, exempt, newPlayer);
  94. var queryBans = await query.ToArrayAsync();
  95. var bans = new List<ServerBanDef>(queryBans.Length);
  96. foreach (var ban in queryBans)
  97. {
  98. var banDef = ConvertBan(ban);
  99. if (banDef != null)
  100. {
  101. bans.Add(banDef);
  102. }
  103. }
  104. return bans;
  105. }
  106. private static IQueryable<ServerBan> MakeBanLookupQuery(
  107. IPAddress? address,
  108. NetUserId? userId,
  109. ImmutableArray<byte>? hwId,
  110. ImmutableArray<ImmutableArray<byte>>? modernHWIds,
  111. DbGuardImpl db,
  112. bool includeUnbanned,
  113. ServerBanExemptFlags? exemptFlags,
  114. bool newPlayer)
  115. {
  116. DebugTools.Assert(!(address == null && userId == null && hwId == null));
  117. var query = MakeBanLookupQualityShared<ServerBan, ServerUnban>(
  118. userId,
  119. hwId,
  120. modernHWIds,
  121. db.PgDbContext.Ban);
  122. if (address != null && !exemptFlags.GetValueOrDefault(ServerBanExemptFlags.None).HasFlag(ServerBanExemptFlags.IP))
  123. {
  124. var newQ = db.PgDbContext.Ban
  125. .Include(p => p.Unban)
  126. .Where(b => b.Address != null
  127. && EF.Functions.ContainsOrEqual(b.Address.Value, address)
  128. && !(b.ExemptFlags.HasFlag(ServerBanExemptFlags.BlacklistedRange) && !newPlayer));
  129. query = query == null ? newQ : query.Union(newQ);
  130. }
  131. DebugTools.Assert(
  132. query != null,
  133. "At least one filter item (IP/UserID/HWID) must have been given to make query not null.");
  134. if (!includeUnbanned)
  135. {
  136. query = query.Where(p =>
  137. p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
  138. }
  139. if (exemptFlags is { } exempt)
  140. {
  141. if (exempt != ServerBanExemptFlags.None)
  142. exempt |= ServerBanExemptFlags.BlacklistedRange; // Any kind of exemption should bypass BlacklistedRange
  143. query = query.Where(b => (b.ExemptFlags & exempt) == 0);
  144. }
  145. return query.Distinct();
  146. }
  147. private static IQueryable<TBan>? MakeBanLookupQualityShared<TBan, TUnban>(
  148. NetUserId? userId,
  149. ImmutableArray<byte>? hwId,
  150. ImmutableArray<ImmutableArray<byte>>? modernHWIds,
  151. DbSet<TBan> set)
  152. where TBan : class, IBanCommon<TUnban>
  153. where TUnban : class, IUnbanCommon
  154. {
  155. IQueryable<TBan>? query = null;
  156. if (userId is { } uid)
  157. {
  158. var newQ = set
  159. .Include(p => p.Unban)
  160. .Where(b => b.PlayerUserId == uid.UserId);
  161. query = query == null ? newQ : query.Union(newQ);
  162. }
  163. if (hwId != null && hwId.Value.Length > 0)
  164. {
  165. var newQ = set
  166. .Include(p => p.Unban)
  167. .Where(b => b.HWId!.Type == HwidType.Legacy && b.HWId!.Hwid.SequenceEqual(hwId.Value.ToArray()));
  168. query = query == null ? newQ : query.Union(newQ);
  169. }
  170. if (modernHWIds != null)
  171. {
  172. foreach (var modernHwid in modernHWIds)
  173. {
  174. var newQ = set
  175. .Include(p => p.Unban)
  176. .Where(b => b.HWId!.Type == HwidType.Modern && b.HWId!.Hwid.SequenceEqual(modernHwid.ToArray()));
  177. query = query == null ? newQ : query.Union(newQ);
  178. }
  179. }
  180. return query;
  181. }
  182. private static ServerBanDef? ConvertBan(ServerBan? ban)
  183. {
  184. if (ban == null)
  185. {
  186. return null;
  187. }
  188. NetUserId? uid = null;
  189. if (ban.PlayerUserId is {} guid)
  190. {
  191. uid = new NetUserId(guid);
  192. }
  193. NetUserId? aUid = null;
  194. if (ban.BanningAdmin is {} aGuid)
  195. {
  196. aUid = new NetUserId(aGuid);
  197. }
  198. var unbanDef = ConvertUnban(ban.Unban);
  199. return new ServerBanDef(
  200. ban.Id,
  201. uid,
  202. ban.Address.ToTuple(),
  203. ban.HWId,
  204. ban.BanTime,
  205. ban.ExpirationTime,
  206. ban.RoundId,
  207. ban.PlaytimeAtNote,
  208. ban.Reason,
  209. ban.Severity,
  210. aUid,
  211. unbanDef,
  212. ban.ExemptFlags);
  213. }
  214. private static ServerUnbanDef? ConvertUnban(ServerUnban? unban)
  215. {
  216. if (unban == null)
  217. {
  218. return null;
  219. }
  220. NetUserId? aUid = null;
  221. if (unban.UnbanningAdmin is {} aGuid)
  222. {
  223. aUid = new NetUserId(aGuid);
  224. }
  225. return new ServerUnbanDef(
  226. unban.Id,
  227. aUid,
  228. unban.UnbanTime);
  229. }
  230. public override async Task AddServerBanAsync(ServerBanDef serverBan)
  231. {
  232. await using var db = await GetDbImpl();
  233. db.PgDbContext.Ban.Add(new ServerBan
  234. {
  235. Address = serverBan.Address.ToNpgsqlInet(),
  236. HWId = serverBan.HWId,
  237. Reason = serverBan.Reason,
  238. Severity = serverBan.Severity,
  239. BanningAdmin = serverBan.BanningAdmin?.UserId,
  240. BanTime = serverBan.BanTime.UtcDateTime,
  241. ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
  242. RoundId = serverBan.RoundId,
  243. PlaytimeAtNote = serverBan.PlaytimeAtNote,
  244. PlayerUserId = serverBan.UserId?.UserId,
  245. ExemptFlags = serverBan.ExemptFlags
  246. });
  247. await db.PgDbContext.SaveChangesAsync();
  248. }
  249. public override async Task AddServerUnbanAsync(ServerUnbanDef serverUnban)
  250. {
  251. await using var db = await GetDbImpl();
  252. db.PgDbContext.Unban.Add(new ServerUnban
  253. {
  254. BanId = serverUnban.BanId,
  255. UnbanningAdmin = serverUnban.UnbanningAdmin?.UserId,
  256. UnbanTime = serverUnban.UnbanTime.UtcDateTime
  257. });
  258. await db.PgDbContext.SaveChangesAsync();
  259. }
  260. #endregion
  261. #region Role Ban
  262. public override async Task<ServerRoleBanDef?> GetServerRoleBanAsync(int id)
  263. {
  264. await using var db = await GetDbImpl();
  265. var query = db.PgDbContext.RoleBan
  266. .Include(p => p.Unban)
  267. .Where(p => p.Id == id);
  268. var ban = await query.SingleOrDefaultAsync();
  269. return ConvertRoleBan(ban);
  270. }
  271. public override async Task<List<ServerRoleBanDef>> GetServerRoleBansAsync(IPAddress? address,
  272. NetUserId? userId,
  273. ImmutableArray<byte>? hwId,
  274. ImmutableArray<ImmutableArray<byte>>? modernHWIds,
  275. bool includeUnbanned)
  276. {
  277. if (address == null && userId == null && hwId == null)
  278. {
  279. throw new ArgumentException("Address, userId, and hwId cannot all be null");
  280. }
  281. await using var db = await GetDbImpl();
  282. var query = MakeRoleBanLookupQuery(address, userId, hwId, modernHWIds, db, includeUnbanned)
  283. .OrderByDescending(b => b.BanTime);
  284. return await QueryRoleBans(query);
  285. }
  286. private static async Task<List<ServerRoleBanDef>> QueryRoleBans(IQueryable<ServerRoleBan> query)
  287. {
  288. var queryRoleBans = await query.ToArrayAsync();
  289. var bans = new List<ServerRoleBanDef>(queryRoleBans.Length);
  290. foreach (var ban in queryRoleBans)
  291. {
  292. var banDef = ConvertRoleBan(ban);
  293. if (banDef != null)
  294. {
  295. bans.Add(banDef);
  296. }
  297. }
  298. return bans;
  299. }
  300. private static IQueryable<ServerRoleBan> MakeRoleBanLookupQuery(
  301. IPAddress? address,
  302. NetUserId? userId,
  303. ImmutableArray<byte>? hwId,
  304. ImmutableArray<ImmutableArray<byte>>? modernHWIds,
  305. DbGuardImpl db,
  306. bool includeUnbanned)
  307. {
  308. var query = MakeBanLookupQualityShared<ServerRoleBan, ServerRoleUnban>(
  309. userId,
  310. hwId,
  311. modernHWIds,
  312. db.PgDbContext.RoleBan);
  313. if (address != null)
  314. {
  315. var newQ = db.PgDbContext.RoleBan
  316. .Include(p => p.Unban)
  317. .Where(b => b.Address != null && EF.Functions.ContainsOrEqual(b.Address.Value, address));
  318. query = query == null ? newQ : query.Union(newQ);
  319. }
  320. if (!includeUnbanned)
  321. {
  322. query = query?.Where(p =>
  323. p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
  324. }
  325. query = query!.Distinct();
  326. return query;
  327. }
  328. [return: NotNullIfNotNull(nameof(ban))]
  329. private static ServerRoleBanDef? ConvertRoleBan(ServerRoleBan? ban)
  330. {
  331. if (ban == null)
  332. {
  333. return null;
  334. }
  335. NetUserId? uid = null;
  336. if (ban.PlayerUserId is {} guid)
  337. {
  338. uid = new NetUserId(guid);
  339. }
  340. NetUserId? aUid = null;
  341. if (ban.BanningAdmin is {} aGuid)
  342. {
  343. aUid = new NetUserId(aGuid);
  344. }
  345. var unbanDef = ConvertRoleUnban(ban.Unban);
  346. return new ServerRoleBanDef(
  347. ban.Id,
  348. uid,
  349. ban.Address.ToTuple(),
  350. ban.HWId,
  351. ban.BanTime,
  352. ban.ExpirationTime,
  353. ban.RoundId,
  354. ban.PlaytimeAtNote,
  355. ban.Reason,
  356. ban.Severity,
  357. aUid,
  358. unbanDef,
  359. ban.RoleId);
  360. }
  361. private static ServerRoleUnbanDef? ConvertRoleUnban(ServerRoleUnban? unban)
  362. {
  363. if (unban == null)
  364. {
  365. return null;
  366. }
  367. NetUserId? aUid = null;
  368. if (unban.UnbanningAdmin is {} aGuid)
  369. {
  370. aUid = new NetUserId(aGuid);
  371. }
  372. return new ServerRoleUnbanDef(
  373. unban.Id,
  374. aUid,
  375. unban.UnbanTime);
  376. }
  377. public override async Task<ServerRoleBanDef> AddServerRoleBanAsync(ServerRoleBanDef serverRoleBan)
  378. {
  379. await using var db = await GetDbImpl();
  380. var ban = new ServerRoleBan
  381. {
  382. Address = serverRoleBan.Address.ToNpgsqlInet(),
  383. HWId = serverRoleBan.HWId,
  384. Reason = serverRoleBan.Reason,
  385. Severity = serverRoleBan.Severity,
  386. BanningAdmin = serverRoleBan.BanningAdmin?.UserId,
  387. BanTime = serverRoleBan.BanTime.UtcDateTime,
  388. ExpirationTime = serverRoleBan.ExpirationTime?.UtcDateTime,
  389. RoundId = serverRoleBan.RoundId,
  390. PlaytimeAtNote = serverRoleBan.PlaytimeAtNote,
  391. PlayerUserId = serverRoleBan.UserId?.UserId,
  392. RoleId = serverRoleBan.Role,
  393. };
  394. db.PgDbContext.RoleBan.Add(ban);
  395. await db.PgDbContext.SaveChangesAsync();
  396. return ConvertRoleBan(ban);
  397. }
  398. public override async Task AddServerRoleUnbanAsync(ServerRoleUnbanDef serverRoleUnban)
  399. {
  400. await using var db = await GetDbImpl();
  401. db.PgDbContext.RoleUnban.Add(new ServerRoleUnban
  402. {
  403. BanId = serverRoleUnban.BanId,
  404. UnbanningAdmin = serverRoleUnban.UnbanningAdmin?.UserId,
  405. UnbanTime = serverRoleUnban.UnbanTime.UtcDateTime
  406. });
  407. await db.PgDbContext.SaveChangesAsync();
  408. }
  409. #endregion
  410. public override async Task<int> AddConnectionLogAsync(
  411. NetUserId userId,
  412. string userName,
  413. IPAddress address,
  414. ImmutableTypedHwid? hwId,
  415. float trust,
  416. ConnectionDenyReason? denied,
  417. int serverId)
  418. {
  419. await using var db = await GetDbImpl();
  420. var connectionLog = new ConnectionLog
  421. {
  422. Address = address,
  423. Time = DateTime.UtcNow,
  424. UserId = userId.UserId,
  425. UserName = userName,
  426. HWId = hwId,
  427. Denied = denied,
  428. ServerId = serverId,
  429. Trust = trust,
  430. };
  431. db.PgDbContext.ConnectionLog.Add(connectionLog);
  432. await db.PgDbContext.SaveChangesAsync();
  433. return connectionLog.Id;
  434. }
  435. public override async Task<((Admin, string? lastUserName)[] admins, AdminRank[])>
  436. GetAllAdminAndRanksAsync(CancellationToken cancel)
  437. {
  438. await using var db = await GetDbImpl();
  439. // Honestly this probably doesn't even matter but whatever.
  440. await using var tx =
  441. await db.DbContext.Database.BeginTransactionAsync(IsolationLevel.RepeatableRead, cancel);
  442. // Join with the player table to find their last seen username, if they have one.
  443. var admins = await db.PgDbContext.Admin
  444. .Include(a => a.Flags)
  445. .GroupJoin(db.PgDbContext.Player, a => a.UserId, p => p.UserId, (a, grouping) => new {a, grouping})
  446. .SelectMany(t => t.grouping.DefaultIfEmpty(), (t, p) => new {t.a, p!.LastSeenUserName})
  447. .ToArrayAsync(cancel);
  448. var adminRanks = await db.DbContext.AdminRank.Include(a => a.Flags).ToArrayAsync(cancel);
  449. return (admins.Select(p => (p.a, p.LastSeenUserName)).ToArray(), adminRanks)!;
  450. }
  451. protected override IQueryable<AdminLog> StartAdminLogsQuery(ServerDbContext db, LogFilter? filter = null)
  452. {
  453. // https://learn.microsoft.com/en-us/ef/core/querying/sql-queries#passing-parameters
  454. // Read the link above for parameterization before changing this method or you get the bullet
  455. if (!string.IsNullOrWhiteSpace(filter?.Search))
  456. {
  457. return db.AdminLog.FromSql($"""
  458. SELECT a.admin_log_id, a.round_id, a.date, a.impact, a.json, a.message, a.type FROM admin_log AS a
  459. WHERE to_tsvector('english'::regconfig, a.message) @@ websearch_to_tsquery('english'::regconfig, {filter.Search})
  460. """);
  461. }
  462. return db.AdminLog;
  463. }
  464. protected override DateTime NormalizeDatabaseTime(DateTime time)
  465. {
  466. DebugTools.Assert(time.Kind == DateTimeKind.Utc);
  467. return time;
  468. }
  469. private async Task<DbGuardImpl> GetDbImpl(
  470. CancellationToken cancel = default,
  471. [CallerMemberName] string? name = null)
  472. {
  473. LogDbOp(name);
  474. await _dbReadyTask;
  475. await _prefsSemaphore.WaitAsync(cancel);
  476. if (_msLag > 0)
  477. await Task.Delay(_msLag, cancel);
  478. return new DbGuardImpl(this, new PostgresServerDbContext(_options));
  479. }
  480. protected override async Task<DbGuard> GetDb(
  481. CancellationToken cancel = default,
  482. [CallerMemberName] string? name = null)
  483. {
  484. return await GetDbImpl(cancel, name);
  485. }
  486. private sealed class DbGuardImpl : DbGuard
  487. {
  488. private readonly ServerDbPostgres _db;
  489. public DbGuardImpl(ServerDbPostgres db, PostgresServerDbContext dbC)
  490. {
  491. _db = db;
  492. PgDbContext = dbC;
  493. }
  494. public PostgresServerDbContext PgDbContext { get; }
  495. public override ServerDbContext DbContext => PgDbContext;
  496. public override async ValueTask DisposeAsync()
  497. {
  498. await DbContext.DisposeAsync();
  499. _db._prefsSemaphore.Release();
  500. }
  501. }
  502. }
  503. }