Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public async Task<Result> Handle(CompleteLoginCommand command, CancellationToken
user.UpdateLastSeen(timeProvider.GetUtcNow());
userRepository.Update(user);

var userInfo = await userInfoFactory.CreateUserInfoAsync(user, cancellationToken, session.Id);
var userInfo = await userInfoFactory.CreateUserInfoAsync(user, session.Id, cancellationToken);
authenticationTokenService.CreateAndSetAuthenticationTokens(userInfo, session.Id, session.RefreshTokenJti);

events.CollectEvent(new SessionCreated(session.Id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ public async Task<Result> Handle(RefreshAuthenticationTokensCommand command, Can

if (!session.IsRefreshTokenValid(jti, refreshTokenVersion, now))
{
logger.LogWarning("Replay attack detected for session '{SessionId}'. Token JTI '{TokenJti}', current JTI '{CurrentJti}'. Token version '{TokenVersion}', current version '{CurrentVersion}'", session.Id, jti, session.RefreshTokenJti, refreshTokenVersion, session.RefreshTokenVersion);
logger.LogWarning(
"Replay attack detected for session '{SessionId}'. Token JTI '{TokenJti}', current JTI '{CurrentJti}'. Token version '{TokenVersion}', current version '{CurrentVersion}'",
session.Id, jti, session.RefreshTokenJti, refreshTokenVersion, session.RefreshTokenVersion
);
session.Revoke(now, SessionRevokedReason.ReplayAttackDetected);
sessionRepository.Update(session);
events.CollectEvent(new SessionReplayDetected(session.Id, refreshTokenVersion, session.RefreshTokenVersion));
Expand All @@ -105,20 +108,43 @@ public async Task<Result> Handle(RefreshAuthenticationTokensCommand command, Can
return Result.Unauthorized($"No user found with user id '{userId}'.");
}

RefreshTokenJti tokenJti;
int tokenVersion;

if (jti == session.RefreshTokenJti && refreshTokenVersion == session.RefreshTokenVersion)
{
session.Refresh();
sessionRepository.Update(session);

user.UpdateLastSeen(now);
userRepository.Update(user);
// Attempt atomic refresh via isolated connection - only one concurrent request can succeed.
// TryRefreshAsync commits immediately via its own connection, independent of UnitOfWorkPipelineBehavior.
var newJti = RefreshTokenJti.NewId();
var refreshed = await sessionRepository.TryRefreshAsync(session.Id, jti, refreshTokenVersion, newJti, now, cancellationToken);

if (refreshed)
{
// Atomic refresh succeeded - update User.LastSeenAt (committed by UnitOfWorkPipelineBehavior)
user.UpdateLastSeen(now);
userRepository.Update(user);
tokenJti = newJti;
tokenVersion = refreshTokenVersion + 1;
}
else
{
// Concurrent request refreshed session after our fetch - re-fetch for updated values.
// Grace period via PreviousRefreshTokenJti ensures this request still succeeds.
session = await sessionRepository.GetByIdUnfilteredAsync(session.Id, cancellationToken)
?? throw new InvalidOperationException("Session revoked during refresh.");
tokenJti = session.RefreshTokenJti;
tokenVersion = session.RefreshTokenVersion;
}
}
else
{
// Grace period request - token validated via PreviousRefreshTokenJti, use current session values
tokenJti = session.RefreshTokenJti;
tokenVersion = session.RefreshTokenVersion;
}

var userInfo = await userInfoFactory.CreateUserInfoAsync(user, cancellationToken, session.Id);
authenticationTokenService.RefreshAuthenticationTokens(userInfo, session.Id, session.RefreshTokenJti, refreshTokenVersion, refreshTokenExpires);

events.CollectEvent(new SessionRefreshed(session.Id));
events.CollectEvent(new AuthenticationTokensRefreshed());
var userInfo = await userInfoFactory.CreateUserInfoAsync(user, session.Id, cancellationToken);
authenticationTokenService.GenerateAuthenticationTokens(userInfo, session.Id, tokenJti, tokenVersion, refreshTokenExpires);

return Result.Success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public async Task<Result> Handle(SwitchTenantCommand command, CancellationToken
targetUser.UpdateLastSeen(timeProvider.GetUtcNow());
userRepository.Update(targetUser);

var userInfo = await userInfoFactory.CreateUserInfoAsync(targetUser, cancellationToken, session.Id);
authenticationTokenService.CreateAndSetAuthenticationTokens(userInfo, session.Id, session.RefreshTokenJti, currentSession.ExpiresAt);
var userInfo = await userInfoFactory.CreateUserInfoAsync(targetUser, session.Id, cancellationToken);
authenticationTokenService.SwitchTenantAndSetAuthenticationTokens(userInfo, session.Id, session.RefreshTokenJti, currentSession.ExpiresAt);

events.CollectEvent(new SessionCreated(session.Id));
events.CollectEvent(new TenantSwitched(executionContext.TenantId!, command.TenantId, targetUser.Id));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Net;
using JetBrains.Annotations;
using PlatformPlatform.SharedKernel.Authentication.TokenGeneration;
using PlatformPlatform.SharedKernel.Domain;

Expand All @@ -24,10 +25,13 @@ private Session(TenantId tenantId, UserId userId, DeviceType deviceType, string

public UserId UserId { get; private init; }

[UsedImplicitly] // Updated via raw SQL in SessionRepository.TryRefreshAsync to handle concurrent refresh requests atomically
public RefreshTokenJti RefreshTokenJti { get; private set; }

[UsedImplicitly] // Updated via raw SQL in SessionRepository.TryRefreshAsync
public RefreshTokenJti? PreviousRefreshTokenJti { get; private set; }

[UsedImplicitly] // Updated via raw SQL in SessionRepository.TryRefreshAsync
public int RefreshTokenVersion { get; private set; }

public DeviceType DeviceType { get; private init; }
Expand All @@ -52,13 +56,6 @@ public static Session Create(TenantId tenantId, UserId userId, string userAgent,
return new Session(tenantId, userId, deviceType, userAgent, ipAddress.ToString());
}

public void Refresh()
{
PreviousRefreshTokenJti = RefreshTokenJti;
RefreshTokenJti = RefreshTokenJti.NewId();
RefreshTokenVersion++;
}

public void Revoke(DateTimeOffset now, SessionRevokedReason reason)
{
if (IsRevoked) throw new UnreachableException("Session is already revoked.");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Data.Common;
using Microsoft.EntityFrameworkCore;
using PlatformPlatform.AccountManagement.Database;
using PlatformPlatform.SharedKernel.Authentication.TokenGeneration;
Expand All @@ -21,6 +22,12 @@ public interface ISessionRepository : ICrudRepository<Session, SessionId>
/// This method should only be used in the Sessions dialog where users need to see all sessions for their email.
/// </summary>
Task<Session[]> GetActiveSessionsForUsersUnfilteredAsync(UserId[] userIds, CancellationToken cancellationToken);

/// <summary>
/// Attempts to refresh the session token if the current JTI and version match.
/// Returns false if another concurrent request already refreshed the session.
/// </summary>
Task<bool> TryRefreshAsync(SessionId sessionId, RefreshTokenJti currentJti, int currentVersion, RefreshTokenJti newJti, DateTimeOffset now, CancellationToken cancellationToken);
}

public sealed class SessionRepository(AccountManagementDbContext accountManagementDbContext)
Expand All @@ -31,6 +38,43 @@ public sealed class SessionRepository(AccountManagementDbContext accountManageme
return await DbSet.IgnoreQueryFilters().FirstOrDefaultAsync(s => s.Id == sessionId, cancellationToken);
}

/// <summary>
/// Uses an atomic UPDATE via raw ADO.NET with a separate connection to ensure complete isolation.
/// This creates an independent database connection that commits immediately, preventing race conditions
/// when multiple concurrent requests attempt to refresh the same token.
/// </summary>
public async Task<bool> TryRefreshAsync(SessionId sessionId, RefreshTokenJti currentJti, int currentVersion, RefreshTokenJti newJti, DateTimeOffset now, CancellationToken cancellationToken)
{
var existingConnection = accountManagementDbContext.Database.GetDbConnection();

// Create a new connection of the same type to ensure complete isolation from EF Core's transaction.
await using var connection = (DbConnection)Activator.CreateInstance(existingConnection.GetType())!;
connection.ConnectionString = accountManagementDbContext.Database.GetConnectionString();
await connection.OpenAsync(cancellationToken);

await using var command = connection.CreateCommand();
command.CommandText = """
UPDATE Sessions
SET PreviousRefreshTokenJti = RefreshTokenJti,
RefreshTokenJti = @newJti,
RefreshTokenVersion = RefreshTokenVersion + 1,
ModifiedAt = @now
WHERE Id = @sessionId
AND RefreshTokenJti = @currentJti
AND RefreshTokenVersion = @currentVersion
""";

AddParameter(command, "@newJti", newJti.Value);
AddParameter(command, "@now", now.ToString("O"));
AddParameter(command, "@sessionId", sessionId.Value);
AddParameter(command, "@currentJti", currentJti.Value);
AddParameter(command, "@currentVersion", currentVersion);

var rowsAffected = await command.ExecuteNonQueryAsync(cancellationToken);

return rowsAffected == 1;
}

public async Task<Session[]> GetActiveSessionsForUserAsync(UserId userId, CancellationToken cancellationToken)
{
var sessions = await DbSet
Expand All @@ -47,4 +91,12 @@ public async Task<Session[]> GetActiveSessionsForUsersUnfilteredAsync(UserId[] u
.ToArrayAsync(cancellationToken);
return sessions.OrderByDescending(s => s.ModifiedAt ?? s.CreatedAt).ToArray();
}

private static void AddParameter(DbCommand command, string name, object value)
{
var parameter = command.CreateParameter();
parameter.ParameterName = name;
parameter.Value = value;
command.Parameters.Add(parameter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async Task<Result> Handle(CompleteSignupCommand command, CancellationToke
user.UpdateLastSeen(timeProvider.GetUtcNow());
userRepository.Update(user);

var userInfo = await userInfoFactory.CreateUserInfoAsync(user, cancellationToken, session.Id);
var userInfo = await userInfoFactory.CreateUserInfoAsync(user, session.Id, cancellationToken);
authenticationTokenService.CreateAndSetAuthenticationTokens(userInfo, session.Id, session.RefreshTokenJti);

events.CollectEvent(new SessionCreated(session.Id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ namespace PlatformPlatform.AccountManagement.Features;
/// This particular includes the naming of the telemetry events (which should be in past tense) and the properties that
/// are collected with each telemetry event. Since missing or bad data cannot be fixed, it is important to have a good
/// data quality from the start.
public sealed class AuthenticationTokensRefreshed
: TelemetryEvent;

public sealed class EmailConfirmationBlocked(EmailConfirmationId emailConfirmationId, EmailConfirmationType emailConfirmationType, int retryCount)
: TelemetryEvent(("email_confirmation_id", emailConfirmationId), ("email_confirmation_type", emailConfirmationType), ("retry_count", retryCount));

Expand Down Expand Up @@ -47,9 +44,6 @@ public sealed class Logout
public sealed class SessionCreated(SessionId sessionId)
: TelemetryEvent(("session_id", sessionId));

public sealed class SessionRefreshed(SessionId sessionId)
: TelemetryEvent(("session_id", sessionId));

public sealed class SessionReplayDetected(SessionId sessionId, int tokenVersion, int currentVersion)
: TelemetryEvent(("session_id", sessionId), ("token_version", tokenVersion), ("current_version", currentVersion));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ public sealed class UserInfoFactory(ITenantRepository tenantRepository)
/// Creates a UserInfo instance from a User entity, including tenant name.
/// </summary>
/// <param name="user">The user entity</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="sessionId">Optional session ID to include in the UserInfo</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>UserInfo with all required properties including tenant name</returns>
public async Task<UserInfo> CreateUserInfoAsync(User user, CancellationToken cancellationToken, SessionId? sessionId = null)
public async Task<UserInfo> CreateUserInfoAsync(User user, SessionId? sessionId, CancellationToken cancellationToken)
{
var tenant = await tenantRepository.GetByIdAsync(user.TenantId, cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ public async Task RefreshAuthenticationTokens_WhenValidToken_ShouldRefreshAndInc

var updatedVersion = Connection.ExecuteScalar<long>("SELECT RefreshTokenVersion FROM Sessions WHERE Id = @id", [new { id = sessionId.ToString() }]);
updatedVersion.Should().Be(2);

TelemetryEventsCollectorSpy.CollectedEvents.Count.Should().Be(2);
TelemetryEventsCollectorSpy.CollectedEvents[0].GetType().Name.Should().Be("SessionRefreshed");
TelemetryEventsCollectorSpy.CollectedEvents[1].GetType().Name.Should().Be("AuthenticationTokensRefreshed");
TelemetryEventsCollectorSpy.AreAllEventsDispatched.Should().BeTrue();
}

[Fact]
Expand All @@ -69,10 +64,6 @@ public async Task RefreshAuthenticationTokens_WhenPreviousVersionWithinGracePeri

var sessionVersion = Connection.ExecuteScalar<long>("SELECT RefreshTokenVersion FROM Sessions WHERE Id = @id", [new { id = sessionId.ToString() }]);
sessionVersion.Should().Be(2);

TelemetryEventsCollectorSpy.CollectedEvents.Count.Should().Be(2);
TelemetryEventsCollectorSpy.CollectedEvents[0].GetType().Name.Should().Be("SessionRefreshed");
TelemetryEventsCollectorSpy.CollectedEvents[1].GetType().Name.Should().Be("AuthenticationTokensRefreshed");
}

[Fact]
Expand Down Expand Up @@ -140,6 +131,31 @@ public async Task RefreshAuthenticationTokens_WhenSessionNotFound_ShouldReturnUn
TelemetryEventsCollectorSpy.CollectedEvents.Should().BeEmpty();
}

[Fact]
public async Task RefreshAuthenticationTokens_WhenSequentialRequestsWithSameToken_ShouldBothSucceed()
{
// Arrange - simulate grace period scenario where concurrent request already refreshed the session
var jti = RefreshTokenJti.NewId();
var sessionId = SessionId.NewId();
InsertSession(DatabaseSeeder.Tenant1Owner.TenantId, DatabaseSeeder.Tenant1Owner.Id, sessionId, jti, 1);
var userInfo = DatabaseSeeder.Tenant1Owner.Adapt<UserInfo>();
var refreshToken = _refreshTokenGenerator.Generate(userInfo, sessionId, jti);
TelemetryEventsCollectorSpy.Reset();

// Act - First request refreshes the session
var response1 = await SendRefreshRequest(refreshToken);

// Act - Second request with same token should succeed via grace period
var response2 = await SendRefreshRequest(refreshToken);

// Assert
response1.StatusCode.Should().Be(HttpStatusCode.OK);
response2.StatusCode.Should().Be(HttpStatusCode.OK);

var sessionVersion = Connection.ExecuteScalar<long>("SELECT RefreshTokenVersion FROM Sessions WHERE Id = @id", [new { id = sessionId.ToString() }]);
sessionVersion.Should().Be(2);
}

private async Task<HttpResponseMessage> SendRefreshRequest(string refreshToken)
{
var request = new HttpRequestMessage(HttpMethod.Post, "/internal-api/account-management/authentication/refresh-authentication-tokens");
Expand All @@ -151,14 +167,8 @@ private string GenerateRefreshTokenWithVersion(UserInfo userInfo, SessionId sess
{
using var serviceScope = Provider.CreateScope();
var generator = serviceScope.ServiceProvider.GetRequiredService<RefreshTokenGenerator>();
var token = generator.Generate(userInfo, sessionId, jti);

for (var i = 1; i < version; i++)
{
token = generator.Update(userInfo, sessionId, jti, i, TimeProvider.System.GetUtcNow().AddHours(2160));
}

return token;
var expires = TimeProvider.System.GetUtcNow().AddHours(RefreshTokenGenerator.ValidForHours);
return generator.Generate(userInfo, sessionId, jti, version, expires);
}

private void InsertSession(long tenantId, string userId, SessionId sessionId, RefreshTokenJti jti, int version, bool isRevoked = false)
Expand Down
4 changes: 2 additions & 2 deletions application/account-management/Tests/EndpointBaseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ protected EndpointBaseTest()
Services.AddLogging();
Services.AddTransient<DatabaseSeeder>();

// Create connection and add DbContext to the service collection
Connection = new SqliteConnection("DataSource=:memory:");
// Create connection using shared cache mode so isolated connections can access the same in-memory database
Connection = new SqliteConnection($"Data Source=TestDb_{Guid.NewGuid():N};Mode=Memory;Cache=Shared");
Connection.Open();

// Configure SQLite to behave more like SQL Server
Expand Down
4 changes: 2 additions & 2 deletions application/back-office/Tests/EndpointBaseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ protected EndpointBaseTest()
Services.AddLogging();
Services.AddTransient<DatabaseSeeder>();

// Create connection and add DbContext to the service collection
Connection = new SqliteConnection("DataSource=:memory:");
// Create connection using shared cache mode so isolated connections can access the same in-memory database
Connection = new SqliteConnection($"Data Source=TestDb_{Guid.NewGuid():N};Mode=Memory;Cache=Shared");
Connection.Open();

// Configure SQLite to behave more like SQL Server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ public void CreateAndSetAuthenticationTokens(UserInfo userInfo, SessionId sessio
SetAuthenticationTokensOnHttpResponse(refreshToken, accessToken);
}

public void CreateAndSetAuthenticationTokens(UserInfo userInfo, SessionId sessionId, RefreshTokenJti jti, DateTimeOffset expires)
/// <summary>Preserves the original expiry to prevent session lifetime extension through repeated tenant switching.</summary>
public void SwitchTenantAndSetAuthenticationTokens(UserInfo userInfo, SessionId sessionId, RefreshTokenJti jti, DateTimeOffset expires)
{
var refreshToken = refreshTokenGenerator.Generate(userInfo, sessionId, jti, expires);
var refreshToken = refreshTokenGenerator.Generate(userInfo, sessionId, jti, 1, expires);
var accessToken = accessTokenGenerator.Generate(userInfo);
SetAuthenticationTokensOnHttpResponse(refreshToken, accessToken);
}

public void RefreshAuthenticationTokens(UserInfo userInfo, SessionId sessionId, RefreshTokenJti jti, int currentRefreshTokenVersion, DateTimeOffset expires)
/// <summary>Used during token refresh to issue new tokens with incremented version while preserving original expiry.</summary>
public void GenerateAuthenticationTokens(UserInfo userInfo, SessionId sessionId, RefreshTokenJti jti, int refreshTokenVersion, DateTimeOffset expires)
{
var refreshToken = refreshTokenGenerator.Update(userInfo, sessionId, jti, currentRefreshTokenVersion, expires);
var refreshToken = refreshTokenGenerator.Generate(userInfo, sessionId, jti, refreshTokenVersion, expires);
var accessToken = accessTokenGenerator.Generate(userInfo);
SetAuthenticationTokensOnHttpResponse(refreshToken, accessToken);
}
Expand Down
Loading
Loading