From 2a8a2013237447be2a180d78ec034042dc519538 Mon Sep 17 00:00:00 2001 From: aaronburtle <93220300+aaronburtle@users.noreply.github.com> Date: Thu, 20 Nov 2025 00:06:25 +0000 Subject: [PATCH 1/8] Update variable replacement during deserialization to use replacement settings class and add AKV replacement logic. (#2882) ## Why make this change? Adds AKV variable replacement and expands our design for doing variable replacements to be more extensible when new variable replacement logic is added. Closes #2708 Closes #2748 Related to #2863 ## What is this change? Change the way that variable replacement is handled to instead of simply using a `bool` to indicate that we want env variable replacement, we add a class which holds all of the replacement settings. This will hold whether or not we will do replacement for each kind of variable that we will handle replacement for during deserialization. We also include the replacement failure mode, and put the logic for handling the replacements into a strategy dictionary which pairs the replacement variable type with the strategy for doing that replacement. Because Azure Key Vault secret replacement requires having the retry and connection settings in order to do the AKV replacement, we must do a first pass where we only do non-AKV replacement and get the required settings so that if AKV replacement is used we have the required settings to do that replacement. We also have to keep in mind that the legacy of the `Configuration Controller` will ignore all variable replacement, so we construct the replacement settings for this code path to not use any variable replacement at all. ## How was this tested? We have updated the logic for the tests to use the new system, however manual testing using an actual AKV is still required. ## Sample Request(s) - Example REST and/or GraphQL request to demonstrate modifications - Example of CLI usage to demonstrate modifications --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Aniruddh Munde --- src/Cli.Tests/EndToEndTests.cs | 6 +- src/Cli.Tests/EnvironmentTests.cs | 8 +- src/Cli/ConfigGenerator.cs | 22 +- src/Cli/Exporter.cs | 3 +- src/Config/Azure.DataApiBuilder.Config.csproj | 1 + .../AKVRetryPolicyOptionsConverterFactory.cs | 34 +-- .../AzureKeyVaultOptionsConverterFactory.cs | 128 +++++++++++ .../AzureLogAnalyticsAuthOptionsConverter.cs | 19 +- ...zureLogAnalyticsOptionsConverterFactory.cs | 32 ++- .../Converters/DataSourceConverterFactory.cs | 34 ++- ...DatasourceHealthOptionsConvertorFactory.cs | 16 +- .../EntityCacheOptionsConverterFactory.cs | 22 +- .../EntityGraphQLOptionsConverterFactory.cs | 38 ++-- .../EntityRestOptionsConverterFactory.cs | 34 ++- .../EntitySourceConverterFactory.cs | 30 ++- .../EnumMemberJsonEnumConverterFactory.cs | 2 +- src/Config/Converters/FileSinkConverter.cs | 21 +- .../GraphQLRuntimeOptionsConverterFactory.cs | 28 +-- .../McpRuntimeOptionsConverterFactory.cs | 16 +- .../RuntimeHealthOptionsConvertorFactory.cs | 16 +- .../Converters/StringJsonConverterFactory.cs | 73 ++---- .../Converters/Utf8JsonReaderExtensions.cs | 10 +- ...erializationVariableReplacementSettings.cs | 213 ++++++++++++++++++ src/Config/FileSystemRuntimeConfigLoader.cs | 32 ++- .../ObjectModel/AzureKeyVaultOptions.cs | 37 +++ src/Config/ObjectModel/RuntimeConfig.cs | 7 +- src/Config/RuntimeConfigLoader.cs | 122 +++++++--- .../Configurations/RuntimeConfigProvider.cs | 9 +- src/Directory.Packages.props | 1 + .../Caching/CachingConfigProcessingTests.cs | 41 +--- .../Configuration/ConfigurationTests.cs | 41 ++-- .../UnitTests/MySqlQueryExecutorUnitTests.cs | 3 +- .../PostgreSqlQueryExecutorUnitTests.cs | 3 +- ...untimeConfigLoaderJsonDeserializerTests.cs | 17 +- .../SerializationDeserializationTests.cs | 66 +++++- .../UnitTests/SqlQueryExecutorUnitTests.cs | 3 +- .../Controllers/ConfigurationController.cs | 4 +- 37 files changed, 835 insertions(+), 357 deletions(-) create mode 100644 src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs create mode 100644 src/Config/DeserializationVariableReplacementSettings.cs diff --git a/src/Cli.Tests/EndToEndTests.cs b/src/Cli.Tests/EndToEndTests.cs index 7fe017501f..5dbf97ca5e 100644 --- a/src/Cli.Tests/EndToEndTests.cs +++ b/src/Cli.Tests/EndToEndTests.cs @@ -116,10 +116,11 @@ public void TestInitializingRestAndGraphQLGlobalSettings() string[] args = { "init", "-c", TEST_RUNTIME_CONFIG_FILE, "--connection-string", SAMPLE_TEST_CONN_STRING, "--database-type", "mssql", "--rest.path", "/rest-api", "--rest.enabled", "false", "--graphql.path", "/graphql-api" }; Program.Execute(args, _cliLogger!, _fileSystem!, _runtimeConfigLoader!); + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig( TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? runtimeConfig, - replaceEnvVar: true)); + replacementSettings: replacementSettings)); SqlConnectionStringBuilder builder = new(runtimeConfig.DataSource.ConnectionString); Assert.AreEqual(ProductInfo.GetDataApiBuilderUserAgent(), builder.ApplicationName); @@ -195,10 +196,11 @@ public void TestEnablingMultipleCreateOperation(CliBool isMultipleCreateEnabled, Program.Execute(args.ToArray(), _cliLogger!, _fileSystem!, _runtimeConfigLoader!); + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); Assert.IsTrue(_runtimeConfigLoader!.TryLoadConfig( TEST_RUNTIME_CONFIG_FILE, out RuntimeConfig? runtimeConfig, - replaceEnvVar: true)); + replacementSettings: replacementSettings)); Assert.IsNotNull(runtimeConfig); Assert.AreEqual(expectedDbType, runtimeConfig.DataSource.DatabaseType); diff --git a/src/Cli.Tests/EnvironmentTests.cs b/src/Cli.Tests/EnvironmentTests.cs index 151d5babb2..2d6378cf74 100644 --- a/src/Cli.Tests/EnvironmentTests.cs +++ b/src/Cli.Tests/EnvironmentTests.cs @@ -19,7 +19,13 @@ public class EnvironmentTests [TestInitialize] public void TestInitialize() { - StringJsonConverterFactory converterFactory = new(EnvironmentVariableReplacementFailureMode.Throw); + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: false, + envFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + + StringJsonConverterFactory converterFactory = new(replacementSettings); _options = new() { PropertyNameCaseInsensitive = true diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 9a56f83c4a..7c35335089 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -2700,9 +2700,10 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( // Azure Key Vault Endpoint if (options.AzureKeyVaultEndpoint is not null) { + // Ensure endpoint flag is marked user provided so converter writes it. updatedAzureKeyVaultOptions = updatedAzureKeyVaultOptions is not null - ? updatedAzureKeyVaultOptions with { Endpoint = options.AzureKeyVaultEndpoint } - : new AzureKeyVaultOptions { Endpoint = options.AzureKeyVaultEndpoint }; + ? updatedAzureKeyVaultOptions with { Endpoint = options.AzureKeyVaultEndpoint, UserProvidedEndpoint = true } + : new AzureKeyVaultOptions(endpoint: options.AzureKeyVaultEndpoint); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.endpoint as '{endpoint}'", options.AzureKeyVaultEndpoint); } @@ -2711,7 +2712,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( { updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { Mode = options.AzureKeyVaultRetryPolicyMode.Value, UserProvidedMode = true } - : new AKVRetryPolicyOptions { Mode = options.AzureKeyVaultRetryPolicyMode.Value, UserProvidedMode = true }; + : new AKVRetryPolicyOptions(mode: options.AzureKeyVaultRetryPolicyMode.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.mode as '{mode}'", options.AzureKeyVaultRetryPolicyMode.Value); } @@ -2726,7 +2727,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { MaxCount = options.AzureKeyVaultRetryPolicyMaxCount.Value, UserProvidedMaxCount = true } - : new AKVRetryPolicyOptions { MaxCount = options.AzureKeyVaultRetryPolicyMaxCount.Value, UserProvidedMaxCount = true }; + : new AKVRetryPolicyOptions(maxCount: options.AzureKeyVaultRetryPolicyMaxCount.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.max-count as '{maxCount}'", options.AzureKeyVaultRetryPolicyMaxCount.Value); } @@ -2741,7 +2742,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { DelaySeconds = options.AzureKeyVaultRetryPolicyDelaySeconds.Value, UserProvidedDelaySeconds = true } - : new AKVRetryPolicyOptions { DelaySeconds = options.AzureKeyVaultRetryPolicyDelaySeconds.Value, UserProvidedDelaySeconds = true }; + : new AKVRetryPolicyOptions(delaySeconds: options.AzureKeyVaultRetryPolicyDelaySeconds.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.delay-seconds as '{delaySeconds}'", options.AzureKeyVaultRetryPolicyDelaySeconds.Value); } @@ -2756,7 +2757,7 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { MaxDelaySeconds = options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value, UserProvidedMaxDelaySeconds = true } - : new AKVRetryPolicyOptions { MaxDelaySeconds = options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value, UserProvidedMaxDelaySeconds = true }; + : new AKVRetryPolicyOptions(maxDelaySeconds: options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.max-delay-seconds as '{maxDelaySeconds}'", options.AzureKeyVaultRetryPolicyMaxDelaySeconds.Value); } @@ -2771,16 +2772,17 @@ private static bool TryUpdateConfiguredAzureKeyVaultOptions( updatedRetryPolicyOptions = updatedRetryPolicyOptions is not null ? updatedRetryPolicyOptions with { NetworkTimeoutSeconds = options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value, UserProvidedNetworkTimeoutSeconds = true } - : new AKVRetryPolicyOptions { NetworkTimeoutSeconds = options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value, UserProvidedNetworkTimeoutSeconds = true }; + : new AKVRetryPolicyOptions(networkTimeoutSeconds: options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value); _logger.LogInformation("Updated RuntimeConfig with azure-key-vault.retry-policy.network-timeout-seconds as '{networkTimeoutSeconds}'", options.AzureKeyVaultRetryPolicyNetworkTimeoutSeconds.Value); } - // Update Azure Key Vault options with retry policy if retry policy was modified + // Update Azure Key Vault options with retry policy if modified if (updatedRetryPolicyOptions is not null) { + // Ensure outer AKV object marks retry policy as user provided so it serializes. updatedAzureKeyVaultOptions = updatedAzureKeyVaultOptions is not null - ? updatedAzureKeyVaultOptions with { RetryPolicy = updatedRetryPolicyOptions } - : new AzureKeyVaultOptions { RetryPolicy = updatedRetryPolicyOptions }; + ? updatedAzureKeyVaultOptions with { RetryPolicy = updatedRetryPolicyOptions, UserProvidedRetryPolicy = true } + : new AzureKeyVaultOptions(retryPolicy: updatedRetryPolicyOptions); } // Update runtime config if Azure Key Vault options were modified diff --git a/src/Cli/Exporter.cs b/src/Cli/Exporter.cs index d4f103e868..896b485692 100644 --- a/src/Cli/Exporter.cs +++ b/src/Cli/Exporter.cs @@ -44,7 +44,8 @@ public static bool Export(ExportOptions options, ILogger logger, FileSystemRunti } // Load the runtime configuration from the file - if (!loader.TryLoadConfig(runtimeConfigFile, out RuntimeConfig? runtimeConfig, replaceEnvVar: true)) + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); + if (!loader.TryLoadConfig(runtimeConfigFile, out RuntimeConfig? runtimeConfig, replacementSettings: replacementSettings)) { logger.LogError("Failed to read the config file: {0}.", runtimeConfigFile); return false; diff --git a/src/Config/Azure.DataApiBuilder.Config.csproj b/src/Config/Azure.DataApiBuilder.Config.csproj index a494bc38ae..6b5bdf0955 100644 --- a/src/Config/Azure.DataApiBuilder.Config.csproj +++ b/src/Config/Azure.DataApiBuilder.Config.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs b/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs index 06d00b64d3..553e43db53 100644 --- a/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs +++ b/src/Config/Converters/AKVRetryPolicyOptionsConverterFactory.cs @@ -12,9 +12,9 @@ namespace Azure.DataApiBuilder.Config.Converters; /// internal class AKVRetryPolicyOptionsConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + // Currently allows for Azure Key Vault (via @akv('secret-name')) and Environment Variable replacement. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,34 +25,34 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new AKVRetryPolicyOptionsConverter(_replaceEnvVar); + return new AKVRetryPolicyOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal AKVRetryPolicyOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal AKVRetryPolicyOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class AKVRetryPolicyOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + // Currently allows for Azure Key Vault (via @akv('')) and Environment Variable replacement. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public AKVRetryPolicyOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public AKVRetryPolicyOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// /// Defines how DAB reads AKV Retry Policy options and defines which values are /// used to instantiate those options. /// - /// Thrown when improperly formatted cache options are provided. + /// Thrown when improperly formatted retry policy options are provided. public override AKVRetryPolicyOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { if (reader.TokenType is JsonTokenType.StartObject) @@ -82,7 +82,7 @@ public AKVRetryPolicyOptionsConverter(bool replaceEnvVar) } else { - mode = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + mode = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); } break; diff --git a/src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs b/src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs new file mode 100644 index 0000000000..92ed0c1a85 --- /dev/null +++ b/src/Config/Converters/AzureKeyVaultOptionsConverterFactory.cs @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel; + +namespace Azure.DataApiBuilder.Config.Converters; + +/// +/// Converter factory for AzureKeyVaultOptions that can optionally perform variable replacement. +/// +internal class AzureKeyVaultOptionsConverterFactory : JsonConverterFactory +{ + // Determines whether to replace environment variable with its + // value or not while deserializing. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + /// How to handle variable replacement during deserialization. + internal AzureKeyVaultOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) + { + _replacementSettings = replacementSettings; + } + + /// + public override bool CanConvert(Type typeToConvert) + { + return typeToConvert.IsAssignableTo(typeof(AzureKeyVaultOptions)); + } + + /// + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return new AzureKeyVaultOptionsConverter(_replacementSettings); + } + + private class AzureKeyVaultOptionsConverter : JsonConverter + { + // Determines whether to replace environment variable with its + // value or not while deserializing. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + /// Whether to replace environment variable with its + /// value or not while deserializing. + public AzureKeyVaultOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) + { + _replacementSettings = replacementSettings; + } + + /// + /// Reads AzureKeyVaultOptions with optional variable replacement. + /// + public override AzureKeyVaultOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType is JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType is JsonTokenType.StartObject) + { + string? endpoint = null; + AKVRetryPolicyOptions? retryPolicy = null; + + while (reader.Read()) + { + if (reader.TokenType is JsonTokenType.EndObject) + { + return new AzureKeyVaultOptions(endpoint, retryPolicy); + } + + string? property = reader.GetString(); + reader.Read(); + + switch (property) + { + case "endpoint": + if (reader.TokenType is JsonTokenType.String) + { + endpoint = reader.DeserializeString(_replacementSettings); + } + + break; + + case "retry-policy": + if (reader.TokenType is JsonTokenType.StartObject) + { + // Uses the AKVRetryPolicyOptionsConverter to read the retry-policy object. + retryPolicy = JsonSerializer.Deserialize(ref reader, options); + } + + break; + + default: + throw new JsonException($"Unexpected property {property}"); + } + } + } + + throw new JsonException("Invalid AzureKeyVaultOptions format"); + } + + /// + /// When writing the AzureKeyVaultOptions back to a JSON file, only write the properties + /// if they are user provided. This avoids polluting the written JSON file with properties + /// the user most likely omitted when writing the original DAB runtime config file. + /// This Write operation is only used when a RuntimeConfig object is serialized to JSON. + /// + public override void Write(Utf8JsonWriter writer, AzureKeyVaultOptions value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + if (value?.UserProvidedEndpoint is true) + { + writer.WritePropertyName("endpoint"); + JsonSerializer.Serialize(writer, value.Endpoint, options); + } + + if (value?.UserProvidedRetryPolicy is true) + { + writer.WritePropertyName("retry-policy"); + JsonSerializer.Serialize(writer, value.RetryPolicy, options); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs b/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs index 1428c0d75f..d4b7623aa2 100644 --- a/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs +++ b/src/Config/Converters/AzureLogAnalyticsAuthOptionsConverter.cs @@ -9,15 +9,14 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class AzureLogAnalyticsAuthOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public AzureLogAnalyticsAuthOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -48,7 +47,7 @@ public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) case "custom-table-name": if (reader.TokenType is not JsonTokenType.Null) { - customTableName = reader.DeserializeString(_replaceEnvVar); + customTableName = reader.DeserializeString(_replacementSettings); } break; @@ -56,7 +55,7 @@ public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) case "dcr-immutable-id": if (reader.TokenType is not JsonTokenType.Null) { - dcrImmutableId = reader.DeserializeString(_replaceEnvVar); + dcrImmutableId = reader.DeserializeString(_replacementSettings); } break; @@ -64,7 +63,7 @@ public AzureLogAnalyticsAuthOptionsConverter(bool replaceEnvVar) case "dce-endpoint": if (reader.TokenType is not JsonTokenType.Null) { - dceEndpoint = reader.DeserializeString(_replaceEnvVar); + dceEndpoint = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs b/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs index 3fcbe8c7bd..fc7c72d655 100644 --- a/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs +++ b/src/Config/Converters/AzureLogAnalyticsOptionsConverterFactory.cs @@ -12,9 +12,8 @@ namespace Azure.DataApiBuilder.Config.Converters; /// internal class AzureLogAnalyticsOptionsConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,27 +24,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new AzureLogAnalyticsOptionsConverter(_replaceEnvVar); + return new AzureLogAnalyticsOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal AzureLogAnalyticsOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal AzureLogAnalyticsOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class AzureLogAnalyticsOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal AzureLogAnalyticsOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal AzureLogAnalyticsOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -57,7 +55,7 @@ internal AzureLogAnalyticsOptionsConverter(bool replaceEnvVar) { if (reader.TokenType is JsonTokenType.StartObject) { - AzureLogAnalyticsAuthOptionsConverter authOptionsConverter = new(_replaceEnvVar); + AzureLogAnalyticsAuthOptionsConverter authOptionsConverter = new(_replacementSettings); bool? enabled = null; AzureLogAnalyticsAuthOptions? auth = null; @@ -91,7 +89,7 @@ internal AzureLogAnalyticsOptionsConverter(bool replaceEnvVar) case "dab-identifier": if (reader.TokenType is not JsonTokenType.Null) { - logType = reader.DeserializeString(_replaceEnvVar); + logType = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/DataSourceConverterFactory.cs b/src/Config/Converters/DataSourceConverterFactory.cs index dabbee405e..1788ebf2b4 100644 --- a/src/Config/Converters/DataSourceConverterFactory.cs +++ b/src/Config/Converters/DataSourceConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class DataSourceConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new DataSourceConverter(_replaceEnvVar); + return new DataSourceConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal DataSourceConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal DataSourceConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class DataSourceConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public DataSourceConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public DataSourceConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } public override DataSource? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) @@ -69,11 +67,11 @@ public DataSourceConverter(bool replaceEnvVar) switch (propertyName) { case "database-type": - databaseType = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + databaseType = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); break; case "connection-string": - connectionString = reader.DeserializeString(replaceEnvVar: _replaceEnvVar)!; + connectionString = reader.DeserializeString(_replacementSettings)!; break; case "health": @@ -106,7 +104,7 @@ public DataSourceConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String) { // Determine whether to resolve the environment variable or keep as-is. - string stringValue = reader.DeserializeString(replaceEnvVar: _replaceEnvVar)!; + string stringValue = reader.DeserializeString(_replacementSettings)!; if (bool.TryParse(stringValue, out bool boolValue)) { diff --git a/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs b/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs index 52272c57a7..d8286ff7a0 100644 --- a/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs +++ b/src/Config/Converters/DatasourceHealthOptionsConvertorFactory.cs @@ -11,7 +11,7 @@ internal class DataSourceHealthOptionsConvertorFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +22,27 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new HealthCheckOptionsConverter(_replaceEnvVar); + return new HealthCheckOptionsConverter(_replacementSettings); } /// Whether to replace environment variable with its /// value or not while deserializing. - internal DataSourceHealthOptionsConvertorFactory(bool replaceEnvVar) + internal DataSourceHealthOptionsConvertorFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class HealthCheckOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// Whether to replace environment variable with its /// value or not while deserializing. - public HealthCheckOptionsConverter(bool replaceEnvVar) + public HealthCheckOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -85,7 +85,7 @@ public HealthCheckOptionsConverter(bool replaceEnvVar) case "name": if (reader.TokenType is not JsonTokenType.Null) { - name = reader.DeserializeString(_replaceEnvVar); + name = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/EntityCacheOptionsConverterFactory.cs b/src/Config/Converters/EntityCacheOptionsConverterFactory.cs index 32a616ab81..641efd062f 100644 --- a/src/Config/Converters/EntityCacheOptionsConverterFactory.cs +++ b/src/Config/Converters/EntityCacheOptionsConverterFactory.cs @@ -14,7 +14,7 @@ internal class EntityCacheOptionsConverterFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,27 +25,25 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntityCacheOptionsConverter(_replaceEnvVar); + return new EntityCacheOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntityCacheOptionsConverterFactory(bool replaceEnvVar) + /// The replacement settings to use while deserializing. + internal EntityCacheOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class EntityCacheOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntityCacheOptionsConverter(bool replaceEnvVar) + /// The replacement settings to use while deserializing. + public EntityCacheOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -110,7 +108,7 @@ public EntityCacheOptionsConverter(bool replaceEnvVar) throw new JsonException("level property cannot be null."); } - level = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + level = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); break; } diff --git a/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs b/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs index 576850b1cb..abe094e970 100644 --- a/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs +++ b/src/Config/Converters/EntityGraphQLOptionsConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class EntityGraphQLOptionsConverterFactory : JsonConverterFactory { - /// Determines whether to replace environment variable with its - /// value or not while deserializing. - private bool _replaceEnvVar; + /// Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntityGraphQLOptionsConverter(_replaceEnvVar); + return new EntityGraphQLOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntityGraphQLOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal EntityGraphQLOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class EntityGraphQLOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntityGraphQLOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public EntityGraphQLOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -73,7 +71,7 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) case "type": if (reader.TokenType is JsonTokenType.String) { - singular = reader.DeserializeString(_replaceEnvVar) ?? string.Empty; + singular = reader.DeserializeString(_replacementSettings) ?? string.Empty; } else if (reader.TokenType is JsonTokenType.StartObject) { @@ -95,10 +93,10 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) switch (property2) { case "singular": - singular = reader.DeserializeString(_replaceEnvVar) ?? string.Empty; + singular = reader.DeserializeString(_replacementSettings) ?? string.Empty; break; case "plural": - plural = reader.DeserializeString(_replaceEnvVar) ?? string.Empty; + plural = reader.DeserializeString(_replacementSettings) ?? string.Empty; break; } } @@ -112,7 +110,7 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) break; case "operation": - string? op = reader.DeserializeString(_replaceEnvVar); + string? op = reader.DeserializeString(_replacementSettings); if (op is not null) { @@ -136,7 +134,7 @@ public EntityGraphQLOptionsConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String) { - string? singular = reader.DeserializeString(_replaceEnvVar); + string? singular = reader.DeserializeString(_replacementSettings); return new EntityGraphQLOptions(singular ?? string.Empty, string.Empty); } diff --git a/src/Config/Converters/EntityRestOptionsConverterFactory.cs b/src/Config/Converters/EntityRestOptionsConverterFactory.cs index cc33943caa..f8c9096673 100644 --- a/src/Config/Converters/EntityRestOptionsConverterFactory.cs +++ b/src/Config/Converters/EntityRestOptionsConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class EntityRestOptionsConverterFactory : JsonConverterFactory { - /// Determines whether to replace environment variable with its - /// value or not while deserializing. - private bool _replaceEnvVar; + /// Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,27 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntityRestOptionsConverter(_replaceEnvVar); + return new EntityRestOptionsConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntityRestOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal EntityRestOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } internal class EntityRestOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntityRestOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public EntityRestOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -67,7 +65,7 @@ public EntityRestOptionsConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String || reader.TokenType is JsonTokenType.Null) { - restOptions = restOptions with { Path = reader.DeserializeString(_replaceEnvVar) }; + restOptions = restOptions with { Path = reader.DeserializeString(_replacementSettings) }; break; } @@ -87,7 +85,7 @@ public EntityRestOptionsConverter(bool replaceEnvVar) break; } - methods.Add(EnumExtensions.Deserialize(reader.DeserializeString(replaceEnvVar: true)!)); + methods.Add(EnumExtensions.Deserialize(reader.DeserializeString(new DeserializationVariableReplacementSettings())!)); } restOptions = restOptions with { Methods = methods.ToArray() }; @@ -107,7 +105,7 @@ public EntityRestOptionsConverter(bool replaceEnvVar) if (reader.TokenType is JsonTokenType.String) { - return new EntityRestOptions(Array.Empty(), reader.DeserializeString(_replaceEnvVar), true); + return new EntityRestOptions(Array.Empty(), reader.DeserializeString(_replacementSettings), true); } if (reader.TokenType is JsonTokenType.True || reader.TokenType is JsonTokenType.False) diff --git a/src/Config/Converters/EntitySourceConverterFactory.cs b/src/Config/Converters/EntitySourceConverterFactory.cs index a748382e01..2edafe31e1 100644 --- a/src/Config/Converters/EntitySourceConverterFactory.cs +++ b/src/Config/Converters/EntitySourceConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class EntitySourceConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,34 +21,33 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EntitySourceConverter(_replaceEnvVar); + return new EntitySourceConverter(_replacementSettings); } - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal EntitySourceConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal EntitySourceConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class EntitySourceConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - public EntitySourceConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public EntitySourceConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } public override EntitySource? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { if (reader.TokenType == JsonTokenType.String) { - string? obj = reader.DeserializeString(_replaceEnvVar); + string? obj = reader.DeserializeString(_replacementSettings); return new EntitySource(obj ?? string.Empty, EntitySourceType.Table, new(), Array.Empty()); } diff --git a/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs b/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs index 1d6dd9f7c4..4455a474e1 100644 --- a/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs +++ b/src/Config/Converters/EnumMemberJsonEnumConverterFactory.cs @@ -114,7 +114,7 @@ public JsonStringEnumConverterEx() public override TEnum Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { // Always replace env variable in case of Enum otherwise string to enum conversion will fail. - string? stringValue = reader.DeserializeString(replaceEnvVar: true); + string? stringValue = reader.DeserializeString(new(doReplaceEnvVar: true)); if (stringValue == null) { diff --git a/src/Config/Converters/FileSinkConverter.cs b/src/Config/Converters/FileSinkConverter.cs index cc7d138a1b..4299fb913b 100644 --- a/src/Config/Converters/FileSinkConverter.cs +++ b/src/Config/Converters/FileSinkConverter.cs @@ -7,18 +7,17 @@ using Serilog; namespace Azure.DataApiBuilder.Config.Converters; + class FileSinkConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; - - /// - /// Whether to replace environment variable with its value or not while deserializing. - /// - public FileSinkConverter(bool replaceEnvVar) + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + public FileSinkConverter(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -59,7 +58,7 @@ public FileSinkConverter(bool replaceEnvVar) case "path": if (reader.TokenType is not JsonTokenType.Null) { - path = reader.DeserializeString(_replaceEnvVar); + path = reader.DeserializeString(_replacementSettings); } break; @@ -67,7 +66,7 @@ public FileSinkConverter(bool replaceEnvVar) case "rolling-interval": if (reader.TokenType is not JsonTokenType.Null) { - rollingInterval = EnumExtensions.Deserialize(reader.DeserializeString(_replaceEnvVar)!); + rollingInterval = EnumExtensions.Deserialize(reader.DeserializeString(_replacementSettings)!); } break; diff --git a/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs b/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs index 082c982e7e..109caef0d5 100644 --- a/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs +++ b/src/Config/Converters/GraphQLRuntimeOptionsConverterFactory.cs @@ -9,9 +9,8 @@ namespace Azure.DataApiBuilder.Config.Converters; internal class GraphQLRuntimeOptionsConverterFactory : JsonConverterFactory { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,25 +21,26 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new GraphQLRuntimeOptionsConverter(_replaceEnvVar); + return new GraphQLRuntimeOptionsConverter(_replacementSettings); } - internal GraphQLRuntimeOptionsConverterFactory(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal GraphQLRuntimeOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class GraphQLRuntimeOptionsConverter : JsonConverter { - // Determines whether to replace environment variable with its - // value or not while deserializing. - private bool _replaceEnvVar; + // Settings for variable replacement during deserialization. + private readonly DeserializationVariableReplacementSettings? _replacementSettings; - /// Whether to replace environment variable with its - /// value or not while deserializing. - internal GraphQLRuntimeOptionsConverter(bool replaceEnvVar) + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. + internal GraphQLRuntimeOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } public override GraphQLRuntimeOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) @@ -117,7 +117,7 @@ internal GraphQLRuntimeOptionsConverter(bool replaceEnvVar) case "path": if (reader.TokenType is JsonTokenType.String) { - string? path = reader.DeserializeString(_replaceEnvVar); + string? path = reader.DeserializeString(_replacementSettings); if (path is null) { path = "/graphql"; diff --git a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs index db9acfa603..d75cbbef5a 100644 --- a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs +++ b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs @@ -14,7 +14,7 @@ internal class McpRuntimeOptionsConverterFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -25,25 +25,25 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new McpRuntimeOptionsConverter(_replaceEnvVar); + return new McpRuntimeOptionsConverter(_replacementSettings); } - internal McpRuntimeOptionsConverterFactory(bool replaceEnvVar) + internal McpRuntimeOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class McpRuntimeOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// Whether to replace environment variable with its /// value or not while deserializing. - internal McpRuntimeOptionsConverter(bool replaceEnvVar) + internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -89,7 +89,7 @@ internal McpRuntimeOptionsConverter(bool replaceEnvVar) case "path": if (reader.TokenType is not JsonTokenType.Null) { - path = reader.DeserializeString(_replaceEnvVar); + path = reader.DeserializeString(_replacementSettings); } break; diff --git a/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs b/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs index d49cc264e7..9c5f46dce2 100644 --- a/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs +++ b/src/Config/Converters/RuntimeHealthOptionsConvertorFactory.cs @@ -11,7 +11,7 @@ internal class RuntimeHealthOptionsConvertorFactory : JsonConverterFactory { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// public override bool CanConvert(Type typeToConvert) @@ -22,25 +22,25 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new HealthCheckOptionsConverter(_replaceEnvVar); + return new HealthCheckOptionsConverter(_replacementSettings); } - internal RuntimeHealthOptionsConvertorFactory(bool replaceEnvVar) + internal RuntimeHealthOptionsConvertorFactory(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } private class HealthCheckOptionsConverter : JsonConverter { // Determines whether to replace environment variable with its // value or not while deserializing. - private bool _replaceEnvVar; + private readonly DeserializationVariableReplacementSettings? _replacementSettings; /// Whether to replace environment variable with its /// value or not while deserializing. - internal HealthCheckOptionsConverter(bool replaceEnvVar) + internal HealthCheckOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) { - _replaceEnvVar = replaceEnvVar; + _replacementSettings = replacementSettings; } /// @@ -102,7 +102,7 @@ internal HealthCheckOptionsConverter(bool replaceEnvVar) { if (reader.TokenType == JsonTokenType.String) { - string? currentRole = reader.DeserializeString(_replaceEnvVar); + string? currentRole = reader.DeserializeString(_replacementSettings); if (!string.IsNullOrEmpty(currentRole)) { stringList.Add(currentRole); diff --git a/src/Config/Converters/StringJsonConverterFactory.cs b/src/Config/Converters/StringJsonConverterFactory.cs index 078b611789..c3f5333237 100644 --- a/src/Config/Converters/StringJsonConverterFactory.cs +++ b/src/Config/Converters/StringJsonConverterFactory.cs @@ -4,21 +4,20 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; -using Azure.DataApiBuilder.Service.Exceptions; namespace Azure.DataApiBuilder.Config.Converters; /// -/// Custom string json converter factory to replace environment variables of the pattern -/// @env('ENV_NAME') with their value during deserialization. +/// Custom string json converter factory to replace environment variables and other variable patterns +/// during deserialization using the DeserializationVariableReplacementSettings. /// public class StringJsonConverterFactory : JsonConverterFactory { - private EnvironmentVariableReplacementFailureMode _replacementFailureMode; + private readonly DeserializationVariableReplacementSettings _replacementSettings; - public StringJsonConverterFactory(EnvironmentVariableReplacementFailureMode replacementFailureMode) + public StringJsonConverterFactory(DeserializationVariableReplacementSettings replacementSettings) { - _replacementFailureMode = replacementFailureMode; + _replacementSettings = replacementSettings; } public override bool CanConvert(Type typeToConvert) @@ -28,32 +27,16 @@ public override bool CanConvert(Type typeToConvert) public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new StringJsonConverter(_replacementFailureMode); + return new StringJsonConverter(_replacementSettings); } class StringJsonConverter : JsonConverter { - // @env\(' : match @env(' - // .*? : lazy match any character except newline 0 or more times - // (?='\)) : look ahead for ') which will combine with our lazy match - // ie: in @env('hello')goodbye') we match @env('hello') - // '\) : consume the ') into the match (look ahead doesn't capture) - // This pattern lazy matches any string that starts with @env(' and ends with ') - // ie: fooBAR@env('hello-world')bash)FOO') match: @env('hello-world') - // This matching pattern allows for the @env('') to be safely nested - // within strings that contain ') after our match. - // ie: if the environment variable "Baz" has the value of "Bar" - // fooBarBaz: "('foo@env('Baz')Baz')" would parse into - // fooBarBaz: "('fooBarBaz')" - // Note that there is no escape character currently for ') to exist - // within the name of the environment variable, but that ') is not - // a valid environment variable name in certain shells. - const string ENV_PATTERN = @"@env\('.*?(?='\))'\)"; - private EnvironmentVariableReplacementFailureMode _replacementFailureMode; + private DeserializationVariableReplacementSettings _replacementSettings; - public StringJsonConverter(EnvironmentVariableReplacementFailureMode replacementFailureMode) + public StringJsonConverter(DeserializationVariableReplacementSettings replacementSettings) { - _replacementFailureMode = replacementFailureMode; + _replacementSettings = replacementSettings; } public override string? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) @@ -61,7 +44,18 @@ public StringJsonConverter(EnvironmentVariableReplacementFailureMode replacement if (reader.TokenType == JsonTokenType.String) { string? value = reader.GetString(); - return Regex.Replace(value!, ENV_PATTERN, new MatchEvaluator(ReplaceMatchWithEnvVariable)); + if (string.IsNullOrWhiteSpace(value)) + { + return value; + } + + // Apply all replacement strategies configured in the settings + foreach (KeyValuePair> strategy in _replacementSettings.ReplacementStrategies) + { + value = strategy.Key.Replace(value, new MatchEvaluator(strategy.Value)); + } + + return value; } if (reader.TokenType == JsonTokenType.Null) @@ -76,30 +70,5 @@ public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOp { writer.WriteStringValue(value); } - - private string ReplaceMatchWithEnvVariable(Match match) - { - // [^@env\(] : any substring that is not @env( - // .* : any char except newline any number of times - // (?=\)) : look ahead for end char of ) - // This pattern greedy matches all characters that are not a part of @env() - // ie: @env('hello@env('goodbye')world') match: 'hello@env('goodbye')world' - string innerPattern = @"[^@env\(].*(?=\))"; - - // strips first and last characters, ie: '''hello'' --> ''hello' - string envName = Regex.Match(match.Value, innerPattern).Value[1..^1]; - string? envValue = Environment.GetEnvironmentVariable(envName); - if (_replacementFailureMode == EnvironmentVariableReplacementFailureMode.Throw) - { - return envValue is not null ? envValue : - throw new DataApiBuilderException(message: $"Environmental Variable, {envName}, not found.", - statusCode: System.Net.HttpStatusCode.ServiceUnavailable, - subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); - } - else - { - return envValue ?? match.Value; - } - } } } diff --git a/src/Config/Converters/Utf8JsonReaderExtensions.cs b/src/Config/Converters/Utf8JsonReaderExtensions.cs index 20c6821d02..5e16357227 100644 --- a/src/Config/Converters/Utf8JsonReaderExtensions.cs +++ b/src/Config/Converters/Utf8JsonReaderExtensions.cs @@ -13,14 +13,12 @@ static internal class Utf8JsonReaderExtensions /// substitution is applied. /// /// The reader that we want to pull the string from. - /// Whether to replace environment variable with its - /// value or not while deserializing. + /// The replacement settings to use while deserializing. /// The failure mode to use when replacing environment variables. /// The result of deserialization. /// Thrown if the is not String. public static string? DeserializeString(this Utf8JsonReader reader, - bool replaceEnvVar, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + DeserializationVariableReplacementSettings? replacementSettings) { if (reader.TokenType is JsonTokenType.Null) { @@ -34,9 +32,9 @@ static internal class Utf8JsonReaderExtensions // Add the StringConverterFactory so that we can do environment variable substitution. JsonSerializerOptions options = new(); - if (replaceEnvVar) + if (replacementSettings is not null) { - options.Converters.Add(new StringJsonConverterFactory(replacementFailureMode)); + options.Converters.Add(new StringJsonConverterFactory(replacementSettings)); } return JsonSerializer.Deserialize(ref reader, options); diff --git a/src/Config/DeserializationVariableReplacementSettings.cs b/src/Config/DeserializationVariableReplacementSettings.cs new file mode 100644 index 0000000000..d4c02b5252 --- /dev/null +++ b/src/Config/DeserializationVariableReplacementSettings.cs @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.RegularExpressions; +using Azure.Core; +using Azure.DataApiBuilder.Config.Converters; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Service.Exceptions; +using Azure.Identity; +using Azure.Security.KeyVault.Secrets; + +namespace Azure.DataApiBuilder.Config +{ + public class DeserializationVariableReplacementSettings + { + public bool DoReplaceEnvVar { get; set; } + public bool DoReplaceAkvVar { get; set; } + public EnvironmentVariableReplacementFailureMode EnvFailureMode { get; set; } = EnvironmentVariableReplacementFailureMode.Throw; + + // @env\(' : match @env(' + // @akv\(' : match @akv(' + // .*? : lazy match any character except newline 0 or more times + // (?='\)) : look ahead for ')' which will combine with our lazy match + // ie: in @env('hello')goodbye') we match @env('hello') + // '\) : consume the ') into the match (look ahead doesn't capture) + // This pattern lazy matches any string that starts with @env(' and ends with ') OR @akv(' and ends with ') + // Example: fooBAR@env('hello-world')bash)FOO') match: @env('hello-world') + // Example: fooBAR@akv('secret-name')bash)FOO') match: @akv('secret-name') + // This matching pattern allows for the @env('') / @akv('') to be safely nested + // within strings that contain ')' after our match. + // Note that there is no escape character currently for ')' to exist within the name of the variable. + public const string OUTER_ENV_PATTERN = @"@env\('.*?(?='\))'\)"; + public const string OUTER_AKV_PATTERN = @"@akv\('.*?(?='\))'\)"; + + // [^@env\(] : any substring that is not @env( + // [^@akv\(] : any substring that is not @akv( + // .* : any char except newline any number of times + // (?=\)) : look ahead for end char of ) + // This pattern greedy matches all characters that are not a part of @env() / @akv() + // ie: @env('hello@env('goodbye')world') match: 'hello@env('goodbye')world' + public const string INNER_ENV_PATTERN = @"[^@env\(].*(?=\))"; + public const string INNER_AKV_PATTERN = @"[^@akv\(].*(?=\))"; + + private readonly AzureKeyVaultOptions? _azureKeyVaultOptions; + private readonly SecretClient? _akvClient; + + public Dictionary> ReplacementStrategies { get; private set; } = new(); + + public DeserializationVariableReplacementSettings( + AzureKeyVaultOptions? azureKeyVaultOptions = null, + bool doReplaceEnvVar = false, + bool doReplaceAkvVar = false, + EnvironmentVariableReplacementFailureMode envFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + { + _azureKeyVaultOptions = azureKeyVaultOptions; + DoReplaceEnvVar = doReplaceEnvVar; + DoReplaceAkvVar = doReplaceAkvVar; + EnvFailureMode = envFailureMode; + + if (DoReplaceEnvVar) + { + ReplacementStrategies.Add( + new Regex(OUTER_ENV_PATTERN, RegexOptions.Compiled), + ReplaceEnvVariable); + } + + if (DoReplaceAkvVar && _azureKeyVaultOptions is not null) + { + _akvClient = CreateSecretClient(_azureKeyVaultOptions); + ReplacementStrategies.Add( + new Regex(OUTER_AKV_PATTERN, RegexOptions.Compiled), + ReplaceAkvVariable); + } + } + + private string ReplaceEnvVariable(Match match) + { + // strips first and last characters, ie: '''hello'' --> ''hello' + string name = Regex.Match(match.Value, INNER_ENV_PATTERN).Value[1..^1]; + string? value = Environment.GetEnvironmentVariable(name); + if (EnvFailureMode is EnvironmentVariableReplacementFailureMode.Throw) + { + return value is not null ? value : + throw new DataApiBuilderException( + message: $"Environmental Variable, {name}, not found.", + statusCode: System.Net.HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + else + { + return value ?? match.Value; + } + } + + private string ReplaceAkvVariable(Match match) + { + // strips first and last characters, ie: '''hello'' --> ''hello' + string name = Regex.Match(match.Value, INNER_AKV_PATTERN).Value[1..^1]; + + // Validate AKV secret name per rules: + // Allowed: alphanumeric and hyphen (-) + // Disallowed: spaces or any other symbols + // Must start and end with alphanumeric + // Length: 1 to 127 chars + if (!IsValidAkvSecretName(name, out string validationError)) + { + throw new DataApiBuilderException( + message: $"Azure Key Vault secret name '{name}' is invalid. {validationError} Requirements: allowed characters are alphanumeric and hyphen (-); must start and end with an alphanumeric character; length 1-127 characters; case-insensitive.", + statusCode: System.Net.HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + string? value = GetAkvVariable(name); + if (EnvFailureMode == EnvironmentVariableReplacementFailureMode.Throw) + { + return value is not null ? value : + throw new DataApiBuilderException(message: $"Azure Key Vault Variable, '{name}', not found.", + statusCode: System.Net.HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + else + { + return value ?? match.Value; + } + } + + private static bool IsValidAkvSecretName(string name, out string error) + { + error = string.Empty; + if (string.IsNullOrEmpty(name)) + { + error = "Name cannot be null or empty."; + return false; + } + + if (name.Length < 1 || name.Length > 127) + { + error = $"Length {name.Length} is outside allowed range (1-127)."; + return false; + } + + // Must start and end with alphanumeric + if (!char.IsLetterOrDigit(name[0]) || !char.IsLetterOrDigit(name[^1])) + { + error = "Must start and end with an alphanumeric character."; + return false; + } + + // Allowed characters: letters, digits, hyphen. + for (int i = 0; i < name.Length; i++) + { + char c = name[i]; + if (!(char.IsLetterOrDigit(c) || c == '-')) + { + error = $"Invalid character '{c}' at position {i}."; + return false; + } + } + + return true; + } + + private static SecretClient CreateSecretClient(AzureKeyVaultOptions options) + { + if (string.IsNullOrWhiteSpace(options.Endpoint)) + { + throw new DataApiBuilderException( + "Missing 'endpoint' property is required to connect to Azure Key Vault.", + System.Net.HttpStatusCode.InternalServerError, + DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + + SecretClientOptions clientOptions = new(); + + if (options.RetryPolicy is not null) + { + // Convert AKVRetryPolicyMode to RetryMode + RetryMode retryMode = options.RetryPolicy.Mode switch + { + AKVRetryPolicyMode.Fixed => RetryMode.Fixed, + AKVRetryPolicyMode.Exponential => RetryMode.Exponential, + null => RetryMode.Exponential, + _ => RetryMode.Exponential + }; + + clientOptions.Retry.Mode = retryMode; + clientOptions.Retry.MaxRetries = options.RetryPolicy.MaxCount ?? AKVRetryPolicyOptions.DEFAULT_MAX_COUNT; + clientOptions.Retry.Delay = TimeSpan.FromSeconds(options.RetryPolicy.DelaySeconds ?? AKVRetryPolicyOptions.DEFAULT_DELAY_SECONDS); + clientOptions.Retry.MaxDelay = TimeSpan.FromSeconds(options.RetryPolicy.MaxDelaySeconds ?? AKVRetryPolicyOptions.DEFAULT_MAX_DELAY_SECONDS); + clientOptions.Retry.NetworkTimeout = TimeSpan.FromSeconds(options.RetryPolicy.NetworkTimeoutSeconds ?? AKVRetryPolicyOptions.DEFAULT_NETWORK_TIMEOUT_SECONDS); + } + + return new SecretClient(new Uri(options.Endpoint), new DefaultAzureCredential(), clientOptions); + } + + private string? GetAkvVariable(string name) + { + if (_akvClient is null) + { + throw new InvalidOperationException("Azure Key Vault client is not initialized."); + } + + try + { + return _akvClient.GetSecret(name).Value.Value; + } + catch (Azure.RequestFailedException ex) when (ex.Status == 404) + { + return null; + } + } + } +} diff --git a/src/Config/FileSystemRuntimeConfigLoader.cs b/src/Config/FileSystemRuntimeConfigLoader.cs index 9c2a8e50b5..614cfbd11c 100644 --- a/src/Config/FileSystemRuntimeConfigLoader.cs +++ b/src/Config/FileSystemRuntimeConfigLoader.cs @@ -182,17 +182,16 @@ private void OnNewFileContentsDetected(object? sender, EventArgs e) /// /// The path to the dab-config.json file. /// The loaded RuntimeConfig, or null if none was loaded. - /// Whether to replace environment variable with its - /// value or not while deserializing. /// ILogger for logging errors. /// When not null indicates we need to overwrite mode and how to do so. + /// Settings for variable replacement during deserialization. If null, uses default settings with environment variable replacement disabled. /// True if the config was loaded, otherwise false. public bool TryLoadConfig( string path, [NotNullWhen(true)] out RuntimeConfig? config, - bool replaceEnvVar = false, ILogger? logger = null, - bool? isDevMode = null) + bool? isDevMode = null, + DeserializationVariableReplacementSettings? replacementSettings = null) { if (_fileSystem.File.Exists(path)) { @@ -226,7 +225,15 @@ public bool TryLoadConfig( } } - if (!string.IsNullOrEmpty(json) && TryParseConfig(json, out RuntimeConfig, connectionString: _connectionString, replaceEnvVar: replaceEnvVar)) + // Use default replacement settings if none provided + replacementSettings ??= new DeserializationVariableReplacementSettings(); + + if (!string.IsNullOrEmpty(json) && TryParseConfig( + json, + out RuntimeConfig, + replacementSettings, + logger: null, + connectionString: _connectionString)) { if (TrySetupConfigFileWatcher()) { @@ -292,12 +299,13 @@ public bool TryLoadConfig( /// Tries to load the config file using the filename known to the RuntimeConfigLoader and for the default environment. /// /// The loaded RuntimeConfig, or null if none was loaded. - /// Whether to replace environment variable with its - /// value or not while deserializing. + /// Settings for variable replacement during deserialization. If null, uses default settings with environment variable replacement disabled. /// True if the config was loaded, otherwise false. public override bool TryLoadKnownConfig([NotNullWhen(true)] out RuntimeConfig? config, bool replaceEnvVar = false) { - return TryLoadConfig(ConfigFilePath, out config, replaceEnvVar); + // Convert legacy replaceEnvVar parameter to replacement settings for backward compatibility + DeserializationVariableReplacementSettings? replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: replaceEnvVar); + return TryLoadConfig(ConfigFilePath, out config, replacementSettings: replacementSettings); } /// @@ -307,7 +315,11 @@ public override bool TryLoadKnownConfig([NotNullWhen(true)] out RuntimeConfig? c private void HotReloadConfig(bool isDevMode, ILogger? logger = null) { logger?.LogInformation(message: "Starting hot-reload process for config: {ConfigFilePath}", ConfigFilePath); - if (!TryLoadConfig(ConfigFilePath, out _, replaceEnvVar: true, isDevMode: isDevMode)) + + // Use default replacement settings for hot reload + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); + + if (!TryLoadConfig(ConfigFilePath, out _, logger: logger, isDevMode: isDevMode, replacementSettings: replacementSettings)) { throw new DataApiBuilderException( message: "Deserialization of the configuration file failed.", @@ -467,7 +479,7 @@ public override string GetPublishedDraftSchemaLink() string? schemaPath = _fileSystem.Path.Combine(assemblyDirectory, "dab.draft.schema.json"); string schemaFileContent = _fileSystem.File.ReadAllText(schemaPath); - Dictionary? jsonDictionary = JsonSerializer.Deserialize>(schemaFileContent, GetSerializationOptions()); + Dictionary? jsonDictionary = JsonSerializer.Deserialize>(schemaFileContent, GetSerializationOptions(replacementSettings: null)); if (jsonDictionary is null) { diff --git a/src/Config/ObjectModel/AzureKeyVaultOptions.cs b/src/Config/ObjectModel/AzureKeyVaultOptions.cs index 27094cd16f..ebd1e909c1 100644 --- a/src/Config/ObjectModel/AzureKeyVaultOptions.cs +++ b/src/Config/ObjectModel/AzureKeyVaultOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace Azure.DataApiBuilder.Config.ObjectModel; @@ -12,4 +13,40 @@ public record AzureKeyVaultOptions [JsonPropertyName("retry-policy")] public AKVRetryPolicyOptions? RetryPolicy { get; init; } + + /// + /// Flag which informs CLI and JSON serializer whether to write endpoint + /// property and value to the runtime config file. + /// When user doesn't provide the endpoint property/value, which signals DAB to use the default, + /// the DAB CLI should not write the default value to a serialized config. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Endpoint))] + public bool UserProvidedEndpoint { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write retry-policy + /// property and value to the runtime config file. + /// When user doesn't provide the retry-policy property/value, which signals DAB to use the default, + /// the DAB CLI should not write the default value to a serialized config. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(RetryPolicy))] + public bool UserProvidedRetryPolicy { get; init; } = false; + + [JsonConstructor] + public AzureKeyVaultOptions(string? endpoint = null, AKVRetryPolicyOptions? retryPolicy = null) + { + if (endpoint is not null) + { + Endpoint = endpoint; + UserProvidedEndpoint = true; + } + + if (retryPolicy is not null) + { + RetryPolicy = retryPolicy; + UserProvidedRetryPolicy = true; + } + } } diff --git a/src/Config/ObjectModel/RuntimeConfig.cs b/src/Config/ObjectModel/RuntimeConfig.cs index a450e1265c..6896d82161 100644 --- a/src/Config/ObjectModel/RuntimeConfig.cs +++ b/src/Config/ObjectModel/RuntimeConfig.cs @@ -298,7 +298,10 @@ public RuntimeConfig( foreach (string dataSourceFile in DataSourceFiles.SourceFiles) { - if (loader.TryLoadConfig(dataSourceFile, out RuntimeConfig? config, replaceEnvVar: true)) + // Use default replacement settings for environment variable replacement + DeserializationVariableReplacementSettings replacementSettings = new(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true); + + if (loader.TryLoadConfig(dataSourceFile, out RuntimeConfig? config, replacementSettings: replacementSettings)) { try { @@ -448,7 +451,7 @@ public bool CheckDataSourceExists(string dataSourceName) public string ToJson(JsonSerializerOptions? jsonSerializerOptions = null) { // get default serializer options if none provided. - jsonSerializerOptions = jsonSerializerOptions ?? RuntimeConfigLoader.GetSerializationOptions(); + jsonSerializerOptions = jsonSerializerOptions ?? RuntimeConfigLoader.GetSerializationOptions(replacementSettings: null); return JsonSerializer.Serialize(this, jsonSerializerOptions); } diff --git a/src/Config/RuntimeConfigLoader.cs b/src/Config/RuntimeConfigLoader.cs index f78c32ebc1..bad5aa8680 100644 --- a/src/Config/RuntimeConfigLoader.cs +++ b/src/Config/RuntimeConfigLoader.cs @@ -129,25 +129,86 @@ protected void SignalConfigChanged(string message = "") /// public abstract string GetPublishedDraftSchemaLink(); + /// + /// Extracts AzureKeyVaultOptions from JSON string with configurable variable replacement. + /// + /// JSON that represents the config file. + /// Whether to enable environment variable replacement during extraction. + /// Failure mode for environment variable replacement if enabled. + /// AzureKeyVaultOptions if present, null otherwise. + private static AzureKeyVaultOptions? ExtractAzureKeyVaultOptions( + string json, + bool enableEnvReplacement, + EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + { + JsonSerializerOptions options = new() + { + PropertyNameCaseInsensitive = false, + PropertyNamingPolicy = new HyphenatedNamingPolicy(), + ReadCommentHandling = JsonCommentHandling.Skip + }; + DeserializationVariableReplacementSettings envOnlySettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: enableEnvReplacement, + doReplaceAkvVar: false, + envFailureMode: replacementFailureMode); + options.Converters.Add(new StringJsonConverterFactory(envOnlySettings)); + options.Converters.Add(new EnumMemberJsonEnumConverterFactory()); + options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings: envOnlySettings)); + options.Converters.Add(new AKVRetryPolicyOptionsConverterFactory(replacementSettings: envOnlySettings)); + + try + { + using JsonDocument doc = JsonDocument.Parse(json); + if (doc.RootElement.TryGetProperty("azure-key-vault", out JsonElement akvElement)) + { + return JsonSerializer.Deserialize(akvElement.GetRawText(), options); + } + } + catch + { + // If we can't extract AKV options, return null and proceed without AKV variable replacement + return null; + } + + return null; + } + /// /// Parses a JSON string into a RuntimeConfig object for single database scenario. /// /// JSON that represents the config file. /// The parsed config, or null if it parsed unsuccessfully. - /// True if the config was parsed, otherwise false. + /// Settings for variable replacement during deserialization. If null, no variable replacement will be performed. /// logger to log messages /// connectionString to add to config if specified - /// Whether to replace environment variable with its - /// value or not while deserializing. By default, no replacement happens. - /// Determines failure mode for env variable replacement. + /// True if the config was parsed, otherwise false. public static bool TryParseConfig(string json, [NotNullWhen(true)] out RuntimeConfig? config, + DeserializationVariableReplacementSettings? replacementSettings = null, ILogger? logger = null, - string? connectionString = null, - bool replaceEnvVar = false, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + string? connectionString = null) { - JsonSerializerOptions options = GetSerializationOptions(replaceEnvVar, replacementFailureMode); + // First pass: extract AzureKeyVault options if AKV replacement is requested + if (replacementSettings?.DoReplaceAkvVar is true) + { + AzureKeyVaultOptions? azureKeyVaultOptions = ExtractAzureKeyVaultOptions( + json: json, + enableEnvReplacement: replacementSettings.DoReplaceEnvVar, + replacementFailureMode: replacementSettings.EnvFailureMode); + + // Update replacement settings with the extracted AKV options + if (azureKeyVaultOptions is not null) + { + replacementSettings = new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: azureKeyVaultOptions, + doReplaceEnvVar: replacementSettings.DoReplaceEnvVar, + doReplaceAkvVar: replacementSettings.DoReplaceAkvVar, + envFailureMode: replacementSettings.EnvFailureMode); + } + } + + JsonSerializerOptions options = GetSerializationOptions(replacementSettings); try { @@ -180,11 +241,11 @@ public static bool TryParseConfig(string json, DataSource ds = config.GetDataSourceFromDataSourceName(dataSourceKey); // Add Application Name for telemetry for MsSQL or PgSql - if (ds.DatabaseType is DatabaseType.MSSQL && replaceEnvVar) + if (ds.DatabaseType is DatabaseType.MSSQL && replacementSettings?.DoReplaceEnvVar == true) { updatedConnection = GetConnectionStringWithApplicationName(connectionValue); } - else if (ds.DatabaseType is DatabaseType.PostgreSQL && replaceEnvVar) + else if (ds.DatabaseType is DatabaseType.PostgreSQL && replacementSettings?.DoReplaceEnvVar == true) { updatedConnection = GetPgSqlConnectionStringWithApplicationName(connectionValue); } @@ -225,11 +286,10 @@ ex is JsonException || /// /// Get Serializer options for the config file. /// - /// Whether to replace environment variable with value or not while deserializing. - /// By default, no replacement happens. + /// Settings for variable replacement during deserialization. + /// If null, no variable replacement will be performed. public static JsonSerializerOptions GetSerializationOptions( - bool replaceEnvVar = false, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + DeserializationVariableReplacementSettings? replacementSettings = null) { JsonSerializerOptions options = new() { @@ -241,33 +301,37 @@ public static JsonSerializerOptions GetSerializationOptions( Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }; options.Converters.Add(new EnumMemberJsonEnumConverterFactory()); - options.Converters.Add(new RuntimeHealthOptionsConvertorFactory(replaceEnvVar)); - options.Converters.Add(new DataSourceHealthOptionsConvertorFactory(replaceEnvVar)); + options.Converters.Add(new RuntimeHealthOptionsConvertorFactory(replacementSettings)); + options.Converters.Add(new DataSourceHealthOptionsConvertorFactory(replacementSettings)); options.Converters.Add(new EntityHealthOptionsConvertorFactory()); options.Converters.Add(new RestRuntimeOptionsConverterFactory()); - options.Converters.Add(new GraphQLRuntimeOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new McpRuntimeOptionsConverterFactory(replaceEnvVar)); + options.Converters.Add(new GraphQLRuntimeOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new McpRuntimeOptionsConverterFactory(replacementSettings)); options.Converters.Add(new DmlToolsConfigConverter()); - options.Converters.Add(new EntitySourceConverterFactory(replaceEnvVar)); - options.Converters.Add(new EntityGraphQLOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new EntityRestOptionsConverterFactory(replaceEnvVar)); + options.Converters.Add(new EntitySourceConverterFactory(replacementSettings)); + options.Converters.Add(new EntityGraphQLOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new EntityRestOptionsConverterFactory(replacementSettings)); options.Converters.Add(new EntityActionConverterFactory()); options.Converters.Add(new DataSourceFilesConverter()); - options.Converters.Add(new EntityCacheOptionsConverterFactory(replaceEnvVar)); + options.Converters.Add(new EntityCacheOptionsConverterFactory(replacementSettings)); options.Converters.Add(new RuntimeCacheOptionsConverterFactory()); options.Converters.Add(new RuntimeCacheLevel2OptionsConverterFactory()); options.Converters.Add(new MultipleCreateOptionsConverter()); options.Converters.Add(new MultipleMutationOptionsConverter(options)); - options.Converters.Add(new DataSourceConverterFactory(replaceEnvVar)); + options.Converters.Add(new DataSourceConverterFactory(replacementSettings)); options.Converters.Add(new HostOptionsConvertorFactory()); - options.Converters.Add(new AKVRetryPolicyOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new AzureLogAnalyticsOptionsConverterFactory(replaceEnvVar)); - options.Converters.Add(new AzureLogAnalyticsAuthOptionsConverter(replaceEnvVar)); - options.Converters.Add(new FileSinkConverter(replaceEnvVar)); + options.Converters.Add(new AKVRetryPolicyOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new AzureLogAnalyticsOptionsConverterFactory(replacementSettings)); + options.Converters.Add(new AzureLogAnalyticsAuthOptionsConverter(replacementSettings)); + options.Converters.Add(new FileSinkConverter(replacementSettings)); + + // Add AzureKeyVaultOptionsConverterFactory to ensure AKV config is deserialized properly + options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings)); - if (replaceEnvVar) + // Only add the extensible string converter if we have replacement settings + if (replacementSettings is not null) { - options.Converters.Add(new StringJsonConverterFactory(replacementFailureMode)); + options.Converters.Add(new StringJsonConverterFactory(replacementSettings)); } return options; diff --git a/src/Core/Configurations/RuntimeConfigProvider.cs b/src/Core/Configurations/RuntimeConfigProvider.cs index faeb2b94d0..b46a716f48 100644 --- a/src/Core/Configurations/RuntimeConfigProvider.cs +++ b/src/Core/Configurations/RuntimeConfigProvider.cs @@ -6,7 +6,6 @@ using System.IO.Abstractions; using System.Net; using Azure.DataApiBuilder.Config; -using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.NamingPolicies; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Service.Exceptions; @@ -189,8 +188,7 @@ public async Task Initialize( if (RuntimeConfigLoader.TryParseConfig( configuration, out RuntimeConfig? runtimeConfig, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Ignore)) + replacementSettings: null)) { _configLoader.RuntimeConfig = runtimeConfig; @@ -257,8 +255,7 @@ public async Task Initialize( string? graphQLSchema, string connectionString, string? accessToken, - bool replaceEnvVar = true, - EnvironmentVariableReplacementFailureMode replacementFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + DeserializationVariableReplacementSettings? replacementSettings) { if (string.IsNullOrEmpty(connectionString)) { @@ -272,7 +269,7 @@ public async Task Initialize( IsLateConfigured = true; - if (RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig? runtimeConfig, replaceEnvVar: replaceEnvVar, replacementFailureMode: replacementFailureMode)) + if (RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig? runtimeConfig, replacementSettings)) { _configLoader.RuntimeConfig = runtimeConfig.DataSource.DatabaseType switch { diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index 14f097915c..542508f71f 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -5,6 +5,7 @@ + diff --git a/src/Service.Tests/Caching/CachingConfigProcessingTests.cs b/src/Service.Tests/Caching/CachingConfigProcessingTests.cs index 2780af63c5..a6daebf3e4 100644 --- a/src/Service.Tests/Caching/CachingConfigProcessingTests.cs +++ b/src/Service.Tests/Caching/CachingConfigProcessingTests.cs @@ -5,7 +5,6 @@ using System.Text; using System.Text.Json; using Azure.DataApiBuilder.Config; -using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -56,10 +55,7 @@ public void EntityCacheOptionsDeserialization_ValidJson( RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); @@ -103,10 +99,7 @@ public void EntityCacheOptionsDeserialization_InvalidValues(string entityCacheCo bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( json: fullConfig, out _, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsFalse(isParsingSuccessful, message: "Expected JSON parsing to fail."); @@ -141,10 +134,7 @@ public void GlobalCacheOptionsDeserialization_ValidValues( RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); @@ -187,10 +177,7 @@ public void GlobalCacheOptionsDeserialization_InvalidValues(string globalCacheCo bool parsingSuccessful = RuntimeConfigLoader.TryParseConfig( json: fullConfig, out _, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsFalse(parsingSuccessful, message: "Expected JSON parsing to fail."); @@ -216,10 +203,7 @@ public void GlobalCacheOptionsOverridesEntityCacheOptions(string globalCacheConf RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); // Assert Assert.IsNotNull(config, message: "Config must not be null, runtime config JSON deserialization failed."); @@ -252,10 +236,7 @@ public void UserDefinedTtlWrittenToSerializedJsonConfigFile(bool expectIsUserDef RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); Assert.IsNotNull(config, message: "Test setup failure. Config must not be null, runtime config JSON deserialization failed."); // Act @@ -300,10 +281,7 @@ public void CachePropertyNotWrittenToSerializedJsonConfigFile(string cacheConfig RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); Assert.IsNotNull(config, message: "Test setup failure. Config must not be null, runtime config JSON deserialization failed."); // Act @@ -342,10 +320,7 @@ public void DefaultTtlNotWrittenToSerializedJsonConfigFile(string cacheConfig) RuntimeConfigLoader.TryParseConfig( json: fullConfig, out RuntimeConfig? config, - logger: null, - connectionString: null, - replaceEnvVar: false, - replacementFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + replacementSettings: null); Assert.IsNotNull(config, message: "Test setup failure. Config must not be null, runtime config JSON deserialization failed."); // Act diff --git a/src/Service.Tests/Configuration/ConfigurationTests.cs b/src/Service.Tests/Configuration/ConfigurationTests.cs index 65f6e6643b..0614e7688f 100644 --- a/src/Service.Tests/Configuration/ConfigurationTests.cs +++ b/src/Service.Tests/Configuration/ConfigurationTests.cs @@ -838,9 +838,9 @@ public void MsSqlConnStringSupplementedWithAppNameProperty( // Act bool configParsed = RuntimeConfigLoader.TryParseConfig( - runtimeConfig.ToJson(), - out RuntimeConfig updatedRuntimeConfig, - replaceEnvVar: true); + json: runtimeConfig.ToJson(), + config: out RuntimeConfig updatedRuntimeConfig, + replacementSettings: new(doReplaceEnvVar: true)); // Assert Assert.AreEqual( @@ -891,9 +891,9 @@ public void PgSqlConnStringSupplementedWithAppNameProperty( // Act bool configParsed = RuntimeConfigLoader.TryParseConfig( - runtimeConfig.ToJson(), - out RuntimeConfig updatedRuntimeConfig, - replaceEnvVar: true); + json: runtimeConfig.ToJson(), + config: out RuntimeConfig updatedRuntimeConfig, + replacementSettings: new(doReplaceEnvVar: true)); // Assert Assert.AreEqual( @@ -956,9 +956,9 @@ public void TestConnectionStringIsCorrectlyUpdatedWithApplicationName( // Act bool configParsed = RuntimeConfigLoader.TryParseConfig( - runtimeConfig.ToJson(), - out RuntimeConfig updatedRuntimeConfig, - replaceEnvVar: true); + json: runtimeConfig.ToJson(), + config: out RuntimeConfig updatedRuntimeConfig, + replacementSettings: new(doReplaceEnvVar: true)); // Assert Assert.AreEqual( @@ -2346,7 +2346,12 @@ public async Task TestSPRestDefaultsForManuallyConstructedConfigs( HttpStatusCode expectedResponseStatusCode) { string configJson = TestHelper.AddPropertiesToJson(TestHelper.BASE_CONFIG, entityJson); - RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, logger: null, GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); + RuntimeConfigLoader.TryParseConfig( + configJson, + out RuntimeConfig deserializedConfig, + replacementSettings: new(), + logger: null, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); string configFileName = "custom-config.json"; File.WriteAllText(configFileName, deserializedConfig.ToJson()); string[] args = new[] @@ -2429,7 +2434,12 @@ public async Task SanityTestForRestAndGQLRequestsWithoutMultipleMutationFeatureF // The configuration file is constructed by merging hard-coded JSON strings to simulate the scenario where users manually edit the // configuration file (instead of using CLI). string configJson = TestHelper.AddPropertiesToJson(TestHelper.BASE_CONFIG, BOOK_ENTITY_JSON); - Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, logger: null, GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL))); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( + configJson, + out RuntimeConfig deserializedConfig, + replacementSettings: new(), + logger: null, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL))); string configFileName = "custom-config.json"; File.WriteAllText(configFileName, deserializedConfig.ToJson()); string[] args = new[] @@ -3290,7 +3300,12 @@ public async Task ValidateStrictModeAsDefaultForRestRequestBody(bool includeExtr // The BASE_CONFIG omits the rest.request-body-strict option in the runtime section. string configJson = TestHelper.AddPropertiesToJson(TestHelper.BASE_CONFIG, entityJson); - RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, logger: null, GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); + RuntimeConfigLoader.TryParseConfig( + configJson, + out RuntimeConfig deserializedConfig, + replacementSettings: new(), + logger: null, + GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)); const string CUSTOM_CONFIG = "custom-config.json"; File.WriteAllText(CUSTOM_CONFIG, deserializedConfig.ToJson()); string[] args = new[] @@ -5494,7 +5509,7 @@ public static string GetConnectionStringFromEnvironmentConfig(string environment string sqlFile = new FileSystemRuntimeConfigLoader(fileSystem).GetFileNameForEnvironment(environment, considerOverrides: true); string configPayload = File.ReadAllText(sqlFile); - RuntimeConfigLoader.TryParseConfig(configPayload, out RuntimeConfig runtimeConfig, replaceEnvVar: true); + RuntimeConfigLoader.TryParseConfig(configPayload, out RuntimeConfig runtimeConfig, replacementSettings: new()); return runtimeConfig.DataSource.ConnectionString; } diff --git a/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs index cbfef36664..63deed78d3 100644 --- a/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/MySqlQueryExecutorUnitTests.cs @@ -81,7 +81,8 @@ await provider.Initialize( provider.GetConfig().ToJson(), graphQLSchema: null, connectionString: connectionString, - accessToken: CONFIG_TOKEN); + accessToken: CONFIG_TOKEN, + replacementSettings: new()); mySqlQueryExecutor = new(provider, dbExceptionParser.Object, queryExecutorLogger.Object, httpContextAccessor.Object); } } diff --git a/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs index ccaa90b353..6039c46a72 100644 --- a/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/PostgreSqlQueryExecutorUnitTests.cs @@ -89,7 +89,8 @@ await provider.Initialize( provider.GetConfig().ToJson(), graphQLSchema: null, connectionString: connectionString, - accessToken: CONFIG_TOKEN); + accessToken: CONFIG_TOKEN, + replacementSettings: new()); postgreSqlQueryExecutor = new(provider, dbExceptionParser.Object, queryExecutorLogger.Object, httpContextAccessor.Object); } } diff --git a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs index b98de993e2..47629ca4c8 100644 --- a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs +++ b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs @@ -79,18 +79,18 @@ public void CheckConfigEnvParsingTest( if (replaceEnvVar) { Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( - GetModifiedJsonString(repValues, @"""postgresql"""), out expectedConfig, replaceEnvVar: replaceEnvVar), + GetModifiedJsonString(repValues, @"""postgresql"""), out expectedConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: false)), "Should read the expected config"); } else { Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( - GetModifiedJsonString(repKeys, @"""postgresql"""), out expectedConfig, replaceEnvVar: replaceEnvVar), + GetModifiedJsonString(repKeys, @"""postgresql"""), out expectedConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: false)), "Should read the expected config"); } Assert.IsTrue(RuntimeConfigLoader.TryParseConfig( - GetModifiedJsonString(repKeys, @"""@env('enumVarName')"""), out RuntimeConfig actualConfig, replaceEnvVar: replaceEnvVar), + GetModifiedJsonString(repKeys, @"""@env('enumVarName')"""), out RuntimeConfig actualConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: false)), "Should read actual config"); Assert.AreEqual(expectedConfig.ToJson(), actualConfig.ToJson()); } @@ -130,7 +130,7 @@ public void TestConfigParsingWithEnvVarReplacement(bool replaceEnvVar, string da string configWithEnvVar = _configWithVariableDataSource.Replace("{0}", GetDataSourceConfigForGivenDatabase(databaseType)); bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - configWithEnvVar, out RuntimeConfig runtimeConfig, replaceEnvVar: replaceEnvVar); + configWithEnvVar, out RuntimeConfig runtimeConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: replaceEnvVar, doReplaceAkvVar: true)); // Assert Assert.IsTrue(isParsingSuccessful); @@ -178,7 +178,7 @@ public void TestConfigParsingWhenDataSourceOptionsForCosmosDBContainsInvalidValu string configWithEnvVar = _configWithVariableDataSource.Replace("{0}", GetDataSourceOptionsForCosmosDBWithInvalidValues()); bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - configWithEnvVar, out RuntimeConfig runtimeConfig, replaceEnvVar: true); + configWithEnvVar, out RuntimeConfig runtimeConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true)); // Assert Assert.IsTrue(isParsingSuccessful); @@ -302,7 +302,7 @@ public void CheckConfigEnvParsingThrowExceptions(string invalidEnvVarName) { string json = @"{ ""foo"" : ""@env('envVarName'), @env('" + invalidEnvVarName + @"')"" }"; SetEnvVariables(); - StringJsonConverterFactory stringConverterFactory = new(EnvironmentVariableReplacementFailureMode.Throw); + StringJsonConverterFactory stringConverterFactory = new(new(doReplaceEnvVar: true, envFailureMode: EnvironmentVariableReplacementFailureMode.Throw)); JsonSerializerOptions options = new() { PropertyNameCaseInsensitive = true }; options.Converters.Add(stringConverterFactory); Assert.ThrowsException(() => JsonSerializer.Deserialize(json, options)); @@ -324,7 +324,7 @@ public void TestDataSourceDeserializationFailures(string dbType, string connecti ""entities"":{ } }"; // replaceEnvVar: true is needed to make sure we do post-processing for the connection string case - Assert.IsFalse(RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, replaceEnvVar: true)); + Assert.IsFalse(RuntimeConfigLoader.TryParseConfig(configJson, out RuntimeConfig deserializedConfig, replacementSettings: new DeserializationVariableReplacementSettings(azureKeyVaultOptions: null, doReplaceEnvVar: true, doReplaceAkvVar: true))); Assert.IsNull(deserializedConfig); } @@ -343,7 +343,8 @@ public void TestLoadRuntimeConfigFailures( MockFileSystem fileSystem = new(); FileSystemRuntimeConfigLoader loader = new(fileSystem); - Assert.IsFalse(loader.TryLoadConfig(configFileName, out RuntimeConfig _)); + // Use null replacement settings for this test + Assert.IsFalse(loader.TryLoadConfig(configFileName, out RuntimeConfig _, replacementSettings: null)); } /// diff --git a/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs b/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs index 44978cd6aa..74d548fef4 100644 --- a/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs +++ b/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs @@ -8,6 +8,7 @@ using System.Reflection; using System.Text.Json; using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Services.MetadataProviders.Converters; @@ -187,7 +188,7 @@ public void TestSourceDefinitionCyclicObjectsSerializationDeserialization() _sourceDefinition.SourceEntityRelationshipMap.Add("persons", metadata); - // In serialization options we need ReferenceHandler = ReferenceHandler.Preserve, or else it doesnot seialize objects with cycle references + // In serialization options we need ReferenceHandler = ReferenceHandler.Preserve, or else it does not serialize objects with cycle references // SourceDefinition -> RelationShipMetadata -> ForeignKeyDefinition RelationshipPair ->DatabaseTable -> SourceDefinition Assert.ThrowsException(() => { @@ -489,5 +490,68 @@ private RelationShipPair GetRelationShipPair() }; return new(_databaseTable, table2); } + + /// + /// Verifies that when merging multiple runtime configs, if the child config omits + /// the azure-key-vault section, the merged result still contains the AzureKeyVaultOptions (including retry-policy) + /// inherited from the parent config. + /// + [TestMethod] + public void TestMergedConfigInheritsAzureKeyVaultOptions() + { + // Arrange + + // Parent config with AKV section. + string parentConfig = @"{ + ""data-source"": { ""database-type"": ""mssql"", ""connection-string"": ""Server=.;Database=Parent;Trusted_Connection=True;"" }, + ""runtime"": { ""rest"": { ""enabled"": true }, ""graphql"": { ""enabled"": true } }, + ""entities"": {}, + ""azure-key-vault"": { + ""endpoint"": ""https://myvault.vault.azure.net/"", + ""retry-policy"": { + ""mode"": ""fixed"", + ""max-count"": 7, + ""delay-seconds"": 3, + ""max-delay-seconds"": 15, + ""network-timeout-seconds"": 20 + } + } +}"; + + // Child config overrides some properties but omits azure-key-vault entirely. + string childConfig = @"{ + ""data-source"": { ""database-type"": ""mssql"", ""connection-string"": ""Server=.;Database=Child;Trusted_Connection=True;"" }, + ""runtime"": { ""rest"": { ""enabled"": true }, ""graphql"": { ""enabled"": true } }, + ""entities"": {} +}"; + // Act + + // Merge child over parent. + string mergedJson = MergeJsonProvider.Merge(parentConfig, childConfig); + + // Parse with AKV replacement enabled so extraction path executes. + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: true); + + // Assert + + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(mergedJson, out RuntimeConfig mergedConfig, replacementSettings: replacementSettings), "Merged runtime config failed to parse."); + Assert.IsNotNull(mergedConfig, "Merged runtime config is null."); + + // Validate AKV inheritance. + Assert.IsNotNull(mergedConfig.AzureKeyVault, "AzureKeyVaultOptions should be inherited from base config."); + Assert.AreEqual("https://myvault.vault.azure.net/", mergedConfig.AzureKeyVault!.Endpoint, "Inherited AKV endpoint mismatch."); + Assert.IsNotNull(mergedConfig.AzureKeyVault.RetryPolicy, "RetryPolicy should be inherited."); + Assert.AreEqual(AKVRetryPolicyMode.Fixed, mergedConfig.AzureKeyVault.RetryPolicy!.Mode, "Inherited retry-policy mode mismatch."); + Assert.AreEqual(7, mergedConfig.AzureKeyVault.RetryPolicy.MaxCount, "Inherited retry-policy max-count mismatch."); + Assert.AreEqual(3, mergedConfig.AzureKeyVault.RetryPolicy.DelaySeconds, "Inherited retry-policy delay-seconds mismatch."); + Assert.AreEqual(15, mergedConfig.AzureKeyVault.RetryPolicy.MaxDelaySeconds, "Inherited retry-policy max-delay-seconds mismatch."); + Assert.AreEqual(20, mergedConfig.AzureKeyVault.RetryPolicy.NetworkTimeoutSeconds, "Inherited retry-policy network-timeout-seconds mismatch."); + + // Ensure child override for connection-string applied while AKV remained from base. + Assert.AreEqual("Server=.;Database=Child;Trusted_Connection=True;", mergedConfig.DataSource.ConnectionString, "Child connection-string override not applied."); + } } } diff --git a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs index 908b7019c4..f549b1dd3e 100644 --- a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs @@ -115,7 +115,8 @@ await provider.Initialize( provider.GetConfig().ToJson(), graphQLSchema: null, connectionString: connectionString, - accessToken: CONFIG_TOKEN); + accessToken: CONFIG_TOKEN, + replacementSettings: new()); msSqlQueryExecutor = new(provider, dbExceptionParser.Object, queryExecutorLogger.Object, httpContextAccessor.Object); } } diff --git a/src/Service/Controllers/ConfigurationController.cs b/src/Service/Controllers/ConfigurationController.cs index be3f9bd727..4ad8fb40f4 100644 --- a/src/Service/Controllers/ConfigurationController.cs +++ b/src/Service/Controllers/ConfigurationController.cs @@ -91,8 +91,8 @@ public async Task Index([FromBody] ConfigurationPostParameters con configuration.Schema, configuration.ConnectionString, configuration.AccessToken, - replaceEnvVar: false, - replacementFailureMode: Config.Converters.EnvironmentVariableReplacementFailureMode.Ignore); + replacementSettings: new(azureKeyVaultOptions: null, doReplaceEnvVar: false, doReplaceAkvVar: false, envFailureMode: Config.Converters.EnvironmentVariableReplacementFailureMode.Ignore) + ); if (initResult && _configurationProvider.TryGetConfig(out _)) { From e93cd49c92d1b3ede158599236a5ca8e5fc05554 Mon Sep 17 00:00:00 2001 From: aaronburtle <93220300+aaronburtle@users.noreply.github.com> Date: Wed, 26 Nov 2025 14:08:05 -0800 Subject: [PATCH 2/8] Support for .akv files to AKV variable replacement. (#2957) ## Why make this change? Closes https://github.com/Azure/data-api-builder/issues/2748 ## What is this change? Adds the option to use a local .akv file instead of Azure Key Vault for @AKV('') replacement in the config file during deserialization. Similar to how we handle .env files. ## How was this tested? A new test was added that verifies we are able to do the replacement and get the correct resultant configuration. --------- Co-authored-by: Aniruddh Munde --- ...erializationVariableReplacementSettings.cs | 79 +++++++- ...untimeConfigLoaderJsonDeserializerTests.cs | 178 +++++++++++++++++- 2 files changed, 254 insertions(+), 3 deletions(-) diff --git a/src/Config/DeserializationVariableReplacementSettings.cs b/src/Config/DeserializationVariableReplacementSettings.cs index d4c02b5252..5c70f4082b 100644 --- a/src/Config/DeserializationVariableReplacementSettings.cs +++ b/src/Config/DeserializationVariableReplacementSettings.cs @@ -8,6 +8,7 @@ using Azure.DataApiBuilder.Service.Exceptions; using Azure.Identity; using Azure.Security.KeyVault.Secrets; +using Microsoft.Extensions.Logging; namespace Azure.DataApiBuilder.Config { @@ -43,6 +44,8 @@ public class DeserializationVariableReplacementSettings private readonly AzureKeyVaultOptions? _azureKeyVaultOptions; private readonly SecretClient? _akvClient; + private readonly Dictionary? _akvFileSecrets; + private readonly ILogger? _logger; public Dictionary> ReplacementStrategies { get; private set; } = new(); @@ -50,12 +53,14 @@ public DeserializationVariableReplacementSettings( AzureKeyVaultOptions? azureKeyVaultOptions = null, bool doReplaceEnvVar = false, bool doReplaceAkvVar = false, - EnvironmentVariableReplacementFailureMode envFailureMode = EnvironmentVariableReplacementFailureMode.Throw) + EnvironmentVariableReplacementFailureMode envFailureMode = EnvironmentVariableReplacementFailureMode.Throw, + ILogger? logger = null) { _azureKeyVaultOptions = azureKeyVaultOptions; DoReplaceEnvVar = doReplaceEnvVar; DoReplaceAkvVar = doReplaceAkvVar; EnvFailureMode = envFailureMode; + _logger = logger; if (DoReplaceEnvVar) { @@ -66,13 +71,68 @@ public DeserializationVariableReplacementSettings( if (DoReplaceAkvVar && _azureKeyVaultOptions is not null) { - _akvClient = CreateSecretClient(_azureKeyVaultOptions); + // Determine if endpoint points to a local .akv file. If so, load secrets from file; otherwise, use remote AKV. + if (IsLocalAkvFileEndpoint(_azureKeyVaultOptions.Endpoint)) + { + _akvFileSecrets = LoadAkvFileSecrets(_azureKeyVaultOptions.Endpoint!, _logger); + } + else + { + _akvClient = CreateSecretClient(_azureKeyVaultOptions); + } + ReplacementStrategies.Add( new Regex(OUTER_AKV_PATTERN, RegexOptions.Compiled), ReplaceAkvVariable); } } + // Checks if the endpoint is a path to a local .akv file. + private static bool IsLocalAkvFileEndpoint(string? endpoint) + => !string.IsNullOrWhiteSpace(endpoint) + && endpoint.EndsWith(".akv", StringComparison.OrdinalIgnoreCase) + && File.Exists(endpoint); + + // Loads key=value pairs from a .akv file, similar to .env style. Lines starting with '#' are comments. + private static Dictionary LoadAkvFileSecrets(string filePath, ILogger? logger = null) + { + Dictionary secrets = new(StringComparer.OrdinalIgnoreCase); + foreach (string rawLine in File.ReadAllLines(filePath)) + { + string line = rawLine.Trim(); + if (string.IsNullOrEmpty(line) || line.StartsWith('#')) + { + continue; + } + + int eqIndex = line.IndexOf('='); + if (eqIndex <= 0) + { + logger?.LogDebug("Ignoring malformed line in AKV secrets file {FilePath}: {Line}", filePath, rawLine); + continue; + } + + string key = line.Substring(0, eqIndex).Trim(); + string value = line[(eqIndex + 1)..].Trim(); + + // Remove optional surrounding quotes + if (value.Length >= 2 && ((value.StartsWith('"') && value.EndsWith('"')) || (value.StartsWith('\'') && value.EndsWith('\'')))) + { + value = value[1..^1]; + } + + if (!string.IsNullOrEmpty(key)) + { + if (!secrets.TryAdd(key, value)) + { + logger?.LogDebug("Duplicate key '{Key}' encountered in AKV secrets file {FilePath}. Skipping later value.", key, filePath); + } + } + } + + return secrets; + } + private string ReplaceEnvVariable(Match match) { // strips first and last characters, ie: '''hello'' --> ''hello' @@ -170,6 +230,15 @@ private static SecretClient CreateSecretClient(AzureKeyVaultOptions options) DataApiBuilderException.SubStatusCodes.ErrorInInitialization); } + // If endpoint is a local .akv file, we should not create a SecretClient. + if (IsLocalAkvFileEndpoint(options.Endpoint)) + { + throw new DataApiBuilderException( + "Attempted to create Azure Key Vault client for local .akv file endpoint.", + System.Net.HttpStatusCode.InternalServerError, + DataApiBuilderException.SubStatusCodes.ErrorInInitialization); + } + SecretClientOptions clientOptions = new(); if (options.RetryPolicy is not null) @@ -195,6 +264,12 @@ private static SecretClient CreateSecretClient(AzureKeyVaultOptions options) private string? GetAkvVariable(string name) { + // If using local .akv file secrets, return from dictionary. + if (_akvFileSecrets is not null) + { + return _akvFileSecrets.TryGetValue(name, out string? value) ? value : null; + } + if (_akvClient is null) { throw new InvalidOperationException("Azure Key Vault client is not initialized."); diff --git a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs index 47629ca4c8..b990b96368 100644 --- a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs +++ b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs @@ -13,6 +13,7 @@ using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.Data.SqlClient; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Azure.DataApiBuilder.Service.Tests.UnitTests @@ -240,6 +241,7 @@ public void CheckCommentParsingInConfigFile() /// but have the effect of default values when deserialized. /// It starts with a minimal config and incrementally /// adds the optional subproperties. At each step, tests for valid deserialization. + /// [TestMethod] public void TestNullableOptionalProps() { @@ -431,7 +433,7 @@ public static string GetModifiedJsonString(string[] reps, string enumString) ""host"": { ""mode"": ""development"", ""cors"": { - ""origins"": [ """ + reps[++index % reps.Length] + @""", """ + reps[++index % reps.Length] + @""" ], + ""origins"": [ """ + reps[++index % reps.Length] + @""", """ + reps[++index % reps.Length] + @"""], ""allow-credentials"": true }, ""authentication"": { @@ -671,5 +673,179 @@ private static bool TryParseAndAssertOnDefaults(string json, out RuntimeConfig p #endregion Helper Functions record StubJsonType(string Foo); + + /// + /// Test to verify Azure Key Vault variable replacement from local .akv file. + /// + [TestMethod] + public void TestAkvVariableReplacementFromLocalFile() + { + // Arrange: create a temporary .akv secrets file + string akvFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString() + ".akv"); + string secretConnectionString = "Server=tcp:127.0.0.1,1433;Persist Security Info=False;Trusted_Connection=True;TrustServerCertificate=True;MultipleActiveResultSets=False;Connection Timeout=5;"; + File.WriteAllText(akvFilePath, $"DBCONN={secretConnectionString}\nAPI_KEY=abcd\n# Comment line should be ignored\n MALFORMEDLINE \n"); + + // Escape backslashes for JSON + string escapedPath = akvFilePath.Replace("\\", "\\\\"); + + string jsonConfig = $$""" + { + "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch-alpha/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@akv('DBCONN')" + }, + "azure-key-vault": { + "endpoint": "{{escapedPath}}" + }, + "entities": { } + } + """; + + try + { + // Act + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: true); + bool parsed = RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig config, replacementSettings: replacementSettings); + + // Assert + Assert.IsTrue(parsed, "Config should parse successfully with local AKV file replacement."); + Assert.IsNotNull(config, "Config should not be null."); + Assert.AreEqual(secretConnectionString, config.DataSource.ConnectionString, "Connection string should be replaced from AKV local file secret."); + } + finally + { + // Cleanup + if (File.Exists(akvFilePath)) + { + File.Delete(akvFilePath); + } + } + } + + /// + /// Validates that when an AKV secret's value itself contains an @env('...') pattern, it is NOT further resolved + /// because replacement only runs once per original JSON token. Demonstrates that nested env patterns inside + /// AKV secret values are left intact. + /// + [TestMethod] + public void TestAkvSecretValueContainingEnvPatternIsNotEnvExpanded() + { + string akvFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString() + ".akv"); + // Valid MSSQL connection string which embeds an @env('env') pattern in the Database value. + // This pattern should NOT be expanded because replacement only runs once on the original JSON token (@akv('DBCONN')). + string secretValueWithEnvPattern = "Server=localhost;Database=@env('env');User Id=sa;Password=XXXX;"; + File.WriteAllText(akvFilePath, $"DBCONN={secretValueWithEnvPattern}\n"); + string escapedPath = akvFilePath.Replace("\\", "\\\\"); + + // Set env variable to prove it would be different if expansion occurred. + Environment.SetEnvironmentVariable("env", "SHOULD_NOT_APPEAR"); + + string jsonConfig = $$""" + { + "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch-alpha/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@akv('DBCONN')" + }, + "azure-key-vault": { + "endpoint": "{{escapedPath}}" + }, + "entities": { } + } + """; + + try + { + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: true); + bool parsed = RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig config, replacementSettings: replacementSettings); + Assert.IsTrue(parsed, "Config should parse successfully."); + Assert.IsNotNull(config); + + string actual = config.DataSource.ConnectionString; + Assert.IsTrue(actual.Contains("@env('env')"), "Nested @env pattern inside AKV secret should remain unexpanded."); + Assert.IsFalse(actual.Contains("SHOULD_NOT_APPEAR"), "Env var value should not be expanded inside AKV secret."); + Assert.IsTrue(actual.Contains("Application Name="), "Application Name should be appended for MSSQL when env replacement is enabled."); + + var builderOriginal = new SqlConnectionStringBuilder(secretValueWithEnvPattern.Replace("Server=", "Data Source=").Replace("Database=", "Initial Catalog=")); + var builderActual = new SqlConnectionStringBuilder(actual); + Assert.AreEqual(builderOriginal["Data Source"], builderActual["Data Source"], "Server/Data Source should match."); + Assert.AreEqual(builderOriginal["Initial Catalog"], builderActual["Initial Catalog"], "Database/Initial Catalog should match (with env pattern retained)."); + Assert.AreEqual(builderOriginal["User ID"], builderActual["User ID"], "User Id should match."); + Assert.AreEqual(builderOriginal["Password"], builderActual["Password"], "Password should match."); + } + finally + { + if (File.Exists(akvFilePath)) + { + File.Delete(akvFilePath); + } + + Environment.SetEnvironmentVariable("env", null); + } + } + + /// + /// Validates two-pass replacement where an env var resolves to an AKV pattern which then resolves to the secret value. + /// connection-string = @env('env_variable'), env_variable value = @akv('DBCONN'), AKV secret DBCONN holds the final connection string. + /// + [TestMethod] + public void TestEnvVariableResolvingToAkvPatternIsExpandedInSecondPass() + { + string akvFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString() + ".akv"); + string finalSecretValue = "Server=localhost;Database=Test;User Id=sa;Password=XXXX;"; + File.WriteAllText(akvFilePath, $"DBCONN={finalSecretValue}\n"); + string escapedPath = akvFilePath.Replace("\\", "\\\\"); + Environment.SetEnvironmentVariable("env_variable", "@akv('DBCONN')"); + + string jsonConfig = $$""" + { + "$schema": "https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch-alpha/dab.draft.schema.json", + "data-source": { + "database-type": "mssql", + "connection-string": "@env('env_variable')" + }, + "azure-key-vault": { + "endpoint": "{{escapedPath}}" + }, + "entities": { } + } + """; + + try + { + DeserializationVariableReplacementSettings replacementSettings = new( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: true); + bool parsed = RuntimeConfigLoader.TryParseConfig(jsonConfig, out RuntimeConfig config, replacementSettings: replacementSettings); + Assert.IsTrue(parsed, "Config should parse successfully."); + Assert.IsNotNull(config); + + string expected = RuntimeConfigLoader.GetConnectionStringWithApplicationName(finalSecretValue); + var builderExpected = new SqlConnectionStringBuilder(expected); + var builderActual = new SqlConnectionStringBuilder(config.DataSource.ConnectionString); + Assert.AreEqual(builderExpected["Data Source"], builderActual["Data Source"], "Data Source should match."); + Assert.AreEqual(builderExpected["Initial Catalog"], builderActual["Initial Catalog"], "Initial Catalog should match."); + Assert.AreEqual(builderExpected["User ID"], builderActual["User ID"], "User ID should match."); + Assert.AreEqual(builderExpected["Password"], builderActual["Password"], "Password should match."); + Assert.IsTrue(builderActual.ApplicationName?.Contains("dab_"), "Application Name should be appended including product identifier."); + } + finally + { + if (File.Exists(akvFilePath)) + { + File.Delete(akvFilePath); + } + + Environment.SetEnvironmentVariable("env_variable", null); + } + } } } From 297ff3de8b49e5e53b527163ec6da1c760bbbed4 Mon Sep 17 00:00:00 2001 From: aaron burtle Date: Tue, 18 Nov 2025 14:19:30 -0800 Subject: [PATCH 3/8] refactor the responses to use common utility --- .../BuiltInTools/ReadRecordsTool.cs | 144 +++------------- .../BuiltInTools/UpdateRecordTool.cs | 163 ++++-------------- 2 files changed, 66 insertions(+), 241 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs index 42b1f41ea0..3791fd0bba 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs @@ -15,6 +15,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http; @@ -85,7 +86,7 @@ public async Task ExecuteAsync( if (runtimeConfig.McpDmlTools?.ReadRecords is not true) { - return BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( "ToolDisabled", "The read_records tool is disabled in the configuration.", logger); @@ -105,14 +106,14 @@ public async Task ExecuteAsync( // Extract arguments if (arguments == null) { - return BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); } JsonElement root = arguments.RootElement; if (!root.TryGetProperty("entity", out JsonElement entityElement) || string.IsNullOrWhiteSpace(entityElement.GetString())) { - return BuildErrorResult("InvalidArguments", "Missing required argument 'entity'.", logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", "Missing required argument 'entity'.", logger); } entityName = entityElement.GetString()!; @@ -157,12 +158,12 @@ public async Task ExecuteAsync( } catch (Exception) { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } // Authorization check in the existing entity @@ -173,12 +174,12 @@ public async Task ExecuteAsync( if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) { - return BuildErrorResult("PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); + return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); } if (!TryResolveAuthorizedRole(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) { - return BuildErrorResult("PermissionDenied", authError, logger); + return McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger); } // Build and validate Find context @@ -208,7 +209,7 @@ public async Task ExecuteAsync( { if (string.IsNullOrWhiteSpace(param)) { - return BuildErrorResult("InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); } sortQueryString += $"{param}, "; @@ -230,7 +231,7 @@ public async Task ExecuteAsync( requirements: new[] { new ColumnsPermissionsRequirement() }); if (!authorizationResult.Succeeded) { - return BuildErrorResult("PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + return McpResponseBuilder.BuildErrorResult("PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); } // Execute @@ -240,34 +241,40 @@ public async Task ExecuteAsync( : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true); // Normalize response - string rawPayloadJson = ExtractResultJson(actionResult); - JsonDocument result = JsonDocument.Parse(rawPayloadJson); + string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult); + using JsonDocument result = JsonDocument.Parse(rawPayloadJson); JsonElement queryRoot = result.RootElement; - return BuildSuccessResult( - entityName, - queryRoot.Clone(), - logger); + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = queryRoot.Clone(), + ["message"] = $"Successfully read records for entity '{entityName}'" + }, + logger, + $"ReadRecordsTool success for entity {entityName}."); } catch (OperationCanceledException) { - return BuildErrorResult("OperationCanceled", "The read operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The read operation was canceled.", logger); } catch (DbException argEx) { - return BuildErrorResult("DatabaseOperationFailed", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult("DatabaseOperationFailed", argEx.Message, logger); } catch (ArgumentException argEx) { - return BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); } catch (DataApiBuilderException argEx) { - return BuildErrorResult(argEx.StatusCode.ToString(), argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(argEx.StatusCode.ToString(), argEx.Message, logger); } - catch (Exception) + catch (Exception ex) { - return BuildErrorResult("UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); + logger?.LogError(ex, "Unexpected error in ReadRecordsTool."); + return McpResponseBuilder.BuildErrorResult("UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); } } @@ -324,100 +331,5 @@ private static bool TryResolveAuthorizedRole( error = $"You do not have permission to read records for entity '{entityName}'."; return false; } - - /// - /// Returns a result from the query in the case that it was successfully ran. - /// - /// Name of the entity used in the request. - /// Query result from engine. - /// MCP logger that returns all logged events. - private static CallToolResult BuildSuccessResult( - string entityName, - JsonElement engineRootElement, - ILogger? logger) - { - // Build normalized response - Dictionary normalized = new() - { - ["status"] = "success", - ["result"] = engineRootElement // only requested values - }; - - string output = JsonSerializer.Serialize(normalized, new JsonSerializerOptions { WriteIndented = true }); - - logger?.LogInformation("ReadRecordsTool success for entity {Entity}.", entityName); - - return new CallToolResult - { - Content = new List - { - new TextContentBlock { Type = "text", Text = output } - } - }; - } - - /// - /// Returns an error if the query failed to run at any point. - /// - /// Type of error that is encountered. - /// Error message given to the user. - /// MCP logger that returns all logged events. - private static CallToolResult BuildErrorResult( - string errorType, - string message, - ILogger? logger) - { - Dictionary errorObj = new() - { - ["status"] = "error", - ["error"] = new Dictionary - { - ["type"] = errorType, - ["message"] = message - } - }; - - string output = JsonSerializer.Serialize(errorObj); - - logger?.LogError("ReadRecordsTool error {ErrorType}: {Message}", errorType, message); - - return new CallToolResult - { - Content = - [ - new TextContentBlock { Type = "text", Text = output } - ], - IsError = true - }; - } - - /// - /// Extracts a JSON string from a typical IActionResult. - /// Falls back to "{}" for unsupported/empty cases to avoid leaking internals. - /// - private static string ExtractResultJson(IActionResult? result) - { - switch (result) - { - case ObjectResult obj: - if (obj.Value is JsonElement je) - { - return je.GetRawText(); - } - - if (obj.Value is JsonDocument jd) - { - return jd.RootElement.GetRawText(); - } - - return JsonSerializer.Serialize(obj.Value ?? new object()); - - case ContentResult content: - return string.IsNullOrWhiteSpace(content.Content) ? "{}" : content.Content; - - default: - return "{}"; - } - } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs index 9e7d101fe6..e58bea7e09 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs @@ -13,6 +13,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.DependencyInjection; @@ -92,7 +93,7 @@ public async Task ExecuteAsync( // 2)Check if the tool is enabled in configuration before proceeding. if (config.McpDmlTools?.UpdateRecord != true) { - return BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( "ToolDisabled", "The update_record tool is disabled in the configuration.", logger); @@ -106,12 +107,12 @@ public async Task ExecuteAsync( // 3) Parsing & basic argument validation (entity, keys, fields) if (arguments is null) { - return BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); } if (!TryParseArguments(arguments.RootElement, out string entityName, out Dictionary keys, out Dictionary fields, out string parseError)) { - return BuildErrorResult("InvalidArguments", parseError, logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", parseError, logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); @@ -128,12 +129,12 @@ public async Task ExecuteAsync( } catch (Exception) { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) { - return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } // 5) Authorization after we have a known entity @@ -143,12 +144,12 @@ public async Task ExecuteAsync( if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) { - return BuildErrorResult("PermissionDenied", "Permission denied: unable to resolve a valid role context for update operation.", logger); + return McpResponseBuilder.BuildErrorResult("PermissionDenied", "Permission denied: unable to resolve a valid role context for update operation.", logger); } if (!TryResolveAuthorizedRoleHasPermission(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) { - return BuildErrorResult("PermissionDenied", $"Permission denied: {authError}", logger); + return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"Permission denied: {authError}", logger); } // 6) Build and validate Upsert (UpdateIncremental) context @@ -165,7 +166,7 @@ public async Task ExecuteAsync( { if (kvp.Value is null) { - return BuildErrorResult("InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); } context.PrimaryKeyValuePairs[kvp.Key] = kvp.Value; @@ -193,7 +194,7 @@ public async Task ExecuteAsync( if (errorMsg.Contains("No Update could be performed, record not found", StringComparison.OrdinalIgnoreCase)) { - return BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( "InvalidArguments", "No record found with the given key.", logger); @@ -208,29 +209,46 @@ public async Task ExecuteAsync( cancellationToken.ThrowIfCancellationRequested(); // 8) Normalize response (success or engine error payload) - string rawPayloadJson = ExtractResultJson(mutationResult); + string rawPayloadJson = McpResponseBuilder.ExtractResultJson(mutationResult); using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson); JsonElement root = resultDoc.RootElement; - return BuildSuccessResult( - entityName: entityName, - engineRootElement: root.Clone(), - logger: logger); + // Extract first item of value[] array (updated record) + Dictionary filteredResult = new(); + if (root.TryGetProperty("value", out JsonElement valueArray) && + valueArray.ValueKind == JsonValueKind.Array && + valueArray.GetArrayLength() > 0) + { + JsonElement firstItem = valueArray[0]; + foreach (JsonProperty prop in firstItem.EnumerateObject()) + { + filteredResult[prop.Name] = GetJsonValue(prop.Value); + } + } + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = filteredResult, + ["message"] = $"Successfully updated record in entity '{entityName}'" + }, + logger, + $"UpdateRecordTool success for entity {entityName}."); } catch (OperationCanceledException) { - return BuildErrorResult("OperationCanceled", "The update operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The update operation was canceled.", logger); } catch (ArgumentException argEx) { - return BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); } catch (Exception ex) { ILogger? innerLogger = serviceProvider.GetService>(); innerLogger?.LogError(ex, "Unexpected error in UpdateRecordTool."); - - return BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( "UnexpectedError", ex.Message ?? "An unexpected error occurred during the update operation.", logger); @@ -349,53 +367,7 @@ private static bool TryResolveAuthorizedRoleHasPermission( #endregion - #region Response Builders & Utilities - - private static CallToolResult BuildSuccessResult( - string entityName, - JsonElement engineRootElement, - ILogger? logger) - { - // Extract only requested keys and updated fields from engineRootElement - Dictionary filteredResult = new(); - - // Navigate to "value" array in the engine result - if (engineRootElement.TryGetProperty("value", out JsonElement valueArray) && - valueArray.ValueKind == JsonValueKind.Array && - valueArray.GetArrayLength() > 0) - { - JsonElement firstItem = valueArray[0]; - - // Include all properties from the result - foreach (JsonProperty prop in firstItem.EnumerateObject()) - { - filteredResult[prop.Name] = GetJsonValue(prop.Value); - } - } - - // Build normalized response - Dictionary normalized = new() - { - ["status"] = "success", - ["result"] = filteredResult - }; - - string output = JsonSerializer.Serialize(normalized, new JsonSerializerOptions { WriteIndented = true }); - - logger?.LogInformation("UpdateRecordTool success for entity {Entity}.", entityName); - - return new CallToolResult - { - Content = new List - { - new TextContentBlock { Type = "text", Text = output } - } - }; - } - - /// - /// Converts JsonElement to .NET object dynamically. - /// + #region Utilities private static object? GetJsonValue(JsonElement element) { return element.ValueKind switch @@ -405,68 +377,9 @@ private static CallToolResult BuildSuccessResult( JsonValueKind.True => true, JsonValueKind.False => false, JsonValueKind.Null => null, - _ => element.GetRawText() // fallback for arrays/objects + _ => element.GetRawText() }; } - - private static CallToolResult BuildErrorResult( - string errorType, - string message, - ILogger? logger) - { - Dictionary errorObj = new() - { - ["status"] = "error", - ["error"] = new Dictionary - { - ["type"] = errorType, - ["message"] = message - } - }; - - string output = JsonSerializer.Serialize(errorObj); - - logger?.LogWarning("UpdateRecordTool error {ErrorType}: {Message}", errorType, message); - - return new CallToolResult - { - Content = - [ - new TextContentBlock { Type = "text", Text = output } - ], - IsError = true - }; - } - - /// - /// Extracts a JSON string from a typical IActionResult. - /// Falls back to "{}" for unsupported/empty cases to avoid leaking internals. - /// - private static string ExtractResultJson(IActionResult? result) - { - switch (result) - { - case ObjectResult obj: - if (obj.Value is JsonElement je) - { - return je.GetRawText(); - } - - if (obj.Value is JsonDocument jd) - { - return jd.RootElement.GetRawText(); - } - - return JsonSerializer.Serialize(obj.Value ?? new object()); - - case ContentResult content: - return string.IsNullOrWhiteSpace(content.Content) ? "{}" : content.Content; - - default: - return "{}"; - } - } - #endregion } } From 5b963a9188f8313b56d11bb81a736306aee30a13 Mon Sep 17 00:00:00 2001 From: aaron burtle Date: Thu, 20 Nov 2025 15:28:36 -0800 Subject: [PATCH 4/8] include toolname in shared error response --- .../BuiltInTools/CreateRecordTool.cs | 25 ++++++---- .../BuiltInTools/DeleteRecordTool.cs | 39 +++++++++------ .../BuiltInTools/DescribeEntitiesTool.cs | 9 ++++ .../BuiltInTools/ExecuteEntityTool.cs | 47 +++++++++++-------- .../BuiltInTools/ReadRecordsTool.cs | 28 ++++++----- .../BuiltInTools/UpdateRecordTool.cs | 24 +++++----- .../Utils/McpResponseBuilder.cs | 2 + 7 files changed, 107 insertions(+), 67 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs index 68447f16f4..9d64fd4cd7 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs @@ -57,20 +57,22 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; if (arguments == null) { - return Utils.McpResponseBuilder.BuildErrorResult("Invalid Arguments", "No arguments provided", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "Invalid Arguments", "No arguments provided", logger); } RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); if (!runtimeConfigProvider.TryGetConfig(out RuntimeConfig? runtimeConfig)) { - return Utils.McpResponseBuilder.BuildErrorResult("Invalid Configuration", "Runtime configuration not available", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "Invalid Configuration", "Runtime configuration not available", logger); } if (runtimeConfig.McpDmlTools?.CreateRecord != true) { return Utils.McpResponseBuilder.BuildErrorResult( + toolName, "ToolDisabled", "The create_record tool is disabled in the configuration.", logger); @@ -84,13 +86,13 @@ public async Task ExecuteAsync( if (!root.TryGetProperty("entity", out JsonElement entityElement) || !root.TryGetProperty("data", out JsonElement dataElement)) { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidArguments", "Missing required arguments 'entity' or 'data'", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required arguments 'entity' or 'data'", logger); } string entityName = entityElement.GetString() ?? string.Empty; if (string.IsNullOrWhiteSpace(entityName)) { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidArguments", "Entity name cannot be empty", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Entity name cannot be empty", logger); } string dataSourceName; @@ -100,7 +102,7 @@ public async Task ExecuteAsync( } catch (Exception) { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidConfiguration", $"Entity '{entityName}' not found in configuration", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", $"Entity '{entityName}' not found in configuration", logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); @@ -113,7 +115,7 @@ public async Task ExecuteAsync( } catch (Exception) { - return Utils.McpResponseBuilder.BuildErrorResult("InvalidConfiguration", $"Database object for entity '{entityName}' not found", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", $"Database object for entity '{entityName}' not found", logger); } // Create an HTTP context for authorization @@ -123,13 +125,13 @@ public async Task ExecuteAsync( if (httpContext is null || !authorizationResolver.IsValidRoleContext(httpContext)) { - return Utils.McpResponseBuilder.BuildErrorResult("PermissionDenied", "Permission denied: Unable to resolve a valid role context for update operation.", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", "Permission denied: Unable to resolve a valid role context for update operation.", logger); } // Validate that we have at least one role authorized for create if (!TryResolveAuthorizedRole(httpContext, authorizationResolver, entityName, out string authError)) { - return Utils.McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", authError, logger); } JsonElement insertPayloadRoot = dataElement.Clone(); @@ -150,12 +152,13 @@ public async Task ExecuteAsync( } catch (Exception ex) { - return Utils.McpResponseBuilder.BuildErrorResult("ValidationFailed", $"Request validation failed: {ex.Message}", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "ValidationFailed", $"Request validation failed: {ex.Message}", logger); } } else { return Utils.McpResponseBuilder.BuildErrorResult( + toolName, "InvalidCreateTarget", "The create_record tool is only available for tables.", logger); @@ -185,6 +188,7 @@ public async Task ExecuteAsync( if (isError) { return Utils.McpResponseBuilder.BuildErrorResult( + toolName, "CreateFailed", $"Failed to create record in entity '{entityName}'. Error: {JsonSerializer.Serialize(objectResult.Value)}", logger); @@ -207,6 +211,7 @@ public async Task ExecuteAsync( if (result is null) { return Utils.McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", $"Mutation engine returned null result for entity '{entityName}'", logger); @@ -226,7 +231,7 @@ public async Task ExecuteAsync( } catch (Exception ex) { - return Utils.McpResponseBuilder.BuildErrorResult("Error", $"Error: {ex.Message}", logger); + return Utils.McpResponseBuilder.BuildErrorResult(toolName, "Error", $"Error: {ex.Message}", logger); } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs index 7abac888c5..eb310ae364 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs @@ -73,6 +73,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; try { @@ -87,6 +88,7 @@ public async Task ExecuteAsync( if (config.McpDmlTools?.DeleteRecord != true) { return McpResponseBuilder.BuildErrorResult( + toolName, "ToolDisabled", $"The {this.GetToolMetadata().Name} tool is disabled in the configuration.", logger); @@ -95,12 +97,12 @@ public async Task ExecuteAsync( // 3) Parsing & basic argument validation if (arguments is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } if (!McpArgumentParser.TryParseEntityAndKeys(arguments.RootElement, out string entityName, out Dictionary keys, out string parseError)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", parseError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); @@ -117,18 +119,18 @@ public async Task ExecuteAsync( } catch (Exception) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } // Validate it's a table or view if (dbObject.SourceType != EntitySourceType.Table && dbObject.SourceType != EntitySourceType.View) { - return McpResponseBuilder.BuildErrorResult("InvalidEntity", $"Entity '{entityName}' is not a table or view. Use 'execute-entity' for stored procedures.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity '{entityName}' is not a table or view. Use 'execute-entity' for stored procedures.", logger); } // 5) Authorization @@ -138,7 +140,7 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"Permission denied: {roleError}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"Permission denied: {roleError}", logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( @@ -149,7 +151,7 @@ public async Task ExecuteAsync( out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"Permission denied: {authError}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"Permission denied: {authError}", logger); } // 6) Build and validate Delete context @@ -164,7 +166,7 @@ public async Task ExecuteAsync( { if (kvp.Value is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); } context.PrimaryKeyValuePairs[kvp.Key] = kvp.Value; @@ -195,6 +197,7 @@ public async Task ExecuteAsync( { string keyDetails = McpJsonHelper.FormatKeyDetails(keys); return McpResponseBuilder.BuildErrorResult( + toolName, "RecordNotFound", $"No record found with the specified primary key: {keyDetails}", logger); @@ -203,6 +206,7 @@ public async Task ExecuteAsync( message.Contains("REFERENCE constraint", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "ConstraintViolation", "Cannot delete record due to foreign key constraint. Other records depend on this record.", logger); @@ -211,6 +215,7 @@ public async Task ExecuteAsync( message.Contains("authorization", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "PermissionDenied", "You do not have permission to delete this record.", logger); @@ -219,6 +224,7 @@ public async Task ExecuteAsync( message.Contains("type", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", "Invalid data type for one or more key values.", logger); @@ -226,6 +232,7 @@ public async Task ExecuteAsync( // For any other DAB exceptions, return the message as-is return McpResponseBuilder.BuildErrorResult( + toolName, "DataApiBuilderError", dabEx.Message, logger); @@ -242,7 +249,7 @@ public async Task ExecuteAsync( 208 => $"Table '{dbObject.FullName}' not found in the database.", _ => $"Database error: {sqlEx.Message}" }; - return McpResponseBuilder.BuildErrorResult("DatabaseError", errorMessage, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", errorMessage, logger); } catch (DbException dbEx) { @@ -254,6 +261,7 @@ public async Task ExecuteAsync( if (errorMsg.Contains("foreign key") || errorMsg.Contains("constraint")) { return McpResponseBuilder.BuildErrorResult( + toolName, "ConstraintViolation", "Cannot delete record due to foreign key constraint. Other records depend on this record.", logger); @@ -261,24 +269,25 @@ public async Task ExecuteAsync( else if (errorMsg.Contains("not found") || errorMsg.Contains("does not exist")) { return McpResponseBuilder.BuildErrorResult( + toolName, "RecordNotFound", "No record found with the specified primary key.", logger); } - return McpResponseBuilder.BuildErrorResult("DatabaseError", $"Database error: {dbEx.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", $"Database error: {dbEx.Message}", logger); } catch (InvalidOperationException ioEx) when (ioEx.Message.Contains("connection", StringComparison.OrdinalIgnoreCase)) { // Handle connection-related issues logger?.LogError(ioEx, "Database connection error"); - return McpResponseBuilder.BuildErrorResult("ConnectionError", "Failed to connect to the database.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "ConnectionError", "Failed to connect to the database.", logger); } catch (TimeoutException timeoutEx) { // Handle query timeout logger?.LogError(timeoutEx, "Delete operation timeout for {Entity}", entityName); - return McpResponseBuilder.BuildErrorResult("TimeoutError", "The delete operation timed out.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "TimeoutError", "The delete operation timed out.", logger); } catch (Exception ex) { @@ -289,6 +298,7 @@ public async Task ExecuteAsync( { string keyDetails = McpJsonHelper.FormatKeyDetails(keys); return McpResponseBuilder.BuildErrorResult( + toolName, "RecordNotFound", $"No entity found with the given key {keyDetails}.", logger); @@ -325,11 +335,11 @@ public async Task ExecuteAsync( } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The delete operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The delete operation was canceled.", logger); } catch (ArgumentException argEx) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (Exception ex) { @@ -337,6 +347,7 @@ public async Task ExecuteAsync( innerLogger?.LogError(ex, "Unexpected error in DeleteRecordTool."); return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", "An unexpected error occurred during the delete operation.", logger); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs index 154b37ee80..b8c7d975a2 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs @@ -67,6 +67,7 @@ public Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; try { @@ -78,6 +79,7 @@ public Task ExecuteAsync( if (!IsToolEnabled(runtimeConfig)) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "ToolDisabled", $"The {GetToolMetadata().Name} tool is disabled in the configuration.", logger)); @@ -158,6 +160,7 @@ public Task ExecuteAsync( if (entityFilter != null && entityFilter.Count > 0) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "EntitiesNotFound", $"No entities found matching the filter: {string.Join(", ", entityFilter)}", logger)); @@ -165,6 +168,7 @@ public Task ExecuteAsync( else { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "NoEntitiesConfigured", "No entities are configured in the runtime configuration.", logger)); @@ -197,6 +201,7 @@ public Task ExecuteAsync( catch (OperationCanceledException) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "OperationCanceled", "The describe operation was canceled.", logger)); @@ -205,6 +210,7 @@ public Task ExecuteAsync( { logger?.LogError(dabEx, "Data API Builder error in DescribeEntitiesTool"); return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "DataApiBuilderError", dabEx.Message, logger)); @@ -212,6 +218,7 @@ public Task ExecuteAsync( catch (ArgumentException argEx) { return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", argEx.Message, logger)); @@ -220,6 +227,7 @@ public Task ExecuteAsync( { logger?.LogError(ioEx, "Invalid operation in DescribeEntitiesTool"); return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "InvalidOperation", "Failed to retrieve entity metadata: " + ioEx.Message, logger)); @@ -228,6 +236,7 @@ public Task ExecuteAsync( { logger?.LogError(ex, "Unexpected error in DescribeEntitiesTool"); return Task.FromResult(McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", "An unexpected error occurred while describing entities.", logger)); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs index be2fa7af36..6b0bc28383 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs @@ -73,6 +73,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; try { @@ -87,26 +88,27 @@ public async Task ExecuteAsync( if (config.McpDmlTools?.ExecuteEntity != true) { return McpResponseBuilder.BuildErrorResult( + toolName, "ToolDisabled", - $"The {this.GetToolMetadata().Name} tool is disabled in the configuration.", + $"The {toolName} tool is disabled in the configuration.", logger); } // 3) Parsing & basic argument validation if (arguments is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } if (!TryParseExecuteArguments(arguments.RootElement, out string entity, out Dictionary parameters, out string parseError)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", parseError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } // Entity is required if (string.IsNullOrWhiteSpace(entity)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "Entity is required", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Entity is required", logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); @@ -115,12 +117,12 @@ public async Task ExecuteAsync( // 4) Validate entity exists and is a stored procedure if (!config.Entities.TryGetValue(entity, out Entity? entityConfig)) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entity}' not found in configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entity}' not found in configuration.", logger); } if (entityConfig.Source.Type != EntitySourceType.StoredProcedure) { - return McpResponseBuilder.BuildErrorResult("InvalidEntity", $"Entity {entity} cannot be executed.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity {entity} cannot be executed.", logger); } // 5) Resolve metadata @@ -134,12 +136,12 @@ public async Task ExecuteAsync( } catch (Exception) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Failed to resolve entity metadata for '{entity}'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Failed to resolve entity metadata for '{entity}'.", logger); } if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entity, out DatabaseObject? dbObject) || dbObject is null) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Failed to resolve database object for entity '{entity}'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Failed to resolve database object for entity '{entity}'.", logger); } // 6) Authorization - Never bypass permissions @@ -149,7 +151,7 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", roleError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", roleError, logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( @@ -160,7 +162,7 @@ public async Task ExecuteAsync( out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", authError, logger); } // 7) Validate parameters against metadata @@ -171,7 +173,7 @@ public async Task ExecuteAsync( { if (!entityConfig.Source.Parameters.Any(p => p.Name == param.Key)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", $"Invalid parameter: {param.Key}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid parameter: {param.Key}", logger); } } } @@ -241,6 +243,7 @@ public async Task ExecuteAsync( message.Contains("authorization", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "PermissionDenied", "You do not have permission to execute this stored procedure.", logger); @@ -249,6 +252,7 @@ public async Task ExecuteAsync( message.Contains("type", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", "Invalid data type for one or more parameters.", logger); @@ -256,6 +260,7 @@ public async Task ExecuteAsync( // For any other DAB exceptions, return the message as-is return McpResponseBuilder.BuildErrorResult( + toolName, "DataApiBuilderError", dabEx.Message, logger); @@ -273,48 +278,49 @@ public async Task ExecuteAsync( 229 or 262 => $"Permission denied to execute stored procedure '{entityConfig.Source.Object}'.", _ => $"Database error: {sqlEx.Message}" }; - return McpResponseBuilder.BuildErrorResult("DatabaseError", errorMessage, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", errorMessage, logger); } catch (DbException dbEx) { // Handle generic database exceptions (works for PostgreSQL, MySQL, etc.) logger?.LogError(dbEx, "Database error executing stored procedure {StoredProcedure}", entity); - return McpResponseBuilder.BuildErrorResult("DatabaseError", $"Database error: {dbEx.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", $"Database error: {dbEx.Message}", logger); } catch (InvalidOperationException ioEx) when (ioEx.Message.Contains("connection", StringComparison.OrdinalIgnoreCase)) { // Handle connection-related issues logger?.LogError(ioEx, "Database connection error"); - return McpResponseBuilder.BuildErrorResult("ConnectionError", "Failed to connect to the database.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "ConnectionError", "Failed to connect to the database.", logger); } catch (TimeoutException timeoutEx) { // Handle query timeout logger?.LogError(timeoutEx, "Stored procedure execution timeout for {StoredProcedure}", entity); - return McpResponseBuilder.BuildErrorResult("TimeoutError", "The stored procedure execution timed out.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "TimeoutError", "The stored procedure execution timed out.", logger); } catch (Exception ex) { // Generic database/execution errors logger?.LogError(ex, "Unexpected error executing stored procedure {StoredProcedure}", entity); - return McpResponseBuilder.BuildErrorResult("DatabaseError", "An error occurred while executing the stored procedure.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseError", "An error occurred while executing the stored procedure.", logger); } // 11) Build response with execution result - return BuildExecuteSuccessResponse(entity, parameters, queryResult, logger); + return BuildExecuteSuccessResponse(toolName, entity, parameters, queryResult, logger); } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The execute operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The execute operation was canceled.", logger); } catch (ArgumentException argEx) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (Exception ex) { logger?.LogError(ex, "Unexpected error in ExecuteEntityTool."); return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", "An unexpected error occurred during the execute operation.", logger); @@ -386,6 +392,7 @@ private static bool TryParseExecuteArguments( /// Builds a successful response for the execute operation. /// private static CallToolResult BuildExecuteSuccessResponse( + string toolName, string entityName, Dictionary? parameters, IActionResult? queryResult, @@ -426,6 +433,7 @@ private static CallToolResult BuildExecuteSuccessResponse( else if (queryResult is BadRequestObjectResult badRequest) { return McpResponseBuilder.BuildErrorResult( + toolName, "BadRequest", badRequest.Value?.ToString() ?? "Bad request", logger); @@ -433,6 +441,7 @@ private static CallToolResult BuildExecuteSuccessResponse( else if (queryResult is UnauthorizedObjectResult) { return McpResponseBuilder.BuildErrorResult( + toolName, "PermissionDenied", "You do not have permission to execute this entity", logger); diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs index 3791fd0bba..7561e1ae54 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs @@ -79,6 +79,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; // Get runtime config RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); @@ -87,6 +88,7 @@ public async Task ExecuteAsync( if (runtimeConfig.McpDmlTools?.ReadRecords is not true) { return McpResponseBuilder.BuildErrorResult( + toolName, "ToolDisabled", "The read_records tool is disabled in the configuration.", logger); @@ -106,14 +108,14 @@ public async Task ExecuteAsync( // Extract arguments if (arguments == null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } JsonElement root = arguments.RootElement; if (!root.TryGetProperty("entity", out JsonElement entityElement) || string.IsNullOrWhiteSpace(entityElement.GetString())) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "Missing required argument 'entity'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'entity'.", logger); } entityName = entityElement.GetString()!; @@ -158,12 +160,12 @@ public async Task ExecuteAsync( } catch (Exception) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } // Authorization check in the existing entity @@ -174,12 +176,12 @@ public async Task ExecuteAsync( if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); } if (!TryResolveAuthorizedRole(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", authError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", authError, logger); } // Build and validate Find context @@ -209,7 +211,7 @@ public async Task ExecuteAsync( { if (string.IsNullOrWhiteSpace(param)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); } sortQueryString += $"{param}, "; @@ -231,7 +233,7 @@ public async Task ExecuteAsync( requirements: new[] { new ColumnsPermissionsRequirement() }); if (!authorizationResult.Succeeded) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); } // Execute @@ -257,24 +259,24 @@ public async Task ExecuteAsync( } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The read operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The read operation was canceled.", logger); } catch (DbException argEx) { - return McpResponseBuilder.BuildErrorResult("DatabaseOperationFailed", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", argEx.Message, logger); } catch (ArgumentException argEx) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (DataApiBuilderException argEx) { - return McpResponseBuilder.BuildErrorResult(argEx.StatusCode.ToString(), argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger); } catch (Exception ex) { logger?.LogError(ex, "Unexpected error in ReadRecordsTool."); - return McpResponseBuilder.BuildErrorResult("UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs index e58bea7e09..05e66a7fd7 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs @@ -84,8 +84,7 @@ public async Task ExecuteAsync( CancellationToken cancellationToken = default) { ILogger? logger = serviceProvider.GetService>(); - - // 1) Resolve required services & configuration + string toolName = GetToolMetadata().Name; RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); RuntimeConfig config = runtimeConfigProvider.GetConfig(); @@ -94,6 +93,7 @@ public async Task ExecuteAsync( if (config.McpDmlTools?.UpdateRecord != true) { return McpResponseBuilder.BuildErrorResult( + toolName, "ToolDisabled", "The update_record tool is disabled in the configuration.", logger); @@ -107,12 +107,12 @@ public async Task ExecuteAsync( // 3) Parsing & basic argument validation (entity, keys, fields) if (arguments is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } if (!TryParseArguments(arguments.RootElement, out string entityName, out Dictionary keys, out Dictionary fields, out string parseError)) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", parseError, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); @@ -129,12 +129,12 @@ public async Task ExecuteAsync( } catch (Exception) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) { - return McpResponseBuilder.BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); } // 5) Authorization after we have a known entity @@ -144,12 +144,12 @@ public async Task ExecuteAsync( if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", "Permission denied: unable to resolve a valid role context for update operation.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", "Permission denied: unable to resolve a valid role context for update operation.", logger); } if (!TryResolveAuthorizedRoleHasPermission(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult("PermissionDenied", $"Permission denied: {authError}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"Permission denied: {authError}", logger); } // 6) Build and validate Upsert (UpdateIncremental) context @@ -166,7 +166,7 @@ public async Task ExecuteAsync( { if (kvp.Value is null) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Primary key value for '{kvp.Key}' cannot be null.", logger); } context.PrimaryKeyValuePairs[kvp.Key] = kvp.Value; @@ -195,6 +195,7 @@ public async Task ExecuteAsync( if (errorMsg.Contains("No Update could be performed, record not found", StringComparison.OrdinalIgnoreCase)) { return McpResponseBuilder.BuildErrorResult( + toolName, "InvalidArguments", "No record found with the given key.", logger); @@ -238,17 +239,18 @@ public async Task ExecuteAsync( } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult("OperationCanceled", "The update operation was canceled.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The update operation was canceled.", logger); } catch (ArgumentException argEx) { - return McpResponseBuilder.BuildErrorResult("InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); } catch (Exception ex) { ILogger? innerLogger = serviceProvider.GetService>(); innerLogger?.LogError(ex, "Unexpected error in UpdateRecordTool."); return McpResponseBuilder.BuildErrorResult( + toolName, "UnexpectedError", ex.Message ?? "An unexpected error occurred during the update operation.", logger); diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs index afbccbda38..cfd8739c30 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs @@ -43,12 +43,14 @@ public static CallToolResult BuildSuccessResult( /// Builds an error response for MCP tools. /// public static CallToolResult BuildErrorResult( + string toolName, string errorType, string message, ILogger? logger = null) { Dictionary errorObj = new() { + ["toolName"] = toolName, ["status"] = "error", ["error"] = new Dictionary { From 39a1e35e32a60d8ace42b8d880f91e61bad73d4d Mon Sep 17 00:00:00 2001 From: aaron burtle Date: Tue, 25 Nov 2025 15:01:40 -0800 Subject: [PATCH 5/8] factor out getJson() --- .../BuiltInTools/UpdateRecordTool.cs | 17 +---------------- .../Utils/McpResponseBuilder.cs | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs index 05e66a7fd7..ab8956b618 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs @@ -223,7 +223,7 @@ public async Task ExecuteAsync( JsonElement firstItem = valueArray[0]; foreach (JsonProperty prop in firstItem.EnumerateObject()) { - filteredResult[prop.Name] = GetJsonValue(prop.Value); + filteredResult[prop.Name] = McpResponseBuilder.GetJsonValue(prop.Value); } } @@ -368,20 +368,5 @@ private static bool TryResolveAuthorizedRoleHasPermission( } #endregion - - #region Utilities - private static object? GetJsonValue(JsonElement element) - { - return element.ValueKind switch - { - JsonValueKind.String => element.GetString(), - JsonValueKind.Number => element.TryGetInt64(out long l) ? l : element.GetDouble(), - JsonValueKind.True => true, - JsonValueKind.False => false, - JsonValueKind.Null => null, - _ => element.GetRawText() - }; - } - #endregion } } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs index cfd8739c30..49cacef2c3 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpResponseBuilder.cs @@ -101,5 +101,21 @@ public static string ExtractResultJson(IActionResult? result) return "{}"; } } + + /// + /// Extracts value from a JsonElement. + /// + public static object? GetJsonValue(JsonElement element) + { + return element.ValueKind switch + { + JsonValueKind.String => element.GetString(), + JsonValueKind.Number => element.TryGetInt64(out long l) ? l : element.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => element.GetRawText() + }; + } } } From fa1fbbc8c9a77752a6f8ad02151d9f2a3eda5e2b Mon Sep 17 00:00:00 2001 From: aaronburtle <93220300+aaronburtle@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:39:44 -0800 Subject: [PATCH 6/8] Refactor common code shared amongst built in MCP tools (#2986) ## Why make this change? Closes https://github.com/Azure/data-api-builder/issues/2932 ## What is this change? Add helper class `McpMetadataHelper`, extend `McpArgumentParser`, and utilize `McpAuthorizationHelper` to factor out common code. We now do the initialization of the metadata, the parsing of arguments, and the authorization checks in these shared helper classes. ## How was this tested? With MCP Inspector and against the normal test suite. * DESCRIBE_ENTITIES image * CREATE image * READ image * UPDATE image * DELETE image ## Sample Request(s) N/A --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Souvik Ghosh --- .../BuiltInTools/CreateRecordTool.cs | 126 ++++--------- .../BuiltInTools/DeleteRecordTool.cs | 47 ++--- .../BuiltInTools/DescribeEntitiesTool.cs | 6 +- .../BuiltInTools/ExecuteEntityTool.cs | 86 ++------- .../BuiltInTools/ReadRecordsTool.cs | 125 ++++--------- .../BuiltInTools/UpdateRecordTool.cs | 173 +++--------------- .../Model/McpErrorCode.cs | 14 ++ .../Utils/McpArgumentParser.cs | 140 ++++++++++++-- .../Utils/McpErrorHelpers.cs | 28 +++ .../Utils/McpMetadataHelper.cs | 89 +++++++++ 10 files changed, 386 insertions(+), 448 deletions(-) create mode 100644 src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs create mode 100644 src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs create mode 100644 src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs index 9d64fd4cd7..1a944d115b 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/CreateRecordTool.cs @@ -5,14 +5,13 @@ using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Core.Resolvers.Factories; using Azure.DataApiBuilder.Core.Services; -using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.DependencyInjection; @@ -60,22 +59,18 @@ public async Task ExecuteAsync( string toolName = GetToolMetadata().Name; if (arguments == null) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "Invalid Arguments", "No arguments provided", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); if (!runtimeConfigProvider.TryGetConfig(out RuntimeConfig? runtimeConfig)) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "Invalid Configuration", "Runtime configuration not available", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", "Runtime configuration not available.", logger); } if (runtimeConfig.McpDmlTools?.CreateRecord != true) { - return Utils.McpResponseBuilder.BuildErrorResult( - toolName, - "ToolDisabled", - "The create_record tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(toolName, logger); } try @@ -83,39 +78,21 @@ public async Task ExecuteAsync( cancellationToken.ThrowIfCancellationRequested(); JsonElement root = arguments.RootElement; - if (!root.TryGetProperty("entity", out JsonElement entityElement) || - !root.TryGetProperty("data", out JsonElement dataElement)) + if (!McpArgumentParser.TryParseEntityAndData(root, out string entityName, out JsonElement dataElement, out string parseError)) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required arguments 'entity' or 'data'", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } - string entityName = entityElement.GetString() ?? string.Empty; - if (string.IsNullOrWhiteSpace(entityName)) + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Entity name cannot be empty", logger); - } - - string dataSourceName; - try - { - dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); - } - catch (Exception) - { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", $"Entity '{entityName}' not found in configuration", logger); - } - - IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); - ISqlMetadataProvider sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - - DatabaseObject dbObject; - try - { - dbObject = sqlMetadataProvider.GetDatabaseObjectByKey(entityName); - } - catch (Exception) - { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "InvalidConfiguration", $"Database object for entity '{entityName}' not found", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // Create an HTTP context for authorization @@ -123,15 +100,20 @@ public async Task ExecuteAsync( HttpContext httpContext = httpContextAccessor.HttpContext ?? new DefaultHttpContext(); IAuthorizationResolver authorizationResolver = serviceProvider.GetRequiredService(); - if (httpContext is null || !authorizationResolver.IsValidRoleContext(httpContext)) + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authorizationResolver, out string roleCtxError)) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", "Permission denied: Unable to resolve a valid role context for update operation.", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "create", roleCtxError, logger); } - // Validate that we have at least one role authorized for create - if (!TryResolveAuthorizedRole(httpContext, authorizationResolver, entityName, out string authError)) + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext, + authorizationResolver, + entityName, + EntityActionOperation.Create, + out string? effectiveRole, + out string authError)) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", authError, logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "create", authError, logger); } JsonElement insertPayloadRoot = dataElement.Clone(); @@ -152,12 +134,12 @@ public async Task ExecuteAsync( } catch (Exception ex) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "ValidationFailed", $"Request validation failed: {ex.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "ValidationFailed", $"Request validation failed: {ex.Message}", logger); } } else { - return Utils.McpResponseBuilder.BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( toolName, "InvalidCreateTarget", "The create_record tool is only available for tables.", @@ -172,7 +154,7 @@ public async Task ExecuteAsync( if (result is CreatedResult createdResult) { - return Utils.McpResponseBuilder.BuildSuccessResult( + return McpResponseBuilder.BuildSuccessResult( new Dictionary { ["entity"] = entityName, @@ -187,7 +169,7 @@ public async Task ExecuteAsync( bool isError = objectResult.StatusCode.HasValue && objectResult.StatusCode.Value >= 400 && objectResult.StatusCode.Value != 403; if (isError) { - return Utils.McpResponseBuilder.BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( toolName, "CreateFailed", $"Failed to create record in entity '{entityName}'. Error: {JsonSerializer.Serialize(objectResult.Value)}", @@ -195,7 +177,7 @@ public async Task ExecuteAsync( } else { - return Utils.McpResponseBuilder.BuildSuccessResult( + return McpResponseBuilder.BuildSuccessResult( new Dictionary { ["entity"] = entityName, @@ -210,7 +192,7 @@ public async Task ExecuteAsync( { if (result is null) { - return Utils.McpResponseBuilder.BuildErrorResult( + return McpResponseBuilder.BuildErrorResult( toolName, "UnexpectedError", $"Mutation engine returned null result for entity '{entityName}'", @@ -218,7 +200,7 @@ public async Task ExecuteAsync( } else { - return Utils.McpResponseBuilder.BuildSuccessResult( + return McpResponseBuilder.BuildSuccessResult( new Dictionary { ["entity"] = entityName, @@ -231,50 +213,8 @@ public async Task ExecuteAsync( } catch (Exception ex) { - return Utils.McpResponseBuilder.BuildErrorResult(toolName, "Error", $"Error: {ex.Message}", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "Error", $"Error: {ex.Message}", logger); } } - - private static bool TryResolveAuthorizedRole( - HttpContext httpContext, - IAuthorizationResolver authorizationResolver, - string entityName, - out string error) - { - error = string.Empty; - - string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); - - if (string.IsNullOrWhiteSpace(roleHeader)) - { - error = "Client role header is missing or empty."; - return false; - } - - string[] roles = roleHeader - .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); - - if (roles.Length == 0) - { - error = "Client role header is missing or empty."; - return false; - } - - foreach (string role in roles) - { - bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( - entityName, role, EntityActionOperation.Create); - - if (allowed) - { - return true; - } - } - - error = "You do not have permission to create records for this entity."; - return false; - } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs index eb310ae364..d7837c0103 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DeleteRecordTool.cs @@ -87,11 +87,7 @@ public async Task ExecuteAsync( // 2) Check if the tool is enabled in configuration before proceeding if (config.McpDmlTools?.DeleteRecord != true) { - return McpResponseBuilder.BuildErrorResult( - toolName, - "ToolDisabled", - $"The {this.GetToolMetadata().Name} tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(GetToolMetadata().Name, logger); } // 3) Parsing & basic argument validation @@ -105,26 +101,17 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } - IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); - IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); - - // 4) Resolve metadata for entity existence check - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try + // 4) Resolve metadata for entity existence + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + config, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - dataSourceName = config.GetDataSourceNameFromEntityName(entityName); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) - { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // Validate it's a table or view @@ -140,7 +127,7 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"Permission denied: {roleError}", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "delete", roleError, logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( @@ -151,10 +138,11 @@ public async Task ExecuteAsync( out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"Permission denied: {authError}", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "delete", authError, logger); } - // 6) Build and validate Delete context + // Need MetadataProviderFactory for RequestValidator; resolve here. + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); RequestValidator requestValidator = new(metadataProviderFactory, runtimeConfigProvider); DeleteRequestContext context = new( @@ -174,7 +162,7 @@ public async Task ExecuteAsync( requestValidator.ValidatePrimaryKey(context); - // 7) Execute + IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); DatabaseType dbType = config.GetDataSourceFromDataSourceName(dataSourceName).DatabaseType; IMutationEngine mutationEngine = mutationEngineFactory.GetMutationEngine(dbType); @@ -343,8 +331,7 @@ public async Task ExecuteAsync( } catch (Exception ex) { - ILogger? innerLogger = serviceProvider.GetService>(); - innerLogger?.LogError(ex, "Unexpected error in DeleteRecordTool."); + logger?.LogError(ex, "Unexpected error in DeleteRecordTool."); return McpResponseBuilder.BuildErrorResult( toolName, diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs index b8c7d975a2..cd2a7cc28b 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/DescribeEntitiesTool.cs @@ -78,11 +78,7 @@ public Task ExecuteAsync( if (!IsToolEnabled(runtimeConfig)) { - return Task.FromResult(McpResponseBuilder.BuildErrorResult( - toolName, - "ToolDisabled", - $"The {GetToolMetadata().Name} tool is disabled in the configuration.", - logger)); + return Task.FromResult(McpErrorHelpers.ToolDisabled(GetToolMetadata().Name, logger)); } // Get authorization services to determine current user's role diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs index 6b0bc28383..e780c8ddeb 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ExecuteEntityTool.cs @@ -87,11 +87,7 @@ public async Task ExecuteAsync( // 2) Check if the tool is enabled in configuration before proceeding if (config.McpDmlTools?.ExecuteEntity != true) { - return McpResponseBuilder.BuildErrorResult( - toolName, - "ToolDisabled", - $"The {toolName} tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(this.GetToolMetadata().Name, logger); } // 3) Parsing & basic argument validation @@ -100,7 +96,7 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } - if (!TryParseExecuteArguments(arguments.RootElement, out string entity, out Dictionary parameters, out string parseError)) + if (!McpArgumentParser.TryParseExecuteArguments(arguments.RootElement, out string entity, out Dictionary parameters, out string parseError)) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } @@ -125,23 +121,17 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "InvalidEntity", $"Entity {entity} cannot be executed.", logger); } - // 5) Resolve metadata - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try - { - dataSourceName = config.GetDataSourceNameFromEntityName(entity); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Failed to resolve entity metadata for '{entity}'.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entity, out DatabaseObject? dbObject) || dbObject is null) + // Use shared metadata helper. + if (!McpMetadataHelper.TryResolveMetadata( + entity, + config, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Failed to resolve database object for entity '{entity}'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // 6) Authorization - Never bypass permissions @@ -151,7 +141,7 @@ public async Task ExecuteAsync( if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", roleError, logger); + return McpErrorHelpers.PermissionDenied(toolName, entity, "execute", roleError, logger); } if (!McpAuthorizationHelper.TryResolveAuthorizedRole( @@ -162,7 +152,7 @@ public async Task ExecuteAsync( out string? effectiveRole, out string authError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", authError, logger); + return McpErrorHelpers.PermissionDenied(toolName, entity, "execute", authError, logger); } // 7) Validate parameters against metadata @@ -327,48 +317,6 @@ public async Task ExecuteAsync( } } - /// - /// Parses the execute arguments from the JSON input. - /// - private static bool TryParseExecuteArguments( - JsonElement rootElement, - out string entity, - out Dictionary parameters, - out string parseError) - { - entity = string.Empty; - parameters = new Dictionary(); - parseError = string.Empty; - - if (rootElement.ValueKind != JsonValueKind.Object) - { - parseError = "Arguments must be an object"; - return false; - } - - // Extract entity name (required) - if (!rootElement.TryGetProperty("entity", out JsonElement entityElement) || - entityElement.ValueKind != JsonValueKind.String) - { - parseError = "Missing or invalid 'entity' parameter"; - return false; - } - - entity = entityElement.GetString() ?? string.Empty; - - // Extract parameters if provided (optional) - if (rootElement.TryGetProperty("parameters", out JsonElement parametersElement) && - parametersElement.ValueKind == JsonValueKind.Object) - { - foreach (JsonProperty property in parametersElement.EnumerateObject()) - { - parameters[property.Name] = GetParameterValue(property.Value); - } - } - - return true; - } - /// /// Converts a JSON element to its appropriate CLR type matching GraphQL data types. /// @@ -440,11 +388,7 @@ private static CallToolResult BuildExecuteSuccessResponse( } else if (queryResult is UnauthorizedObjectResult) { - return McpResponseBuilder.BuildErrorResult( - toolName, - "PermissionDenied", - "You do not have permission to execute this entity", - logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "execute", "You do not have permission to execute this entity", logger); } else { diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs index 7561e1ae54..1ed91c30a8 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs @@ -87,11 +87,7 @@ public async Task ExecuteAsync( if (runtimeConfig.McpDmlTools?.ReadRecords is not true) { - return McpResponseBuilder.BuildErrorResult( - toolName, - "ToolDisabled", - "The read_records tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(toolName, logger); } try @@ -113,13 +109,11 @@ public async Task ExecuteAsync( JsonElement root = arguments.RootElement; - if (!root.TryGetProperty("entity", out JsonElement entityElement) || string.IsNullOrWhiteSpace(entityElement.GetString())) + if (!McpArgumentParser.TryParseEntity(root, out entityName, out string parseError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'entity'.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } - entityName = entityElement.GetString()!; - if (root.TryGetProperty("select", out JsonElement selectElement)) { select = selectElement.GetString(); @@ -145,27 +139,16 @@ public async Task ExecuteAsync( after = afterElement.GetString(); } - // Get required services & configuration - IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); - IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); - - // Check metadata for entity exists - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try - { - dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) - { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // Authorization check in the existing entity @@ -174,20 +157,29 @@ public async Task ExecuteAsync( IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); HttpContext? httpContext = httpContextAccessor.HttpContext; - if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger); } - if (!TryResolveAuthorizedRole(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Read, + out string? effectiveRole, + out string readAuthError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", authError, logger); + string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase) + ? $"You do not have permission to read records for entity '{entityName}'." + : readAuthError; + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger); } // Build and validate Find context - RequestValidator requestValidator = new(metadataProviderFactory, runtimeConfigProvider); + RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider); FindRequestContext context = new(entityName, dbObject, true); - httpContext.Request.Method = "GET"; + httpContext!.Request.Method = "GET"; requestValidator.ValidateEntity(entityName); @@ -233,14 +225,17 @@ public async Task ExecuteAsync( requirements: new[] { new ColumnsPermissionsRequirement() }); if (!authorizationResult.Succeeded) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); } // Execute + IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); JsonDocument? queryResult = await queryEngine.ExecuteAsync(context); - IActionResult actionResult = queryResult is null ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true) - : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true); + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); + IActionResult actionResult = queryResult is null + ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true) + : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true); // Normalize response string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult); @@ -279,59 +274,5 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); } } - - /// - /// Ensures that the role used on the request has the necessary authorizations. - /// - /// Contains request headers and metadata of the user. - /// Resolver used to check if role has necessary authorizations. - /// Name of the entity used in the request. - /// Role defined in client role header. - /// Error message given to the user. - /// True if the user role is authorized, along with the role. - private static bool TryResolveAuthorizedRole( - HttpContext httpContext, - IAuthorizationResolver authorizationResolver, - string entityName, - out string? effectiveRole, - out string error) - { - effectiveRole = null; - error = string.Empty; - - string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); - - if (string.IsNullOrWhiteSpace(roleHeader)) - { - error = $"Client role header '{AuthorizationResolver.CLIENT_ROLE_HEADER}' is missing or empty."; - return false; - } - - string[] roles = roleHeader - .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); - - if (roles.Length == 0) - { - error = $"Client role header '{AuthorizationResolver.CLIENT_ROLE_HEADER}' is missing or empty."; - return false; - } - - foreach (string role in roles) - { - bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( - entityName, role, EntityActionOperation.Read); - - if (allowed) - { - effectiveRole = role; - return true; - } - } - - error = $"You do not have permission to read records for entity '{entityName}'."; - return false; - } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs index ab8956b618..195e27a0cd 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/UpdateRecordTool.cs @@ -5,7 +5,6 @@ using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Resolvers; @@ -92,11 +91,7 @@ public async Task ExecuteAsync( // 2)Check if the tool is enabled in configuration before proceeding. if (config.McpDmlTools?.UpdateRecord != true) { - return McpResponseBuilder.BuildErrorResult( - toolName, - "ToolDisabled", - "The update_record tool is disabled in the configuration.", - logger); + return McpErrorHelpers.ToolDisabled(GetToolMetadata().Name, logger); } try @@ -110,7 +105,12 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); } - if (!TryParseArguments(arguments.RootElement, out string entityName, out Dictionary keys, out Dictionary fields, out string parseError)) + if (!McpArgumentParser.TryParseEntityKeysAndFields( + arguments.RootElement, + out string entityName, + out Dictionary keys, + out Dictionary fields, + out string parseError)) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } @@ -118,23 +118,16 @@ public async Task ExecuteAsync( IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); IMutationEngineFactory mutationEngineFactory = serviceProvider.GetRequiredService(); - // 4) Resolve metadata for entity existence check - string dataSourceName; - ISqlMetadataProvider sqlMetadataProvider; - - try - { - dataSourceName = config.GetDataSourceNameFromEntityName(entityName); - sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); - } - catch (Exception) - { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); - } - - if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + config, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } // 5) Authorization after we have a known entity @@ -144,12 +137,18 @@ public async Task ExecuteAsync( if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", "Permission denied: unable to resolve a valid role context for update operation.", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "update", "unable to resolve a valid role context for update operation.", logger); } - if (!TryResolveAuthorizedRoleHasPermission(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Update, + out string? effectiveRole, + out string authError)) { - return McpResponseBuilder.BuildErrorResult(toolName, "PermissionDenied", $"Permission denied: {authError}", logger); + return McpErrorHelpers.PermissionDenied(toolName, entityName, "update", authError, logger); } // 6) Build and validate Upsert (UpdateIncremental) context @@ -194,11 +193,7 @@ public async Task ExecuteAsync( if (errorMsg.Contains("No Update could be performed, record not found", StringComparison.OrdinalIgnoreCase)) { - return McpResponseBuilder.BuildErrorResult( - toolName, - "InvalidArguments", - "No record found with the given key.", - logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No record found with the given key.", logger); } else { @@ -247,8 +242,8 @@ public async Task ExecuteAsync( } catch (Exception ex) { - ILogger? innerLogger = serviceProvider.GetService>(); - innerLogger?.LogError(ex, "Unexpected error in UpdateRecordTool."); + logger?.LogError(ex, "Unexpected error in UpdateRecordTool."); + return McpResponseBuilder.BuildErrorResult( toolName, "UnexpectedError", @@ -256,117 +251,5 @@ public async Task ExecuteAsync( logger); } } - - #region Parsing & Authorization - - private static bool TryParseArguments( - JsonElement root, - out string entityName, - out Dictionary keys, - out Dictionary fields, - out string error) - { - entityName = string.Empty; - keys = new Dictionary(); - fields = new Dictionary(); - error = string.Empty; - - if (!root.TryGetProperty("entity", out JsonElement entityEl) || - !root.TryGetProperty("keys", out JsonElement keysEl) || - !root.TryGetProperty("fields", out JsonElement fieldsEl)) - { - error = "Missing required arguments 'entity', 'keys', or 'fields'."; - return false; - } - - // Parse and validate required arguments: entity, keys, fields - entityName = entityEl.GetString() ?? string.Empty; - if (string.IsNullOrWhiteSpace(entityName)) - { - throw new ArgumentException("Entity is required", nameof(entityName)); - } - - if (keysEl.ValueKind != JsonValueKind.Object || fieldsEl.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("'keys' and 'fields' must be JSON objects."); - } - - try - { - keys = JsonSerializer.Deserialize>(keysEl.GetRawText()) ?? new Dictionary(); - fields = JsonSerializer.Deserialize>(fieldsEl.GetRawText()) ?? new Dictionary(); - } - catch (Exception ex) - { - throw new ArgumentException("Failed to parse 'keys' or 'fields'", ex); - } - - if (keys.Count == 0) - { - throw new ArgumentException("Keys are required to update an entity"); - } - - if (fields.Count == 0) - { - throw new ArgumentException("At least one field must be provided to update an entity", nameof(fields)); - } - - foreach (KeyValuePair kv in keys) - { - if (kv.Value is null || (kv.Value is string str && string.IsNullOrWhiteSpace(str))) - { - throw new ArgumentException($"Key value for '{kv.Key}' cannot be null or empty."); - } - } - - return true; - } - - private static bool TryResolveAuthorizedRoleHasPermission( - HttpContext httpContext, - IAuthorizationResolver authorizationResolver, - string entityName, - out string? effectiveRole, - out string error) - { - effectiveRole = null; - error = string.Empty; - - string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); - - if (string.IsNullOrWhiteSpace(roleHeader)) - { - error = "Client role header is missing or empty."; - return false; - } - - string[] roles = roleHeader - .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); - - if (roles.Length == 0) - { - error = "Client role header is missing or empty."; - return false; - } - - foreach (string role in roles) - { - bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( - entityName, role, EntityActionOperation.Update); - - if (allowed) - { - effectiveRole = role; - return true; - } - } - - error = "You do not have permission to update records for this entity."; - return false; - } - - #endregion } } diff --git a/src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs b/src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs new file mode 100644 index 0000000000..ed13f62783 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Model/McpErrorCode.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Mcp.Model +{ + /// + /// MCP error codes standardized for built-in tools. + /// + public enum McpErrorCode + { + ToolDisabled, + PermissionDenied + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs index 04d14eb5d6..02344c2956 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpArgumentParser.cs @@ -11,26 +11,24 @@ namespace Azure.DataApiBuilder.Mcp.Utils public static class McpArgumentParser { /// - /// Parses entity and keys arguments for delete/update operations. + /// Parses only the entity name from arguments. /// - public static bool TryParseEntityAndKeys( + public static bool TryParseEntity( JsonElement root, out string entityName, - out Dictionary keys, - out string error) + out string error, + CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); entityName = string.Empty; - keys = new Dictionary(); error = string.Empty; - if (!root.TryGetProperty("entity", out JsonElement entityEl) || - !root.TryGetProperty("keys", out JsonElement keysEl)) + if (!root.TryGetProperty("entity", out JsonElement entityEl)) { - error = "Missing required arguments 'entity' or 'keys'."; + error = "Missing required argument 'entity'."; return false; } - // Parse and validate entity name entityName = entityEl.GetString() ?? string.Empty; if (string.IsNullOrWhiteSpace(entityName)) { @@ -38,7 +36,65 @@ public static bool TryParseEntityAndKeys( return false; } - // Parse and validate keys + return true; + } + + /// + /// Parses entity and data arguments for create operations. + /// + public static bool TryParseEntityAndData( + JsonElement root, + out string entityName, + out JsonElement dataElement, + out string error, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + dataElement = default; + + if (!TryParseEntity(root, out entityName, out error, cancellationToken)) + { + return false; + } + + if (!root.TryGetProperty("data", out dataElement)) + { + error = "Missing required argument 'data'."; + return false; + } + + if (dataElement.ValueKind != JsonValueKind.Object) + { + error = "'data' must be a JSON object."; + return false; + } + + return true; + } + + /// + /// Parses entity and keys arguments for delete/update operations. + /// + public static bool TryParseEntityAndKeys( + JsonElement root, + out string entityName, + out Dictionary keys, + out string error, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + keys = new Dictionary(); + if (!TryParseEntity(root, out entityName, out error, cancellationToken)) + { + return false; + } + + if (!root.TryGetProperty("keys", out JsonElement keysEl)) + { + error = "Missing required argument 'keys'."; + return false; + } + if (keysEl.ValueKind != JsonValueKind.Object) { error = "'keys' must be a JSON object."; @@ -64,6 +120,8 @@ public static bool TryParseEntityAndKeys( // Validate key values foreach (KeyValuePair kv in keys) { + cancellationToken.ThrowIfCancellationRequested(); + if (kv.Value is null || (kv.Value is string str && string.IsNullOrWhiteSpace(str))) { error = $"Primary key value for '{kv.Key}' cannot be null or empty"; @@ -82,12 +140,14 @@ public static bool TryParseEntityKeysAndFields( out string entityName, out Dictionary keys, out Dictionary fields, - out string error) + out string error, + CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); fields = new Dictionary(); // First parse entity and keys - if (!TryParseEntityAndKeys(root, out entityName, out keys, out error)) + if (!TryParseEntityAndKeys(root, out entityName, out keys, out error, cancellationToken)) { return false; } @@ -123,5 +183,61 @@ public static bool TryParseEntityKeysAndFields( return true; } + + /// + /// Parses the execute arguments from the JSON input. + /// + public static bool TryParseExecuteArguments( + JsonElement rootElement, + out string entity, + out Dictionary parameters, + out string parseError, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + entity = string.Empty; + parameters = new Dictionary(); + + if (rootElement.ValueKind != JsonValueKind.Object) + { + parseError = "Arguments must be an object"; + return false; + } + + if (!TryParseEntity(rootElement, out entity, out parseError, cancellationToken)) + { + return false; + } + + // Extract parameters if provided (optional) + if (rootElement.TryGetProperty("parameters", out JsonElement parametersElement) && + parametersElement.ValueKind == JsonValueKind.Object) + { + foreach (JsonProperty property in parametersElement.EnumerateObject()) + { + cancellationToken.ThrowIfCancellationRequested(); + parameters[property.Name] = GetExecuteParameterValue(property.Value); + } + } + + return true; + } + + // Local helper replicating ExecuteEntityTool.GetParameterValue without refactoring other tools. + private static object? GetExecuteParameterValue(JsonElement element) + { + return element.ValueKind switch + { + JsonValueKind.String => element.GetString(), + JsonValueKind.Number => + element.TryGetInt64(out long longValue) ? longValue : + element.TryGetDecimal(out decimal decimalValue) ? decimalValue : + element.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => element.ToString() + }; + } } } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs new file mode 100644 index 0000000000..75335b2db1 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; + +namespace Azure.DataApiBuilder.Mcp.Utils +{ + /// + /// Helper utilities for creating standardized MCP error responses. + /// Only includes helpers currently being centralized. + /// + public static class McpErrorHelpers + { + public static CallToolResult PermissionDenied(string toolName, string entityName, string operation, string detail, ILogger? logger) + { + string message = $"Permission denied for {operation} on entity '{entityName}'. {detail}"; + return McpResponseBuilder.BuildErrorResult(toolName, Model.McpErrorCode.PermissionDenied.ToString(), message, logger); + } + + // Centralized language for 'tool disabled' errors. Pass the tool name, e.g. "read_records". + public static CallToolResult ToolDisabled(string toolName, ILogger? logger) + { + string message = $"The {toolName} tool is disabled in the configuration."; + return McpResponseBuilder.BuildErrorResult(toolName, Model.McpErrorCode.ToolDisabled.ToString(), message, logger); + } + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs new file mode 100644 index 0000000000..1e79e86b15 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Service.Exceptions; // Added for DataApiBuilderException +using Microsoft.Extensions.DependencyInjection; + +namespace Azure.DataApiBuilder.Mcp.Utils +{ + /// + /// Utility class for resolving metadata and datasource information for MCP tools. + /// + public static class McpMetadataHelper + { + public static bool TryResolveMetadata( + string entityName, + RuntimeConfig config, + IServiceProvider serviceProvider, + out Azure.DataApiBuilder.Core.Services.ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string error, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + sqlMetadataProvider = default!; + dbObject = default!; + dataSourceName = string.Empty; + error = string.Empty; + + if (string.IsNullOrWhiteSpace(entityName)) + { + error = "Entity name cannot be null or empty."; + return false; + } + + var metadataProviderFactory = serviceProvider.GetRequiredService(); + + // Resolve datasource name for the entity. + try + { + cancellationToken.ThrowIfCancellationRequested(); + dataSourceName = config.GetDataSourceNameFromEntityName(entityName); + } + catch (DataApiBuilderException dabEx) when (dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.EntityNotFound) + { + error = $"Entity '{entityName}' is not defined in the configuration."; + return false; + } + catch (DataApiBuilderException dabEx) + { + // Other DAB exceptions during entity->datasource resolution. + error = dabEx.Message; + return false; + } + + // Resolve metadata provider for the datasource. + try + { + cancellationToken.ThrowIfCancellationRequested(); + sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); + } + catch (DataApiBuilderException dabEx) when (dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.DataSourceNotFound) + { + error = $"Data source '{dataSourceName}' for entity '{entityName}' is not defined in the configuration."; + return false; + } + catch (DataApiBuilderException dabEx) + { + // Other DAB exceptions during metadata provider resolution. + error = dabEx.Message; + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Validate entity exists in metadata mapping. + if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? temp) || temp is null) + { + error = $"Entity '{entityName}' is not defined in the configuration."; + return false; + } + + dbObject = temp; + return true; + } + } +} From 487aac3fa19a996c96ee9fea71a9c2907406a28b Mon Sep 17 00:00:00 2001 From: aaron burtle Date: Fri, 12 Dec 2025 08:03:28 -0800 Subject: [PATCH 7/8] format --- src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs | 3 ++- .../UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs index 1e79e86b15..d92117dba1 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpMetadataHelper.cs @@ -3,6 +3,7 @@ using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Service.Exceptions; // Added for DataApiBuilderException using Microsoft.Extensions.DependencyInjection; @@ -35,7 +36,7 @@ public static bool TryResolveMetadata( return false; } - var metadataProviderFactory = serviceProvider.GetRequiredService(); + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); // Resolve datasource name for the entity. try diff --git a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs index b990b96368..8329dc2134 100644 --- a/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs +++ b/src/Service.Tests/UnitTests/RuntimeConfigLoaderJsonDeserializerTests.cs @@ -773,8 +773,8 @@ public void TestAkvSecretValueContainingEnvPatternIsNotEnvExpanded() Assert.IsFalse(actual.Contains("SHOULD_NOT_APPEAR"), "Env var value should not be expanded inside AKV secret."); Assert.IsTrue(actual.Contains("Application Name="), "Application Name should be appended for MSSQL when env replacement is enabled."); - var builderOriginal = new SqlConnectionStringBuilder(secretValueWithEnvPattern.Replace("Server=", "Data Source=").Replace("Database=", "Initial Catalog=")); - var builderActual = new SqlConnectionStringBuilder(actual); + SqlConnectionStringBuilder builderOriginal = new(secretValueWithEnvPattern.Replace("Server=", "Data Source=").Replace("Database=", "Initial Catalog=")); + SqlConnectionStringBuilder builderActual = new(actual); Assert.AreEqual(builderOriginal["Data Source"], builderActual["Data Source"], "Server/Data Source should match."); Assert.AreEqual(builderOriginal["Initial Catalog"], builderActual["Initial Catalog"], "Database/Initial Catalog should match (with env pattern retained)."); Assert.AreEqual(builderOriginal["User ID"], builderActual["User ID"], "User Id should match."); @@ -829,8 +829,8 @@ public void TestEnvVariableResolvingToAkvPatternIsExpandedInSecondPass() Assert.IsNotNull(config); string expected = RuntimeConfigLoader.GetConnectionStringWithApplicationName(finalSecretValue); - var builderExpected = new SqlConnectionStringBuilder(expected); - var builderActual = new SqlConnectionStringBuilder(config.DataSource.ConnectionString); + SqlConnectionStringBuilder builderExpected = new(expected); + SqlConnectionStringBuilder builderActual = new(config.DataSource.ConnectionString); Assert.AreEqual(builderExpected["Data Source"], builderActual["Data Source"], "Data Source should match."); Assert.AreEqual(builderExpected["Initial Catalog"], builderActual["Initial Catalog"], "Initial Catalog should match."); Assert.AreEqual(builderExpected["User ID"], builderActual["User ID"], "User ID should match."); From e3511e7efa031192b943eb73fb6aee561713c5aa Mon Sep 17 00:00:00 2001 From: aaronburtle <93220300+aaronburtle@users.noreply.github.com> Date: Tue, 9 Dec 2025 20:12:36 -0800 Subject: [PATCH 8/8] Suppress SM05137 in deserialization variable replacement settings (#3004) ## Why make this change? Silences CodeQL flag. ## What is this change? Adds the suppression language to the usage of `DefaultAzureCredential()` ## How was this tested? Against usual test suite, no real code change, just a comment. ## Sample Request(s) N/A --- src/Config/DeserializationVariableReplacementSettings.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Config/DeserializationVariableReplacementSettings.cs b/src/Config/DeserializationVariableReplacementSettings.cs index 5c70f4082b..350824409b 100644 --- a/src/Config/DeserializationVariableReplacementSettings.cs +++ b/src/Config/DeserializationVariableReplacementSettings.cs @@ -259,7 +259,7 @@ private static SecretClient CreateSecretClient(AzureKeyVaultOptions options) clientOptions.Retry.NetworkTimeout = TimeSpan.FromSeconds(options.RetryPolicy.NetworkTimeoutSeconds ?? AKVRetryPolicyOptions.DEFAULT_NETWORK_TIMEOUT_SECONDS); } - return new SecretClient(new Uri(options.Endpoint), new DefaultAzureCredential(), clientOptions); + return new SecretClient(new Uri(options.Endpoint), new DefaultAzureCredential(), clientOptions); // CodeQL [SM05137] DefaultAzureCredential will use Managed Identity if available or fallback to default. } private string? GetAkvVariable(string name)