From 2d9615192408687c8170f9042f8a430e43bfc075 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 28 Feb 2026 05:55:59 +0000 Subject: [PATCH 01/32] Initial plan From eaaa5229773661a4aa6008aa185b4438bea8f37e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 28 Feb 2026 06:15:55 +0000 Subject: [PATCH 02/32] Changes before error encountered Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 594 +++++++++++++++++ src/Cli/Commands/ConfigureOptions.cs | 3 + src/Cli/ConfigGenerator.cs | 12 +- .../Converters/DmlToolsConfigConverter.cs | 18 +- src/Config/ObjectModel/DmlToolsConfig.cs | 25 +- .../Mcp/AggregateRecordsToolTests.cs | 596 ++++++++++++++++++ .../EntityLevelDmlToolConfigurationTests.cs | 2 + 7 files changed, 1243 insertions(+), 7 deletions(-) create mode 100644 src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs create mode 100644 src/Service.Tests/Mcp/AggregateRecordsToolTests.cs diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs new file mode 100644 index 0000000000..e64710e46e --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -0,0 +1,594 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Data.Common; +using System.Text.Json; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Parsers; +using Azure.DataApiBuilder.Core.Resolvers; +using Azure.DataApiBuilder.Core.Resolvers.Factories; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Mcp.BuiltInTools +{ + /// + /// Tool to aggregate records from a table/view entity configured in DAB. + /// Supports count, avg, sum, min, max with optional distinct, filter, groupby, having, orderby. + /// + public class AggregateRecordsTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + private static readonly HashSet ValidFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; + + public Tool GetToolMetadata() + { + return new Tool + { + Name = "aggregate_records", + Description = "STEP 1: describe_entities -> find entities with READ permission and their fields. STEP 2: call this tool to compute aggregations (count, avg, sum, min, max) with optional filter, groupby, having, and orderby.", + InputSchema = JsonSerializer.Deserialize( + @"{ + ""type"": ""object"", + ""properties"": { + ""entity"": { + ""type"": ""string"", + ""description"": ""Entity name with READ permission."" + }, + ""function"": { + ""type"": ""string"", + ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], + ""description"": ""Aggregation function to apply."" + }, + ""field"": { + ""type"": ""string"", + ""description"": ""Field to aggregate. Use '*' for count."" + }, + ""distinct"": { + ""type"": ""boolean"", + ""description"": ""Apply DISTINCT before aggregating."", + ""default"": false + }, + ""filter"": { + ""type"": ""string"", + ""description"": ""OData filter applied before aggregating (WHERE). Example: 'unitPrice lt 10'"", + ""default"": """" + }, + ""groupby"": { + ""type"": ""array"", + ""items"": { ""type"": ""string"" }, + ""description"": ""Fields to group by, e.g., ['category', 'region']. Grouped field values are included in the response."", + ""default"": [] + }, + ""orderby"": { + ""type"": ""string"", + ""enum"": [""asc"", ""desc""], + ""description"": ""Sort aggregated results by the computed value. Only applies with groupby."", + ""default"": ""desc"" + }, + ""having"": { + ""type"": ""object"", + ""description"": ""Filter applied after aggregating on the result (HAVING). Operators are AND-ed together."", + ""properties"": { + ""eq"": { ""type"": ""number"", ""description"": ""Aggregated value equals."" }, + ""neq"": { ""type"": ""number"", ""description"": ""Aggregated value not equals."" }, + ""gt"": { ""type"": ""number"", ""description"": ""Aggregated value greater than."" }, + ""gte"": { ""type"": ""number"", ""description"": ""Aggregated value greater than or equal."" }, + ""lt"": { ""type"": ""number"", ""description"": ""Aggregated value less than."" }, + ""lte"": { ""type"": ""number"", ""description"": ""Aggregated value less than or equal."" }, + ""in"": { + ""type"": ""array"", + ""items"": { ""type"": ""number"" }, + ""description"": ""Aggregated value is in the given list."" + } + } + } + }, + ""required"": [""entity"", ""function"", ""field""] + }" + ) + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; + + RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); + RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); + + if (runtimeConfig.McpDmlTools?.AggregateRecords is not true) + { + return McpErrorHelpers.ToolDisabled(toolName, logger); + } + + try + { + cancellationToken.ThrowIfCancellationRequested(); + + if (arguments == null) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); + } + + JsonElement root = arguments.RootElement; + + // Parse required arguments + if (!McpArgumentParser.TryParseEntity(root, out string entityName, out string parseError)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); + } + + if (runtimeConfig.Entities?.TryGetValue(entityName, out Entity? entity) == true && + entity.Mcp?.DmlToolEnabled == false) + { + return McpErrorHelpers.ToolDisabled(toolName, logger, $"DML tools are disabled for entity '{entityName}'."); + } + + if (!root.TryGetProperty("function", out JsonElement funcEl) || string.IsNullOrWhiteSpace(funcEl.GetString())) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'function'.", logger); + } + + string function = funcEl.GetString()!.ToLowerInvariant(); + if (!ValidFunctions.Contains(function)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger); + } + + if (!root.TryGetProperty("field", out JsonElement fieldEl) || string.IsNullOrWhiteSpace(fieldEl.GetString())) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'field'.", logger); + } + + string field = fieldEl.GetString()!; + bool distinct = root.TryGetProperty("distinct", out JsonElement distinctEl) && distinctEl.GetBoolean(); + string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null; + string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc"; + + List groupby = new(); + if (root.TryGetProperty("groupby", out JsonElement groupbyEl) && groupbyEl.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement g in groupbyEl.EnumerateArray()) + { + string? gVal = g.GetString(); + if (!string.IsNullOrWhiteSpace(gVal)) + { + groupby.Add(gVal); + } + } + } + + Dictionary? havingOps = null; + List? havingIn = null; + if (root.TryGetProperty("having", out JsonElement havingEl) && havingEl.ValueKind == JsonValueKind.Object) + { + havingOps = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (JsonProperty prop in havingEl.EnumerateObject()) + { + if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase) && prop.Value.ValueKind == JsonValueKind.Array) + { + havingIn = new List(); + foreach (JsonElement item in prop.Value.EnumerateArray()) + { + havingIn.Add(item.GetDouble()); + } + } + else if (prop.Value.ValueKind == JsonValueKind.Number) + { + havingOps[prop.Name] = prop.Value.GetDouble(); + } + } + } + + // Resolve metadata + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); + } + + // Authorization + IAuthorizationResolver authResolver = serviceProvider.GetRequiredService(); + IAuthorizationService authorizationService = serviceProvider.GetRequiredService(); + IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); + HttpContext? httpContext = httpContextAccessor.HttpContext; + + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError)) + { + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger); + } + + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Read, + out string? effectiveRole, + out string readAuthError)) + { + string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase) + ? $"You do not have permission to read records for entity '{entityName}'." + : readAuthError; + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger); + } + + // Build select list: groupby fields + aggregation field + List selectFields = new(groupby); + bool isCountStar = function == "count" && field == "*"; + if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase)) + { + selectFields.Add(field); + } + + // Build and validate Find context + RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider); + FindRequestContext context = new(entityName, dbObject, true); + httpContext!.Request.Method = "GET"; + + requestValidator.ValidateEntity(entityName); + + if (selectFields.Count > 0) + { + context.UpdateReturnFields(selectFields); + } + + if (!string.IsNullOrWhiteSpace(filter)) + { + string filterQueryString = $"?{RequestParser.FILTER_URL}={filter}"; + context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}"); + } + + requestValidator.ValidateRequestContext(context); + + AuthorizationResult authorizationResult = await authorizationService.AuthorizeAsync( + user: httpContext.User, + resource: context, + requirements: new[] { new ColumnsPermissionsRequirement() }); + if (!authorizationResult.Succeeded) + { + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + } + + // Execute query to get records + IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); + IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); + JsonDocument? queryResult = await queryEngine.ExecuteAsync(context); + + IActionResult actionResult = queryResult is null + ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true) + : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true); + + string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult); + using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson); + JsonElement resultRoot = resultDoc.RootElement; + + // Extract the records array from the response + JsonElement records; + if (resultRoot.TryGetProperty("value", out JsonElement valueArray)) + { + records = valueArray; + } + else if (resultRoot.ValueKind == JsonValueKind.Array) + { + records = resultRoot; + } + else + { + records = resultRoot; + } + + // Compute alias for the response + string alias = ComputeAlias(function, field); + + // Perform in-memory aggregation + List> aggregatedResults = PerformAggregation( + records, function, field, distinct, groupby, havingOps, havingIn, orderby, alias); + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = aggregatedResults, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); + } + catch (OperationCanceledException) + { + return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The aggregate operation was canceled.", logger); + } + catch (DbException argEx) + { + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", argEx.Message, logger); + } + catch (ArgumentException argEx) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); + } + catch (DataApiBuilderException argEx) + { + return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger); + } + catch (Exception ex) + { + logger?.LogError(ex, "Unexpected error in AggregateRecordsTool."); + return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in AggregateRecordsTool.", logger); + } + } + + /// + /// Computes the response alias for the aggregation result. + /// For count with "*", the alias is "count". Otherwise it's "{function}_{field}". + /// + internal static string ComputeAlias(string function, string field) + { + if (function == "count" && field == "*") + { + return "count"; + } + + return $"{function}_{field}"; + } + + /// + /// Performs in-memory aggregation over a JSON array of records. + /// + internal static List> PerformAggregation( + JsonElement records, + string function, + string field, + bool distinct, + List groupby, + Dictionary? havingOps, + List? havingIn, + string orderby, + string alias) + { + if (records.ValueKind != JsonValueKind.Array) + { + return new List> { new() { [alias] = null } }; + } + + bool isCountStar = function == "count" && field == "*"; + + if (groupby.Count == 0) + { + // No groupby - single result + List items = new(); + foreach (JsonElement record in records.EnumerateArray()) + { + items.Add(record); + } + + double? aggregatedValue = ComputeAggregateValue(items, function, field, distinct, isCountStar); + + // Apply having + if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn)) + { + return new List>(); + } + + return new List> + { + new() { [alias] = aggregatedValue } + }; + } + else + { + // Group by + Dictionary> groups = new(); + Dictionary> groupKeys = new(); + + foreach (JsonElement record in records.EnumerateArray()) + { + string key = BuildGroupKey(record, groupby); + if (!groups.ContainsKey(key)) + { + groups[key] = new List(); + groupKeys[key] = ExtractGroupFields(record, groupby); + } + + groups[key].Add(record); + } + + List> results = new(); + foreach (KeyValuePair> group in groups) + { + double? aggregatedValue = ComputeAggregateValue(group.Value, function, field, distinct, isCountStar); + + if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn)) + { + continue; + } + + Dictionary row = new(groupKeys[group.Key]) + { + [alias] = aggregatedValue + }; + results.Add(row); + } + + // Apply orderby + if (orderby.Equals("asc", StringComparison.OrdinalIgnoreCase)) + { + results.Sort((a, b) => CompareNullableDoubles(a[alias] as double?, b[alias] as double?)); + } + else + { + results.Sort((a, b) => CompareNullableDoubles(b[alias] as double?, a[alias] as double?)); + } + + return results; + } + } + + private static double? ComputeAggregateValue(List records, string function, string field, bool distinct, bool isCountStar) + { + if (isCountStar) + { + return distinct ? 0 : records.Count; + } + + List values = new(); + foreach (JsonElement record in records) + { + if (record.TryGetProperty(field, out JsonElement val) && val.ValueKind == JsonValueKind.Number) + { + values.Add(val.GetDouble()); + } + } + + if (distinct) + { + values = values.Distinct().ToList(); + } + + if (function == "count") + { + return values.Count; + } + + if (values.Count == 0) + { + return null; + } + + return function switch + { + "avg" => Math.Round(values.Average(), 2), + "sum" => values.Sum(), + "min" => values.Min(), + "max" => values.Max(), + _ => null + }; + } + + private static bool PassesHavingFilter(double? value, Dictionary? havingOps, List? havingIn) + { + if (havingOps == null && havingIn == null) + { + return true; + } + + if (value == null) + { + return false; + } + + double v = value.Value; + + if (havingOps != null) + { + foreach (KeyValuePair op in havingOps) + { + bool passes = op.Key.ToLowerInvariant() switch + { + "eq" => v == op.Value, + "neq" => v != op.Value, + "gt" => v > op.Value, + "gte" => v >= op.Value, + "lt" => v < op.Value, + "lte" => v <= op.Value, + _ => true + }; + + if (!passes) + { + return false; + } + } + } + + if (havingIn != null && !havingIn.Contains(v)) + { + return false; + } + + return true; + } + + private static string BuildGroupKey(JsonElement record, List groupby) + { + List parts = new(); + foreach (string g in groupby) + { + if (record.TryGetProperty(g, out JsonElement val)) + { + parts.Add(val.ToString()); + } + else + { + parts.Add("__null__"); + } + } + + return string.Join("|", parts); + } + + private static Dictionary ExtractGroupFields(JsonElement record, List groupby) + { + Dictionary result = new(); + foreach (string g in groupby) + { + if (record.TryGetProperty(g, out JsonElement val)) + { + result[g] = McpResponseBuilder.GetJsonValue(val); + } + else + { + result[g] = null; + } + } + + return result; + } + + private static int CompareNullableDoubles(double? a, double? b) + { + if (a == null && b == null) + { + return 0; + } + + if (a == null) + { + return -1; + } + + if (b == null) + { + return 1; + } + + return a.Value.CompareTo(b.Value); + } + } +} diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 262cbc9145..ecd5ecd185 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -224,6 +224,9 @@ public ConfigureOptions( [Option("runtime.mcp.dml-tools.execute-entity.enabled", Required = false, HelpText = "Enable DAB's MCP execute entity tool. Default: true (boolean).")] public bool? RuntimeMcpDmlToolsExecuteEntityEnabled { get; } + [Option("runtime.mcp.dml-tools.aggregate-records.enabled", Required = false, HelpText = "Enable DAB's MCP aggregate records tool. Default: true (boolean).")] + public bool? RuntimeMcpDmlToolsAggregateRecordsEnabled { get; } + [Option("runtime.cache.enabled", Required = false, HelpText = "Enable DAB's cache globally. (You must also enable each entity's cache separately.). Default: false (boolean).")] public bool? RuntimeCacheEnabled { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 6c51f002b7..2eaf50a822 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -1181,6 +1181,7 @@ private static bool TryUpdateConfiguredMcpValues( bool? updateRecord = currentDmlTools?.UpdateRecord; bool? deleteRecord = currentDmlTools?.DeleteRecord; bool? executeEntity = currentDmlTools?.ExecuteEntity; + bool? aggregateRecords = currentDmlTools?.AggregateRecords; updatedValue = options?.RuntimeMcpDmlToolsDescribeEntitiesEnabled; if (updatedValue != null) @@ -1230,6 +1231,14 @@ private static bool TryUpdateConfiguredMcpValues( _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.execute-entity as '{updatedValue}'", updatedValue); } + updatedValue = options?.RuntimeMcpDmlToolsAggregateRecordsEnabled; + if (updatedValue != null) + { + aggregateRecords = (bool)updatedValue; + hasToolUpdates = true; + _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.aggregate-records as '{updatedValue}'", updatedValue); + } + if (hasToolUpdates) { updatedMcpOptions = updatedMcpOptions! with @@ -1242,7 +1251,8 @@ private static bool TryUpdateConfiguredMcpValues( ReadRecords = readRecord, UpdateRecord = updateRecord, DeleteRecord = deleteRecord, - ExecuteEntity = executeEntity + ExecuteEntity = executeEntity, + AggregateRecords = aggregateRecords } }; } diff --git a/src/Config/Converters/DmlToolsConfigConverter.cs b/src/Config/Converters/DmlToolsConfigConverter.cs index 82ac3f6069..7e049c7926 100644 --- a/src/Config/Converters/DmlToolsConfigConverter.cs +++ b/src/Config/Converters/DmlToolsConfigConverter.cs @@ -44,6 +44,7 @@ internal class DmlToolsConfigConverter : JsonConverter bool? updateRecord = null; bool? deleteRecord = null; bool? executeEntity = null; + bool? aggregateRecords = null; while (reader.Read()) { @@ -82,6 +83,9 @@ internal class DmlToolsConfigConverter : JsonConverter case "execute-entity": executeEntity = value; break; + case "aggregate-records": + aggregateRecords = value; + break; default: // Skip unknown properties break; @@ -91,7 +95,8 @@ internal class DmlToolsConfigConverter : JsonConverter { // Error on non-boolean values for known properties if (property?.ToLowerInvariant() is "describe-entities" or "create-record" - or "read-records" or "update-record" or "delete-record" or "execute-entity") + or "read-records" or "update-record" or "delete-record" or "execute-entity" + or "aggregate-records") { throw new JsonException($"Property '{property}' must be a boolean value."); } @@ -110,7 +115,8 @@ internal class DmlToolsConfigConverter : JsonConverter readRecords: readRecords, updateRecord: updateRecord, deleteRecord: deleteRecord, - executeEntity: executeEntity); + executeEntity: executeEntity, + aggregateRecords: aggregateRecords); } // For any other unexpected token type, return default (all enabled) @@ -135,7 +141,8 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer value.UserProvidedReadRecords || value.UserProvidedUpdateRecord || value.UserProvidedDeleteRecord || - value.UserProvidedExecuteEntity; + value.UserProvidedExecuteEntity || + value.UserProvidedAggregateRecords; // Only write the boolean value if it's provided by user // This prevents writing "dml-tools": true when it's the default @@ -181,6 +188,11 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer writer.WriteBoolean("execute-entity", value.ExecuteEntity.Value); } + if (value.UserProvidedAggregateRecords && value.AggregateRecords.HasValue) + { + writer.WriteBoolean("aggregate-records", value.AggregateRecords.Value); + } + writer.WriteEndObject(); } } diff --git a/src/Config/ObjectModel/DmlToolsConfig.cs b/src/Config/ObjectModel/DmlToolsConfig.cs index 2a09e9d53c..c1f8b278cd 100644 --- a/src/Config/ObjectModel/DmlToolsConfig.cs +++ b/src/Config/ObjectModel/DmlToolsConfig.cs @@ -51,6 +51,11 @@ public record DmlToolsConfig /// public bool? ExecuteEntity { get; init; } + /// + /// Whether aggregate-records tool is enabled + /// + public bool? AggregateRecords { get; init; } + [JsonConstructor] public DmlToolsConfig( bool? allToolsEnabled = null, @@ -59,7 +64,8 @@ public DmlToolsConfig( bool? readRecords = null, bool? updateRecord = null, bool? deleteRecord = null, - bool? executeEntity = null) + bool? executeEntity = null, + bool? aggregateRecords = null) { if (allToolsEnabled is not null) { @@ -75,6 +81,7 @@ public DmlToolsConfig( UpdateRecord = updateRecord ?? toolDefault; DeleteRecord = deleteRecord ?? toolDefault; ExecuteEntity = executeEntity ?? toolDefault; + AggregateRecords = aggregateRecords ?? toolDefault; } else { @@ -87,6 +94,7 @@ public DmlToolsConfig( UpdateRecord = updateRecord ?? DEFAULT_ENABLED; DeleteRecord = deleteRecord ?? DEFAULT_ENABLED; ExecuteEntity = executeEntity ?? DEFAULT_ENABLED; + AggregateRecords = aggregateRecords ?? DEFAULT_ENABLED; } // Track user-provided status - only true if the parameter was not null @@ -96,6 +104,7 @@ public DmlToolsConfig( UserProvidedUpdateRecord = updateRecord is not null; UserProvidedDeleteRecord = deleteRecord is not null; UserProvidedExecuteEntity = executeEntity is not null; + UserProvidedAggregateRecords = aggregateRecords is not null; } /// @@ -112,7 +121,8 @@ public static DmlToolsConfig FromBoolean(bool enabled) readRecords: null, updateRecord: null, deleteRecord: null, - executeEntity: null + executeEntity: null, + aggregateRecords: null ); } @@ -127,7 +137,8 @@ public static DmlToolsConfig FromBoolean(bool enabled) readRecords: null, updateRecord: null, deleteRecord: null, - executeEntity: null + executeEntity: null, + aggregateRecords: null ); /// @@ -185,4 +196,12 @@ public static DmlToolsConfig FromBoolean(bool enabled) [JsonIgnore(Condition = JsonIgnoreCondition.Always)] [MemberNotNullWhen(true, nameof(ExecuteEntity))] public bool UserProvidedExecuteEntity { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write aggregate-records + /// property/value to the runtime config file. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(AggregateRecords))] + public bool UserProvidedAggregateRecords { get; init; } = false; } diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs new file mode 100644 index 0000000000..a1fb2b691c --- /dev/null +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -0,0 +1,596 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.BuiltInTools; +using Azure.DataApiBuilder.Mcp.Model; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for the AggregateRecordsTool MCP tool. + /// Covers: + /// - Tool metadata and schema validation + /// - Runtime-level enabled/disabled configuration + /// - Entity-level DML tool configuration + /// - Input validation (missing/invalid arguments) + /// - In-memory aggregation logic (count, avg, sum, min, max) + /// - distinct, groupby, having, orderby + /// - Alias convention + /// + [TestClass] + public class AggregateRecordsToolTests + { + #region Tool Metadata Tests + + [TestMethod] + public void GetToolMetadata_ReturnsCorrectName() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + Assert.AreEqual("aggregate_records", metadata.Name); + } + + [TestMethod] + public void GetToolMetadata_ReturnsCorrectToolType() + { + AggregateRecordsTool tool = new(); + Assert.AreEqual(McpEnums.ToolType.BuiltIn, tool.ToolType); + } + + [TestMethod] + public void GetToolMetadata_HasInputSchema() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + Assert.AreEqual(JsonValueKind.Object, metadata.InputSchema.ValueKind); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("properties", out _)); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("required", out JsonElement required)); + + List requiredFields = new(); + foreach (JsonElement r in required.EnumerateArray()) + { + requiredFields.Add(r.GetString()!); + } + + CollectionAssert.Contains(requiredFields, "entity"); + CollectionAssert.Contains(requiredFields, "function"); + CollectionAssert.Contains(requiredFields, "field"); + } + + #endregion + + #region Configuration Tests + + [TestMethod] + public async Task AggregateRecords_DisabledAtRuntimeLevel_ReturnsToolDisabledError() + { + RuntimeConfig config = CreateConfig(aggregateRecordsEnabled: false); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + AssertToolDisabledError(content); + } + + [TestMethod] + public async Task AggregateRecords_DisabledAtEntityLevel_ReturnsToolDisabledError() + { + RuntimeConfig config = CreateConfigWithEntityDmlDisabled(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + AssertToolDisabledError(content); + } + + #endregion + + #region Input Validation Tests + + [TestMethod] + public async Task AggregateRecords_NullArguments_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + CallToolResult result = await tool.ExecuteAsync(null, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.AreEqual("InvalidArguments", error.GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingEntity_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingFunction_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingField_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_InvalidFunction_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"median\", \"field\": \"price\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("median")); + } + + #endregion + + #region Alias Convention Tests + + [TestMethod] + public void ComputeAlias_CountStar_ReturnsCount() + { + Assert.AreEqual("count", AggregateRecordsTool.ComputeAlias("count", "*")); + } + + [TestMethod] + public void ComputeAlias_CountField_ReturnsFunctionField() + { + Assert.AreEqual("count_supplierId", AggregateRecordsTool.ComputeAlias("count", "supplierId")); + } + + [TestMethod] + public void ComputeAlias_AvgField_ReturnsFunctionField() + { + Assert.AreEqual("avg_unitPrice", AggregateRecordsTool.ComputeAlias("avg", "unitPrice")); + } + + [TestMethod] + public void ComputeAlias_SumField_ReturnsFunctionField() + { + Assert.AreEqual("sum_unitPrice", AggregateRecordsTool.ComputeAlias("sum", "unitPrice")); + } + + [TestMethod] + public void ComputeAlias_MinField_ReturnsFunctionField() + { + Assert.AreEqual("min_price", AggregateRecordsTool.ComputeAlias("min", "price")); + } + + [TestMethod] + public void ComputeAlias_MaxField_ReturnsFunctionField() + { + Assert.AreEqual("max_price", AggregateRecordsTool.ComputeAlias("max", "price")); + } + + #endregion + + #region In-Memory Aggregation Tests + + [TestMethod] + public void PerformAggregation_CountStar_ReturnsCount() + { + JsonElement records = ParseArray("[{\"id\":1},{\"id\":2},{\"id\":3}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(3.0, result[0]["count"]); + } + + [TestMethod] + public void PerformAggregation_Avg_ReturnsAverage() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["avg_price"]); + } + + [TestMethod] + public void PerformAggregation_Sum_ReturnsSum() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), null, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(60.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_Min_ReturnsMin() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "min", "price", false, new(), null, null, "desc", "min_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(5.0, result[0]["min_price"]); + } + + [TestMethod] + public void PerformAggregation_Max_ReturnsMax() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "max", "price", false, new(), null, null, "desc", "max_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["max_price"]); + } + + [TestMethod] + public void PerformAggregation_CountDistinct_ReturnsDistinctCount() + { + JsonElement records = ParseArray("[{\"supplierId\":1},{\"supplierId\":2},{\"supplierId\":1},{\"supplierId\":3}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", "count_supplierId"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(3.0, result[0]["count_supplierId"]); + } + + [TestMethod] + public void PerformAggregation_AvgDistinct_ReturnsDistinctAvg() + { + JsonElement records = ParseArray("[{\"price\":10},{\"price\":10},{\"price\":20},{\"price\":30}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", true, new(), null, null, "desc", "avg_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["avg_price"]); + } + + [TestMethod] + public void PerformAggregation_GroupBy_ReturnsGroupedResults() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":50}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "desc", "sum_price"); + + Assert.AreEqual(2, result.Count); + // Desc order: B(50) first, then A(30) + Assert.AreEqual("B", result[0]["category"]?.ToString()); + Assert.AreEqual(50.0, result[0]["sum_price"]); + Assert.AreEqual("A", result[1]["category"]?.ToString()); + Assert.AreEqual(30.0, result[1]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_GroupBy_Asc_ReturnsSortedAsc() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":30},{\"category\":\"A\",\"price\":20}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "asc", "sum_price"); + + Assert.AreEqual(2, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + Assert.AreEqual(30.0, result[0]["sum_price"]); + Assert.AreEqual("B", result[1]["category"]?.ToString()); + Assert.AreEqual(30.0, result[1]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_CountStar_GroupBy_ReturnsGroupCounts() + { + JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, null, "desc", "count"); + + Assert.AreEqual(2, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + Assert.AreEqual(2.0, result[0]["count"]); + Assert.AreEqual("B", result[1]["category"]?.ToString()); + Assert.AreEqual(1.0, result[1]["count"]); + } + + [TestMethod] + public void PerformAggregation_HavingGt_FiltersResults() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":5}]"); + var having = new Dictionary { ["gt"] = 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + Assert.AreEqual(30.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_HavingGteLte_FiltersRange() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":100},{\"category\":\"B\",\"price\":20},{\"category\":\"C\",\"price\":1}]"); + var having = new Dictionary { ["gte"] = 10, ["lte"] = 50 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("B", result[0]["category"]?.ToString()); + } + + [TestMethod] + public void PerformAggregation_HavingIn_FiltersExactValues() + { + JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"},{\"category\":\"C\"},{\"category\":\"C\"},{\"category\":\"C\"}]"); + var havingIn = new List { 2, 3 }; + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, havingIn, "desc", "count"); + + Assert.AreEqual(2, result.Count); + // C(3) desc, A(2) + Assert.AreEqual("C", result[0]["category"]?.ToString()); + Assert.AreEqual(3.0, result[0]["count"]); + Assert.AreEqual("A", result[1]["category"]?.ToString()); + Assert.AreEqual(2.0, result[1]["count"]); + } + + [TestMethod] + public void PerformAggregation_HavingEq_FiltersSingleValue() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]"); + var having = new Dictionary { ["eq"] = 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("A", result[0]["category"]?.ToString()); + } + + [TestMethod] + public void PerformAggregation_HavingNeq_FiltersOutValue() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]"); + var having = new Dictionary { ["neq"] = 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("B", result[0]["category"]?.ToString()); + } + + [TestMethod] + public void PerformAggregation_EmptyRecords_ReturnsNull() + { + JsonElement records = ParseArray("[]"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price"); + + Assert.AreEqual(1, result.Count); + Assert.IsNull(result[0]["avg_price"]); + } + + [TestMethod] + public void PerformAggregation_EmptyRecordsCountStar_ReturnsZero() + { + JsonElement records = ParseArray("[]"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(0.0, result[0]["count"]); + } + + [TestMethod] + public void PerformAggregation_MultipleGroupByFields_ReturnsCorrectGroups() + { + JsonElement records = ParseArray("[{\"cat\":\"A\",\"region\":\"East\",\"price\":10},{\"cat\":\"A\",\"region\":\"East\",\"price\":20},{\"cat\":\"A\",\"region\":\"West\",\"price\":5}]"); + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "cat", "region" }, null, null, "desc", "sum_price"); + + Assert.AreEqual(2, result.Count); + // (A,East)=30 desc, (A,West)=5 + Assert.AreEqual("A", result[0]["cat"]?.ToString()); + Assert.AreEqual("East", result[0]["region"]?.ToString()); + Assert.AreEqual(30.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_HavingNoResults_ReturnsEmpty() + { + JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10}]"); + var having = new Dictionary { ["gt"] = 100 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); + + Assert.AreEqual(0, result.Count); + } + + [TestMethod] + public void PerformAggregation_HavingOnSingleResult_Passes() + { + JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]"); + var having = new Dictionary { ["gte"] = 100 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(110.0, result[0]["sum_price"]); + } + + [TestMethod] + public void PerformAggregation_HavingOnSingleResult_Fails() + { + JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]"); + var having = new Dictionary { ["gt"] = 200 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price"); + + Assert.AreEqual(0, result.Count); + } + + #endregion + + #region Helper Methods + + private static JsonElement ParseArray(string json) + { + return JsonDocument.Parse(json).RootElement; + } + + private static JsonElement ParseContent(CallToolResult result) + { + TextContentBlock firstContent = (TextContentBlock)result.Content[0]; + return JsonDocument.Parse(firstContent.Text).RootElement; + } + + private static void AssertToolDisabledError(JsonElement content) + { + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.IsTrue(error.TryGetProperty("type", out JsonElement errorType)); + Assert.AreEqual("ToolDisabled", errorType.GetString()); + } + + private static RuntimeConfig CreateConfig(bool aggregateRecordsEnabled = true) + { + Dictionary entities = new() + { + ["Book"] = new Entity( + Source: new("books", EntitySourceType.Table, null, null), + GraphQL: new("Book", "Books"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: null + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: aggregateRecordsEnabled + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + private static RuntimeConfig CreateConfigWithEntityDmlDisabled() + { + Dictionary entities = new() + { + ["Book"] = new Entity( + Source: new("books", EntitySourceType.Table, null, null), + GraphQL: new("Book", "Books"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: new EntityMcpOptions(customToolEnabled: false, dmlToolsEnabled: false) + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + private static IServiceProvider CreateServiceProvider(RuntimeConfig config) + { + ServiceCollection services = new(); + + RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + + Mock mockAuthResolver = new(); + mockAuthResolver.Setup(x => x.IsValidRoleContext(It.IsAny())).Returns(true); + services.AddSingleton(mockAuthResolver.Object); + + Mock mockHttpContext = new(); + Mock mockRequest = new(); + mockRequest.Setup(x => x.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]).Returns("anonymous"); + mockHttpContext.Setup(x => x.Request).Returns(mockRequest.Object); + + Mock mockHttpContextAccessor = new(); + mockHttpContextAccessor.Setup(x => x.HttpContext).Returns(mockHttpContext.Object); + services.AddSingleton(mockHttpContextAccessor.Object); + + services.AddLogging(); + + return services.BuildServiceProvider(); + } + + #endregion + } +} diff --git a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs index d2f6554cd3..b4ae074207 100644 --- a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs +++ b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs @@ -48,6 +48,7 @@ public class EntityLevelDmlToolConfigurationTests [DataRow("UpdateRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}, \"fields\": {\"title\": \"Updated\"}}", false, DisplayName = "UpdateRecord respects entity-level DmlToolEnabled=false")] [DataRow("DeleteRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}}", false, DisplayName = "DeleteRecord respects entity-level DmlToolEnabled=false")] [DataRow("ExecuteEntity", "{\"entity\": \"GetBook\"}", true, DisplayName = "ExecuteEntity respects entity-level DmlToolEnabled=false")] + [DataRow("AggregateRecords", "{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}", false, DisplayName = "AggregateRecords respects entity-level DmlToolEnabled=false")] public async Task DmlTool_RespectsEntityLevelDmlToolDisabled(string toolType, string jsonArguments, bool isStoredProcedure) { // Arrange @@ -238,6 +239,7 @@ private static IMcpTool CreateTool(string toolType) "UpdateRecord" => new UpdateRecordTool(), "DeleteRecord" => new DeleteRecordTool(), "ExecuteEntity" => new ExecuteEntityTool(), + "AggregateRecords" => new AggregateRecordsTool(), _ => throw new ArgumentException($"Unknown tool type: {toolType}", nameof(toolType)) }; } From f855e96eba1cf6de7cc782b4d014fb795b703e48 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:22:14 +0000 Subject: [PATCH 03/32] Add first/after pagination support to aggregate_records tool Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 102 +++++++++++ .../Mcp/AggregateRecordsToolTests.cs | 158 +++++++++++++++++- 2 files changed, 259 insertions(+), 1 deletion(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index e64710e46e..2fa3bfa89c 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -98,6 +98,15 @@ public Tool GetToolMetadata() ""description"": ""Aggregated value is in the given list."" } } + }, + ""first"": { + ""type"": ""integer"", + ""description"": ""Maximum number of results to return. Used for pagination. Only applies with groupby."", + ""minimum"": 1 + }, + ""after"": { + ""type"": ""string"", + ""description"": ""Cursor for pagination. Returns results after this cursor. Only applies with groupby and first."" } }, ""required"": [""entity"", ""function"", ""field""] @@ -166,6 +175,18 @@ public async Task ExecuteAsync( string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null; string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc"; + int? first = null; + if (root.TryGetProperty("first", out JsonElement firstEl) && firstEl.ValueKind == JsonValueKind.Number) + { + first = firstEl.GetInt32(); + if (first < 1) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger); + } + } + + string? after = root.TryGetProperty("after", out JsonElement afterEl) ? afterEl.GetString() : null; + List groupby = new(); if (root.TryGetProperty("groupby", out JsonElement groupbyEl) && groupbyEl.ValueKind == JsonValueKind.Array) { @@ -311,6 +332,26 @@ public async Task ExecuteAsync( List> aggregatedResults = PerformAggregation( records, function, field, distinct, groupby, havingOps, havingIn, orderby, alias); + // Apply pagination if first is specified with groupby + if (first.HasValue && groupby.Count > 0) + { + PaginationResult paginatedResult = ApplyPagination(aggregatedResults, first.Value, after); + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = new Dictionary + { + ["items"] = paginatedResult.Items, + ["endCursor"] = paginatedResult.EndCursor, + ["hasNextPage"] = paginatedResult.HasNextPage + }, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); + } + return McpResponseBuilder.BuildSuccessResult( new Dictionary { @@ -450,6 +491,67 @@ internal static string ComputeAlias(string function, string field) } } + /// + /// Represents the result of applying pagination to aggregated results. + /// + internal sealed class PaginationResult + { + public List> Items { get; set; } = new(); + public string? EndCursor { get; set; } + public bool HasNextPage { get; set; } + } + + /// + /// Applies cursor-based pagination to aggregated results. + /// The cursor is an opaque base64-encoded offset integer. + /// + internal static PaginationResult ApplyPagination( + List> allResults, + int first, + string? after) + { + int startIndex = 0; + + if (!string.IsNullOrWhiteSpace(after)) + { + try + { + byte[] bytes = Convert.FromBase64String(after); + string decoded = System.Text.Encoding.UTF8.GetString(bytes); + if (int.TryParse(decoded, out int cursorOffset)) + { + startIndex = cursorOffset; + } + } + catch (FormatException) + { + // Invalid cursor format; start from beginning + } + } + + List> pageItems = allResults + .Skip(startIndex) + .Take(first) + .ToList(); + + bool hasNextPage = startIndex + first < allResults.Count; + string? endCursor = null; + + if (pageItems.Count > 0) + { + int lastItemIndex = startIndex + pageItems.Count; + endCursor = Convert.ToBase64String( + System.Text.Encoding.UTF8.GetBytes(lastItemIndex.ToString())); + } + + return new PaginationResult + { + Items = pageItems, + EndCursor = endCursor, + HasNextPage = hasNextPage + }; + } + private static double? ComputeAggregateValue(List records, string function, string field, bool distinct, bool isCountStar) { if (isCountStar) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index a1fb2b691c..f7e3930d7b 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -57,7 +57,7 @@ public void GetToolMetadata_HasInputSchema() AggregateRecordsTool tool = new(); Tool metadata = tool.GetToolMetadata(); Assert.AreEqual(JsonValueKind.Object, metadata.InputSchema.ValueKind); - Assert.IsTrue(metadata.InputSchema.TryGetProperty("properties", out _)); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("properties", out JsonElement properties)); Assert.IsTrue(metadata.InputSchema.TryGetProperty("required", out JsonElement required)); List requiredFields = new(); @@ -69,6 +69,12 @@ public void GetToolMetadata_HasInputSchema() CollectionAssert.Contains(requiredFields, "entity"); CollectionAssert.Contains(requiredFields, "function"); CollectionAssert.Contains(requiredFields, "field"); + + // Verify first and after properties exist in schema + Assert.IsTrue(properties.TryGetProperty("first", out JsonElement firstProp)); + Assert.AreEqual("integer", firstProp.GetProperty("type").GetString()); + Assert.IsTrue(properties.TryGetProperty("after", out JsonElement afterProp)); + Assert.AreEqual("string", afterProp.GetProperty("type").GetString()); } #endregion @@ -460,6 +466,156 @@ public void PerformAggregation_HavingOnSingleResult_Fails() #endregion + #region Pagination Tests + + [TestMethod] + public void ApplyPagination_FirstOnly_ReturnsFirstNItems() + { + List> allResults = new() + { + new() { ["category"] = "A", ["count"] = 10.0 }, + new() { ["category"] = "B", ["count"] = 8.0 }, + new() { ["category"] = "C", ["count"] = 6.0 }, + new() { ["category"] = "D", ["count"] = 4.0 }, + new() { ["category"] = "E", ["count"] = 2.0 } + }; + + AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 3, null); + + Assert.AreEqual(3, result.Items.Count); + Assert.AreEqual("A", result.Items[0]["category"]?.ToString()); + Assert.AreEqual("C", result.Items[2]["category"]?.ToString()); + Assert.IsTrue(result.HasNextPage); + Assert.IsNotNull(result.EndCursor); + } + + [TestMethod] + public void ApplyPagination_FirstWithAfter_ReturnsNextPage() + { + List> allResults = new() + { + new() { ["category"] = "A", ["count"] = 10.0 }, + new() { ["category"] = "B", ["count"] = 8.0 }, + new() { ["category"] = "C", ["count"] = 6.0 }, + new() { ["category"] = "D", ["count"] = 4.0 }, + new() { ["category"] = "E", ["count"] = 2.0 } + }; + + // First page + AggregateRecordsTool.PaginationResult firstPage = AggregateRecordsTool.ApplyPagination(allResults, 3, null); + Assert.AreEqual(3, firstPage.Items.Count); + Assert.IsTrue(firstPage.HasNextPage); + + // Second page using cursor from first page + AggregateRecordsTool.PaginationResult secondPage = AggregateRecordsTool.ApplyPagination(allResults, 3, firstPage.EndCursor); + Assert.AreEqual(2, secondPage.Items.Count); + Assert.AreEqual("D", secondPage.Items[0]["category"]?.ToString()); + Assert.AreEqual("E", secondPage.Items[1]["category"]?.ToString()); + Assert.IsFalse(secondPage.HasNextPage); + } + + [TestMethod] + public void ApplyPagination_FirstExceedsTotalCount_ReturnsAllItems() + { + List> allResults = new() + { + new() { ["category"] = "A", ["count"] = 10.0 }, + new() { ["category"] = "B", ["count"] = 8.0 } + }; + + AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, null); + + Assert.AreEqual(2, result.Items.Count); + Assert.IsFalse(result.HasNextPage); + } + + [TestMethod] + public void ApplyPagination_FirstExactlyMatchesTotalCount_HasNextPageIsFalse() + { + List> allResults = new() + { + new() { ["category"] = "A", ["count"] = 10.0 }, + new() { ["category"] = "B", ["count"] = 8.0 }, + new() { ["category"] = "C", ["count"] = 6.0 } + }; + + AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 3, null); + + Assert.AreEqual(3, result.Items.Count); + Assert.IsFalse(result.HasNextPage); + } + + [TestMethod] + public void ApplyPagination_EmptyResults_ReturnsEmptyPage() + { + List> allResults = new(); + + AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, null); + + Assert.AreEqual(0, result.Items.Count); + Assert.IsFalse(result.HasNextPage); + Assert.IsNull(result.EndCursor); + } + + [TestMethod] + public void ApplyPagination_InvalidCursor_StartsFromBeginning() + { + List> allResults = new() + { + new() { ["category"] = "A", ["count"] = 10.0 }, + new() { ["category"] = "B", ["count"] = 8.0 } + }; + + AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, "not-valid-base64!!!"); + + Assert.AreEqual(2, result.Items.Count); + Assert.AreEqual("A", result.Items[0]["category"]?.ToString()); + } + + [TestMethod] + public void ApplyPagination_CursorBeyondResults_ReturnsEmptyPage() + { + List> allResults = new() + { + new() { ["category"] = "A", ["count"] = 10.0 } + }; + + // Cursor pointing beyond the end + string cursor = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("100")); + AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, cursor); + + Assert.AreEqual(0, result.Items.Count); + Assert.IsFalse(result.HasNextPage); + Assert.IsNull(result.EndCursor); + } + + [TestMethod] + public void ApplyPagination_MultiplePages_TraversesAllResults() + { + List> allResults = new(); + for (int i = 0; i < 8; i++) + { + allResults.Add(new() { ["category"] = $"Cat{i}", ["count"] = (double)(8 - i) }); + } + + // Page 1 + AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 3, null); + Assert.AreEqual(3, page1.Items.Count); + Assert.IsTrue(page1.HasNextPage); + + // Page 2 + AggregateRecordsTool.PaginationResult page2 = AggregateRecordsTool.ApplyPagination(allResults, 3, page1.EndCursor); + Assert.AreEqual(3, page2.Items.Count); + Assert.IsTrue(page2.HasNextPage); + + // Page 3 (last page) + AggregateRecordsTool.PaginationResult page3 = AggregateRecordsTool.ApplyPagination(allResults, 3, page2.EndCursor); + Assert.AreEqual(2, page3.Items.Count); + Assert.IsFalse(page3.HasNextPage); + } + + #endregion + #region Helper Methods private static JsonElement ParseArray(string json) From 35733211c35a2f1eab877f9cfdadbdb03b8fd53b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:29:15 +0000 Subject: [PATCH 04/32] Add exhaustive tool instructions and all 13 spec example tests Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 55 ++- .../Mcp/AggregateRecordsToolTests.cs | 439 ++++++++++++++++++ 2 files changed, 476 insertions(+), 18 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 2fa3bfa89c..c6fbd08198 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -42,71 +42,90 @@ public Tool GetToolMetadata() return new Tool { Name = "aggregate_records", - Description = "STEP 1: describe_entities -> find entities with READ permission and their fields. STEP 2: call this tool to compute aggregations (count, avg, sum, min, max) with optional filter, groupby, having, and orderby.", + Description = "Computes aggregations (count, avg, sum, min, max) on entity data. " + + "STEP 1: Call describe_entities to discover entities with READ permission and their field names. " + + "STEP 2: Call this tool with the exact entity name, an aggregation function, and a field name from STEP 1. " + + "REQUIRED: entity (exact entity name), function (one of: count, avg, sum, min, max), field (exact field name, or '*' ONLY for count). " + + "OPTIONAL: filter (OData WHERE clause applied before aggregating, e.g. 'unitPrice lt 10'), " + + "distinct (true to deduplicate values before aggregating), " + + "groupby (array of field names to group results by, e.g. ['categoryName']), " + + "orderby ('asc' or 'desc' to sort grouped results by aggregated value; requires groupby), " + + "having (object to filter groups after aggregating, operators: eq, neq, gt, gte, lt, lte, in; requires groupby), " + + "first (integer >= 1, maximum grouped results to return; requires groupby), " + + "after (opaque cursor string from a previous response's endCursor; requires first and groupby). " + + "RESPONSE: The aggregated value is aliased as '{function}_{field}' (e.g. avg_unitPrice, sum_revenue). " + + "For count with field '*', the alias is 'count'. " + + "When first is used with groupby, response contains: items (array), endCursor (string), hasNextPage (boolean). " + + "RULES: 1) ALWAYS call describe_entities first to get valid entity and field names. " + + "2) Use field '*' ONLY with function 'count'. " + + "3) For avg, sum, min, max: field MUST be a numeric field name from describe_entities. " + + "4) orderby, having, first, and after ONLY apply when groupby is provided. " + + "5) after REQUIRES first to also be set. " + + "6) Use first and after for paginating large grouped result sets.", InputSchema = JsonSerializer.Deserialize( @"{ ""type"": ""object"", ""properties"": { ""entity"": { ""type"": ""string"", - ""description"": ""Entity name with READ permission."" + ""description"": ""Exact entity name from describe_entities that has READ permission. Must match exactly (case-sensitive)."" }, ""function"": { ""type"": ""string"", ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], - ""description"": ""Aggregation function to apply."" + ""description"": ""Aggregation function to apply. Use 'count' to count records, 'avg' for average, 'sum' for total, 'min' for minimum, 'max' for maximum. For count use field '*' or a specific field name. For avg, sum, min, max the field must be numeric."" }, ""field"": { ""type"": ""string"", - ""description"": ""Field to aggregate. Use '*' for count."" + ""description"": ""Exact field name from describe_entities to aggregate. Use '*' ONLY with function 'count' to count all records. For avg, sum, min, max, provide a numeric field name."" }, ""distinct"": { ""type"": ""boolean"", - ""description"": ""Apply DISTINCT before aggregating."", + ""description"": ""When true, removes duplicate values before applying the aggregation function. For example, count with distinct counts unique values only. Default is false."", ""default"": false }, ""filter"": { ""type"": ""string"", - ""description"": ""OData filter applied before aggregating (WHERE). Example: 'unitPrice lt 10'"", + ""description"": ""OData filter expression applied before aggregating (acts as a WHERE clause). Supported operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10' filters to rows where unitPrice is less than 10 before aggregating. Example: 'discontinued eq true and categoryName eq ''Seafood''' filters discontinued seafood products."", ""default"": """" }, ""groupby"": { ""type"": ""array"", ""items"": { ""type"": ""string"" }, - ""description"": ""Fields to group by, e.g., ['category', 'region']. Grouped field values are included in the response."", + ""description"": ""Array of exact field names from describe_entities to group results by. Each unique combination of grouped field values produces one aggregated row. Grouped field values are included in the response alongside the aggregated value. Example: ['categoryName'] groups by category. Example: ['categoryName', 'region'] groups by both fields."", ""default"": [] }, ""orderby"": { ""type"": ""string"", ""enum"": [""asc"", ""desc""], - ""description"": ""Sort aggregated results by the computed value. Only applies with groupby."", + ""description"": ""Sort direction for grouped results by the computed aggregated value. 'desc' returns highest values first, 'asc' returns lowest first. ONLY applies when groupby is provided. Default is 'desc'."", ""default"": ""desc"" }, ""having"": { ""type"": ""object"", - ""description"": ""Filter applied after aggregating on the result (HAVING). Operators are AND-ed together."", + ""description"": ""Filter applied AFTER aggregating to filter grouped results by the computed aggregated value (acts as a HAVING clause). ONLY applies when groupby is provided. Multiple operators are AND-ed together. For example, use gt with value 20 to keep groups where the aggregated value exceeds 20. Combine gte and lte to define a range."", ""properties"": { - ""eq"": { ""type"": ""number"", ""description"": ""Aggregated value equals."" }, - ""neq"": { ""type"": ""number"", ""description"": ""Aggregated value not equals."" }, - ""gt"": { ""type"": ""number"", ""description"": ""Aggregated value greater than."" }, - ""gte"": { ""type"": ""number"", ""description"": ""Aggregated value greater than or equal."" }, - ""lt"": { ""type"": ""number"", ""description"": ""Aggregated value less than."" }, - ""lte"": { ""type"": ""number"", ""description"": ""Aggregated value less than or equal."" }, + ""eq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value equals this number."" }, + ""neq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value does not equal this number."" }, + ""gt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than this number."" }, + ""gte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than or equal to this number."" }, + ""lt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than this number."" }, + ""lte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than or equal to this number."" }, ""in"": { ""type"": ""array"", ""items"": { ""type"": ""number"" }, - ""description"": ""Aggregated value is in the given list."" + ""description"": ""Keep groups where the aggregated value matches any number in this list. Example: [5, 10] keeps groups with aggregated value 5 or 10."" } } }, ""first"": { ""type"": ""integer"", - ""description"": ""Maximum number of results to return. Used for pagination. Only applies with groupby."", + ""description"": ""Maximum number of grouped results to return. Used for pagination of grouped results. ONLY applies when groupby is provided. Must be >= 1. When set, the response includes 'items', 'endCursor', and 'hasNextPage' fields for pagination."", ""minimum"": 1 }, ""after"": { ""type"": ""string"", - ""description"": ""Cursor for pagination. Returns results after this cursor. Only applies with groupby and first."" + ""description"": ""Opaque cursor string for pagination. Pass the 'endCursor' value from a previous response to get the next page of results. REQUIRES both groupby and first to be set. Do not construct this value manually; always use the endCursor from a previous response."" } }, ""required"": [""entity"", ""function"", ""field""] diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index f7e3930d7b..dce07fff80 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -570,6 +570,8 @@ public void ApplyPagination_InvalidCursor_StartsFromBeginning() Assert.AreEqual(2, result.Items.Count); Assert.AreEqual("A", result.Items[0]["category"]?.ToString()); + Assert.IsFalse(result.HasNextPage); + Assert.IsNotNull(result.EndCursor); } [TestMethod] @@ -616,6 +618,443 @@ public void ApplyPagination_MultiplePages_TraversesAllResults() #endregion + #region Spec Example Tests + + /// + /// Spec Example 1: "How many products are there?" + /// COUNT(*) → 77 + /// + [TestMethod] + public void SpecExample01_CountStar_ReturnsTotal() + { + // Build 77 product records + List items = new(); + for (int i = 1; i <= 77; i++) + { + items.Add($"{{\"id\":{i}}}"); + } + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", alias); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("count", alias); + Assert.AreEqual(77.0, result[0]["count"]); + } + + /// + /// Spec Example 2: "What is the average price of products under $10?" + /// AVG(unitPrice) WHERE unitPrice < 10 → 6.74 + /// Filter is applied at DB level; we supply pre-filtered records. + /// + [TestMethod] + public void SpecExample02_AvgWithFilter_ReturnsFilteredAverage() + { + // Pre-filtered records (unitPrice < 10) that average to 6.74 + // 4.50 + 6.00 + 9.72 = 20.22 / 3 = 6.74 + JsonElement records = ParseArray("[{\"unitPrice\":4.5},{\"unitPrice\":6.0},{\"unitPrice\":9.72}]"); + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new(), null, null, "desc", alias); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("avg_unitPrice", alias); + Assert.AreEqual(6.74, result[0]["avg_unitPrice"]); + } + + /// + /// Spec Example 3: "Which categories have more than 20 products?" + /// COUNT(*) GROUP BY categoryName HAVING COUNT(*) > 20 + /// Expected: Beverages=24, Condiments=22 + /// + [TestMethod] + public void SpecExample03_CountGroupByHavingGt_FiltersGroups() + { + List items = new(); + for (int i = 0; i < 24; i++) + { + items.Add("{\"categoryName\":\"Beverages\"}"); + } + + for (int i = 0; i < 22; i++) + { + items.Add("{\"categoryName\":\"Condiments\"}"); + } + + for (int i = 0; i < 12; i++) + { + items.Add("{\"categoryName\":\"Seafood\"}"); + } + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + var having = new Dictionary { ["gt"] = 20 }; + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, having, null, "desc", alias); + + Assert.AreEqual(2, result.Count); + // Desc order: Beverages(24), Condiments(22) + Assert.AreEqual("Beverages", result[0]["categoryName"]?.ToString()); + Assert.AreEqual(24.0, result[0]["count"]); + Assert.AreEqual("Condiments", result[1]["categoryName"]?.ToString()); + Assert.AreEqual(22.0, result[1]["count"]); + } + + /// + /// Spec Example 4: "For discontinued products, which categories have a total revenue between $500 and $10,000?" + /// SUM(unitPrice) WHERE discontinued=1 GROUP BY categoryName HAVING SUM >= 500 AND <= 10000 + /// Expected: Seafood=1834.50, Produce=742.00 + /// + [TestMethod] + public void SpecExample04_SumFilterGroupByHavingRange_ReturnsMatchingGroups() + { + // Pre-filtered (discontinued) records with prices summing per category + JsonElement records = ParseArray( + "[" + + "{\"categoryName\":\"Seafood\",\"unitPrice\":900}," + + "{\"categoryName\":\"Seafood\",\"unitPrice\":934.5}," + + "{\"categoryName\":\"Produce\",\"unitPrice\":400}," + + "{\"categoryName\":\"Produce\",\"unitPrice\":342}," + + "{\"categoryName\":\"Dairy\",\"unitPrice\":50}" + // Sum 50, below 500 + "]"); + string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); + var having = new Dictionary { ["gte"] = 500, ["lte"] = 10000 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "unitPrice", false, new() { "categoryName" }, having, null, "desc", alias); + + Assert.AreEqual(2, result.Count); + Assert.AreEqual("sum_unitPrice", alias); + // Desc order: Seafood(1834.5), Produce(742) + Assert.AreEqual("Seafood", result[0]["categoryName"]?.ToString()); + Assert.AreEqual(1834.5, result[0]["sum_unitPrice"]); + Assert.AreEqual("Produce", result[1]["categoryName"]?.ToString()); + Assert.AreEqual(742.0, result[1]["sum_unitPrice"]); + } + + /// + /// Spec Example 5: "How many distinct suppliers do we have?" + /// COUNT(DISTINCT supplierId) → 29 + /// + [TestMethod] + public void SpecExample05_CountDistinct_ReturnsDistinctCount() + { + // Build records with 29 distinct supplierIds plus duplicates + List items = new(); + for (int i = 1; i <= 29; i++) + { + items.Add($"{{\"supplierId\":{i}}}"); + } + + // Add duplicates + items.Add("{\"supplierId\":1}"); + items.Add("{\"supplierId\":5}"); + items.Add("{\"supplierId\":10}"); + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "supplierId"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", alias); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual("count_supplierId", alias); + Assert.AreEqual(29.0, result[0]["count_supplierId"]); + } + + /// + /// Spec Example 6: "Which categories have exactly 5 or 10 products?" + /// COUNT(*) GROUP BY categoryName HAVING COUNT(*) IN (5, 10) + /// Expected: Grains=5, Produce=5 + /// + [TestMethod] + public void SpecExample06_CountGroupByHavingIn_FiltersExactCounts() + { + List items = new(); + for (int i = 0; i < 5; i++) + { + items.Add("{\"categoryName\":\"Grains\"}"); + } + + for (int i = 0; i < 5; i++) + { + items.Add("{\"categoryName\":\"Produce\"}"); + } + + for (int i = 0; i < 12; i++) + { + items.Add("{\"categoryName\":\"Beverages\"}"); + } + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + var havingIn = new List { 5, 10 }; + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, havingIn, "desc", alias); + + Assert.AreEqual(2, result.Count); + // Both have count=5, same order as grouped + Assert.AreEqual(5.0, result[0]["count"]); + Assert.AreEqual(5.0, result[1]["count"]); + } + + /// + /// Spec Example 7: "What is the average distinct unit price per category, for categories averaging over $25?" + /// AVG(DISTINCT unitPrice) GROUP BY categoryName HAVING AVG(DISTINCT unitPrice) > 25 + /// Expected: Meat/Poultry=54.01, Beverages=32.50 + /// + [TestMethod] + public void SpecExample07_AvgDistinctGroupByHavingGt_FiltersAboveThreshold() + { + // Meat/Poultry: distinct prices {40.00, 68.02} → avg = 54.01 + // Beverages: distinct prices {25.00, 40.00} → avg = 32.50 + // Condiments: distinct prices {10.00, 15.00} → avg = 12.50 (below threshold) + JsonElement records = ParseArray( + "[" + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":68.02}," + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + // duplicate + "{\"categoryName\":\"Beverages\",\"unitPrice\":25.00}," + + "{\"categoryName\":\"Beverages\",\"unitPrice\":40.00}," + + "{\"categoryName\":\"Beverages\",\"unitPrice\":25.00}," + // duplicate + "{\"categoryName\":\"Condiments\",\"unitPrice\":10.00}," + + "{\"categoryName\":\"Condiments\",\"unitPrice\":15.00}" + + "]"); + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + var having = new Dictionary { ["gt"] = 25 }; + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", true, new() { "categoryName" }, having, null, "desc", alias); + + Assert.AreEqual(2, result.Count); + Assert.AreEqual("avg_unitPrice", alias); + // Desc order: Meat/Poultry(54.01), Beverages(32.5) + Assert.AreEqual("Meat/Poultry", result[0]["categoryName"]?.ToString()); + Assert.AreEqual(54.01, result[0]["avg_unitPrice"]); + Assert.AreEqual("Beverages", result[1]["categoryName"]?.ToString()); + Assert.AreEqual(32.5, result[1]["avg_unitPrice"]); + } + + /// + /// Spec Example 8: "Which categories have the most products?" + /// COUNT(*) GROUP BY categoryName ORDER BY DESC + /// Expected: Confections=13, Beverages=12, Condiments=12, Seafood=12 + /// + [TestMethod] + public void SpecExample08_CountGroupByOrderByDesc_ReturnsSortedDesc() + { + List items = new(); + for (int i = 0; i < 13; i++) + { + items.Add("{\"categoryName\":\"Confections\"}"); + } + + for (int i = 0; i < 12; i++) + { + items.Add("{\"categoryName\":\"Beverages\"}"); + } + + for (int i = 0; i < 12; i++) + { + items.Add("{\"categoryName\":\"Condiments\"}"); + } + + for (int i = 0; i < 12; i++) + { + items.Add("{\"categoryName\":\"Seafood\"}"); + } + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias); + + Assert.AreEqual(4, result.Count); + Assert.AreEqual("Confections", result[0]["categoryName"]?.ToString()); + Assert.AreEqual(13.0, result[0]["count"]); + // Remaining 3 all have count=12 + Assert.AreEqual(12.0, result[1]["count"]); + Assert.AreEqual(12.0, result[2]["count"]); + Assert.AreEqual(12.0, result[3]["count"]); + } + + /// + /// Spec Example 9: "What are the cheapest categories by average price?" + /// AVG(unitPrice) GROUP BY categoryName ORDER BY ASC + /// Expected: Grains/Cereals=20.25, Condiments=23.06, Produce=32.37 + /// + [TestMethod] + public void SpecExample09_AvgGroupByOrderByAsc_ReturnsSortedAsc() + { + // Grains/Cereals: {15.50, 25.00} → avg = 20.25 + // Condiments: {20.12, 26.00} → avg = 23.06 + // Produce: {28.74, 36.00} → avg = 32.37 + JsonElement records = ParseArray( + "[" + + "{\"categoryName\":\"Grains/Cereals\",\"unitPrice\":15.50}," + + "{\"categoryName\":\"Grains/Cereals\",\"unitPrice\":25.00}," + + "{\"categoryName\":\"Condiments\",\"unitPrice\":20.12}," + + "{\"categoryName\":\"Condiments\",\"unitPrice\":26.00}," + + "{\"categoryName\":\"Produce\",\"unitPrice\":28.74}," + + "{\"categoryName\":\"Produce\",\"unitPrice\":36.00}" + + "]"); + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new() { "categoryName" }, null, null, "asc", alias); + + Assert.AreEqual(3, result.Count); + // Asc order: Grains/Cereals(20.25), Condiments(23.06), Produce(32.37) + Assert.AreEqual("Grains/Cereals", result[0]["categoryName"]?.ToString()); + Assert.AreEqual(20.25, result[0]["avg_unitPrice"]); + Assert.AreEqual("Condiments", result[1]["categoryName"]?.ToString()); + Assert.AreEqual(23.06, result[1]["avg_unitPrice"]); + Assert.AreEqual("Produce", result[2]["categoryName"]?.ToString()); + Assert.AreEqual(32.37, result[2]["avg_unitPrice"]); + } + + /// + /// Spec Example 10: "For categories with over $500 revenue from discontinued products, which has the highest total?" + /// SUM(unitPrice) WHERE discontinued=1 GROUP BY categoryName HAVING SUM > 500 ORDER BY DESC + /// Expected: Seafood=1834.50, Meat/Poultry=1062.50, Produce=742.00 + /// + [TestMethod] + public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_ReturnsSortedFiltered() + { + // Pre-filtered (discontinued) records + JsonElement records = ParseArray( + "[" + + "{\"categoryName\":\"Seafood\",\"unitPrice\":900}," + + "{\"categoryName\":\"Seafood\",\"unitPrice\":934.5}," + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":500}," + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":562.5}," + + "{\"categoryName\":\"Produce\",\"unitPrice\":400}," + + "{\"categoryName\":\"Produce\",\"unitPrice\":342}," + + "{\"categoryName\":\"Dairy\",\"unitPrice\":50}" + // Sum 50, below 500 + "]"); + string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); + var having = new Dictionary { ["gt"] = 500 }; + var result = AggregateRecordsTool.PerformAggregation(records, "sum", "unitPrice", false, new() { "categoryName" }, having, null, "desc", alias); + + Assert.AreEqual(3, result.Count); + // Desc order: Seafood(1834.5), Meat/Poultry(1062.5), Produce(742) + Assert.AreEqual("Seafood", result[0]["categoryName"]?.ToString()); + Assert.AreEqual(1834.5, result[0]["sum_unitPrice"]); + Assert.AreEqual("Meat/Poultry", result[1]["categoryName"]?.ToString()); + Assert.AreEqual(1062.5, result[1]["sum_unitPrice"]); + Assert.AreEqual("Produce", result[2]["categoryName"]?.ToString()); + Assert.AreEqual(742.0, result[2]["sum_unitPrice"]); + } + + /// + /// Spec Example 11: "Show me the first 5 categories by product count" + /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 + /// Expected: 5 items with hasNextPage=true, endCursor set + /// + [TestMethod] + public void SpecExample11_CountGroupByOrderByDescFirst5_ReturnsPaginatedResults() + { + List items = new(); + string[] categories = { "Confections", "Beverages", "Condiments", "Seafood", "Dairy", "Grains/Cereals", "Meat/Poultry", "Produce" }; + int[] counts = { 13, 12, 12, 12, 10, 7, 6, 5 }; + for (int c = 0; c < categories.Length; c++) + { + for (int i = 0; i < counts[c]; i++) + { + items.Add($"{{\"categoryName\":\"{categories[c]}\"}}"); + } + } + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + var allResults = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias); + + Assert.AreEqual(8, allResults.Count); + + // Apply pagination: first=5 + AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 5, null); + + Assert.AreEqual(5, page1.Items.Count); + Assert.AreEqual("Confections", page1.Items[0]["categoryName"]?.ToString()); + Assert.AreEqual(13.0, page1.Items[0]["count"]); + Assert.AreEqual("Dairy", page1.Items[4]["categoryName"]?.ToString()); + Assert.AreEqual(10.0, page1.Items[4]["count"]); + Assert.IsTrue(page1.HasNextPage); + Assert.IsNotNull(page1.EndCursor); + } + + /// + /// Spec Example 12: "Show me the next 5 categories" (continuation of Example 11) + /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 AFTER cursor + /// Expected: 3 items (remaining), hasNextPage=false + /// + [TestMethod] + public void SpecExample12_CountGroupByOrderByDescFirst5After_ReturnsNextPage() + { + List items = new(); + string[] categories = { "Confections", "Beverages", "Condiments", "Seafood", "Dairy", "Grains/Cereals", "Meat/Poultry", "Produce" }; + int[] counts = { 13, 12, 12, 12, 10, 7, 6, 5 }; + for (int c = 0; c < categories.Length; c++) + { + for (int i = 0; i < counts[c]; i++) + { + items.Add($"{{\"categoryName\":\"{categories[c]}\"}}"); + } + } + + JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + var allResults = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias); + + // Page 1 + AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 5, null); + Assert.IsTrue(page1.HasNextPage); + + // Page 2 (continuation) + AggregateRecordsTool.PaginationResult page2 = AggregateRecordsTool.ApplyPagination(allResults, 5, page1.EndCursor); + + Assert.AreEqual(3, page2.Items.Count); + Assert.AreEqual("Grains/Cereals", page2.Items[0]["categoryName"]?.ToString()); + Assert.AreEqual(7.0, page2.Items[0]["count"]); + Assert.AreEqual("Meat/Poultry", page2.Items[1]["categoryName"]?.ToString()); + Assert.AreEqual(6.0, page2.Items[1]["count"]); + Assert.AreEqual("Produce", page2.Items[2]["categoryName"]?.ToString()); + Assert.AreEqual(5.0, page2.Items[2]["count"]); + Assert.IsFalse(page2.HasNextPage); + } + + /// + /// Spec Example 13: "Show me the top 3 most expensive categories by average price" + /// AVG(unitPrice) GROUP BY categoryName ORDER BY DESC FIRST 3 + /// Expected: Meat/Poultry=54.01, Beverages=37.98, Seafood=37.08 + /// + [TestMethod] + public void SpecExample13_AvgGroupByOrderByDescFirst3_ReturnsTop3() + { + // Meat/Poultry: {40.00, 68.02} → avg = 54.01 + // Beverages: {30.96, 45.00} → avg = 37.98 + // Seafood: {25.16, 49.00} → avg = 37.08 + // Condiments: {10.00, 15.00} → avg = 12.50 + JsonElement records = ParseArray( + "[" + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + + "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":68.02}," + + "{\"categoryName\":\"Beverages\",\"unitPrice\":30.96}," + + "{\"categoryName\":\"Beverages\",\"unitPrice\":45.00}," + + "{\"categoryName\":\"Seafood\",\"unitPrice\":25.16}," + + "{\"categoryName\":\"Seafood\",\"unitPrice\":49.00}," + + "{\"categoryName\":\"Condiments\",\"unitPrice\":10.00}," + + "{\"categoryName\":\"Condiments\",\"unitPrice\":15.00}" + + "]"); + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + var allResults = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new() { "categoryName" }, null, null, "desc", alias); + + Assert.AreEqual(4, allResults.Count); + + // Apply pagination: first=3 + AggregateRecordsTool.PaginationResult page = AggregateRecordsTool.ApplyPagination(allResults, 3, null); + + Assert.AreEqual(3, page.Items.Count); + Assert.AreEqual("Meat/Poultry", page.Items[0]["categoryName"]?.ToString()); + Assert.AreEqual(54.01, page.Items[0]["avg_unitPrice"]); + Assert.AreEqual("Beverages", page.Items[1]["categoryName"]?.ToString()); + Assert.AreEqual(37.98, page.Items[1]["avg_unitPrice"]); + Assert.AreEqual("Seafood", page.Items[2]["categoryName"]?.ToString()); + Assert.AreEqual(37.08, page.Items[2]["avg_unitPrice"]); + Assert.IsTrue(page.HasNextPage); + } + + #endregion + #region Helper Methods private static JsonElement ParseArray(string json) From f66bf3f8289c3b199c90f93dfdbbfcea96f6f2e9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:48:11 +0000 Subject: [PATCH 05/32] Changes before error encountered Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- schemas/dab.draft.schema.json | 5 + .../BuiltInTools/AggregateRecordsTool.cs | 43 +++++- .../Utils/McpTelemetryErrorCodes.cs | 5 + .../Utils/McpTelemetryHelper.cs | 2 + .../Mcp/AggregateRecordsToolTests.cs | 122 ++++++++++++++++++ 5 files changed, 173 insertions(+), 4 deletions(-) diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index cbe38b7d72..ec1afc063a 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -315,6 +315,11 @@ "type": "boolean", "description": "Enable/disable the execute-entity tool.", "default": false + }, + "aggregate-records": { + "type": "boolean", + "description": "Enable/disable the aggregate-records tool.", + "default": false } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index c6fbd08198..59bd465ad0 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -150,6 +150,8 @@ public async Task ExecuteAsync( return McpErrorHelpers.ToolDisabled(toolName, logger); } + string entityName = string.Empty; + try { cancellationToken.ThrowIfCancellationRequested(); @@ -162,11 +164,13 @@ public async Task ExecuteAsync( JsonElement root = arguments.RootElement; // Parse required arguments - if (!McpArgumentParser.TryParseEntity(root, out string entityName, out string parseError)) + if (!McpArgumentParser.TryParseEntity(root, out string parsedEntityName, out string parseError)) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); } + entityName = parsedEntityName; + if (runtimeConfig.Entities?.TryGetValue(entityName, out Entity? entity) == true && entity.Mcp?.DmlToolEnabled == false) { @@ -381,13 +385,44 @@ public async Task ExecuteAsync( logger, $"AggregateRecordsTool success for entity {entityName}."); } + catch (TimeoutException timeoutEx) + { + logger?.LogError(timeoutEx, "Aggregation operation timed out for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult( + toolName, + "TimeoutError", + $"The aggregation query for entity '{entityName}' timed out. " + + "This is NOT a tool error. The database did not respond in time. " + + "This may occur with large datasets or complex aggregations. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.", + logger); + } + catch (TaskCanceledException taskEx) + { + logger?.LogError(taskEx, "Aggregation task was canceled for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult( + toolName, + "TimeoutError", + $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. " + + "This is NOT a tool error. The database did not respond in time. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.", + logger); + } catch (OperationCanceledException) { - return McpResponseBuilder.BuildErrorResult(toolName, "OperationCanceled", "The aggregate operation was canceled.", logger); + logger?.LogWarning("Aggregation operation was canceled for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult( + toolName, + "OperationCanceled", + $"The aggregation query for entity '{entityName}' was canceled before completion. " + + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. " + + "No results were returned. You may retry the same request.", + logger); } - catch (DbException argEx) + catch (DbException dbEx) { - return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", argEx.Message, logger); + logger?.LogError(dbEx, "Database error during aggregation for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", dbEx.Message, logger); } catch (ArgumentException argEx) { diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs index f69a26fa5d..3ef3aa4d74 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs @@ -37,5 +37,10 @@ internal static class McpTelemetryErrorCodes /// Operation cancelled error code. /// public const string OPERATION_CANCELLED = "OperationCancelled"; + + /// + /// Operation timed out error code. + /// + public const string TIMEOUT = "Timeout"; } } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs index 2a60557f8d..eabbdc62d8 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs @@ -124,6 +124,7 @@ public static string InferOperationFromTool(IMcpTool tool, string toolName) "delete_record" => "delete", "describe_entities" => "describe", "execute_entity" => "execute", + "aggregate_records" => "aggregate", _ => "execute" // Fallback for any unknown built-in tools }; } @@ -188,6 +189,7 @@ public static string MapExceptionToErrorCode(Exception ex) return ex switch { OperationCanceledException => McpTelemetryErrorCodes.OPERATION_CANCELLED, + TimeoutException => McpTelemetryErrorCodes.TIMEOUT, DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthenticationChallenge => McpTelemetryErrorCodes.AUTHENTICATION_FAILED, DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthorizationCheckFailed diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index dce07fff80..9255a2e9c5 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -618,6 +618,128 @@ public void ApplyPagination_MultiplePages_TraversesAllResults() #endregion + #region Timeout and Cancellation Tests + + /// + /// Verifies that OperationCanceledException produces a model-explicit error + /// that clearly states the operation was canceled, not errored. + /// + [TestMethod] + public async Task AggregateRecords_OperationCanceled_ReturnsExplicitCanceledMessage() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + // Create a pre-canceled token + CancellationTokenSource cts = new(); + cts.Cancel(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, cts.Token); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + string errorType = error.GetProperty("type").GetString(); + string errorMessage = error.GetProperty("message").GetString(); + + // Verify the error type identifies it as a cancellation + Assert.AreEqual("OperationCanceled", errorType); + + // Verify the message explicitly tells the model this is NOT a tool error + Assert.IsTrue(errorMessage.Contains("NOT a tool error"), "Message must explicitly state this is NOT a tool error."); + + // Verify the message tells the model what happened + Assert.IsTrue(errorMessage.Contains("canceled"), "Message must mention the operation was canceled."); + + // Verify the message tells the model it can retry + Assert.IsTrue(errorMessage.Contains("retry"), "Message must tell the model it can retry."); + } + + /// + /// Verifies that the timeout error message provides explicit guidance to the model + /// about what happened and what to do next. + /// + [TestMethod] + public void TimeoutErrorMessage_ContainsModelGuidance() + { + // Simulate what the tool builds for a TimeoutException response + string entityName = "Product"; + string expectedMessage = $"The aggregation query for entity '{entityName}' timed out. " + + "This is NOT a tool error. The database did not respond in time. " + + "This may occur with large datasets or complex aggregations. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination."; + + // Verify message explicitly states it's NOT a tool error + Assert.IsTrue(expectedMessage.Contains("NOT a tool error"), "Timeout message must state this is NOT a tool error."); + + // Verify message explains the cause + Assert.IsTrue(expectedMessage.Contains("database did not respond"), "Timeout message must explain the database didn't respond."); + + // Verify message mentions large datasets + Assert.IsTrue(expectedMessage.Contains("large datasets"), "Timeout message must mention large datasets as a possible cause."); + + // Verify message provides actionable remediation steps + Assert.IsTrue(expectedMessage.Contains("filter"), "Timeout message must suggest using a filter."); + Assert.IsTrue(expectedMessage.Contains("groupby"), "Timeout message must suggest reducing groupby fields."); + Assert.IsTrue(expectedMessage.Contains("first"), "Timeout message must suggest using pagination with first."); + } + + /// + /// Verifies that TaskCanceledException (which typically signals HTTP/DB timeout) + /// produces a TimeoutError, not a cancellation error. + /// + [TestMethod] + public void TaskCanceledErrorMessage_ContainsTimeoutGuidance() + { + // Simulate what the tool builds for a TaskCanceledException response + string entityName = "Product"; + string expectedMessage = $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. " + + "This is NOT a tool error. The database did not respond in time. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination."; + + // TaskCanceledException should produce a TimeoutError, not OperationCanceled + Assert.IsTrue(expectedMessage.Contains("NOT a tool error"), "TaskCanceled message must state this is NOT a tool error."); + Assert.IsTrue(expectedMessage.Contains("timeout"), "TaskCanceled message must reference timeout as the cause."); + Assert.IsTrue(expectedMessage.Contains("filter"), "TaskCanceled message must suggest filter as remediation."); + Assert.IsTrue(expectedMessage.Contains("first"), "TaskCanceled message must suggest first for pagination."); + } + + /// + /// Verifies that the OperationCanceled error message for a specific entity + /// includes the entity name so the model knows which aggregation failed. + /// + [TestMethod] + public void CanceledErrorMessage_IncludesEntityName() + { + string entityName = "LargeProductCatalog"; + string expectedMessage = $"The aggregation query for entity '{entityName}' was canceled before completion. " + + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. " + + "No results were returned. You may retry the same request."; + + Assert.IsTrue(expectedMessage.Contains(entityName), "Canceled message must include the entity name."); + Assert.IsTrue(expectedMessage.Contains("No results were returned"), "Canceled message must state no results were returned."); + } + + /// + /// Verifies that the timeout error message for a specific entity + /// includes the entity name so the model knows which aggregation timed out. + /// + [TestMethod] + public void TimeoutErrorMessage_IncludesEntityName() + { + string entityName = "HugeTransactionLog"; + string expectedMessage = $"The aggregation query for entity '{entityName}' timed out. " + + "This is NOT a tool error. The database did not respond in time. " + + "This may occur with large datasets or complex aggregations. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination."; + + Assert.IsTrue(expectedMessage.Contains(entityName), "Timeout message must include the entity name."); + } + + #endregion + #region Spec Example Tests /// From 829a630ad8b9d6e427996c3c33d87ecb3d4a8653 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:56:23 +0000 Subject: [PATCH 06/32] Changes before error encountered Co-authored-by: JerryNixon <210500244+JerryNixon@users.noreply.github.com> --- schemas/dab.draft.schema.json | 6 ++++ src/Cli/Commands/ConfigureOptions.cs | 5 +++ src/Cli/ConfigGenerator.cs | 12 ++++++- src/Config/ObjectModel/McpRuntimeOptions.cs | 31 ++++++++++++++++++- .../Configurations/RuntimeConfigValidator.cs | 10 ++++++ 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index ec1afc063a..94df7ca77c 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -275,6 +275,12 @@ "description": "Allow enabling/disabling MCP requests for all entities.", "default": true }, + "query-timeout": { + "type": "integer", + "description": "Query timeout in seconds for MCP tool operations. Applies to all MCP tools that execute database queries.", + "default": 10, + "minimum": 1 + }, "dml-tools": { "oneOf": [ { diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index ecd5ecd185..fc0ab7b8e5 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -42,6 +42,7 @@ public ConfigureOptions( bool? runtimeMcpEnabled = null, string? runtimeMcpPath = null, string? runtimeMcpDescription = null, + int? runtimeMcpQueryTimeout = null, bool? runtimeMcpDmlToolsEnabled = null, bool? runtimeMcpDmlToolsDescribeEntitiesEnabled = null, bool? runtimeMcpDmlToolsCreateRecordEnabled = null, @@ -102,6 +103,7 @@ public ConfigureOptions( RuntimeMcpEnabled = runtimeMcpEnabled; RuntimeMcpPath = runtimeMcpPath; RuntimeMcpDescription = runtimeMcpDescription; + RuntimeMcpQueryTimeout = runtimeMcpQueryTimeout; RuntimeMcpDmlToolsEnabled = runtimeMcpDmlToolsEnabled; RuntimeMcpDmlToolsDescribeEntitiesEnabled = runtimeMcpDmlToolsDescribeEntitiesEnabled; RuntimeMcpDmlToolsCreateRecordEnabled = runtimeMcpDmlToolsCreateRecordEnabled; @@ -203,6 +205,9 @@ public ConfigureOptions( [Option("runtime.mcp.description", Required = false, HelpText = "Set the MCP server description to be exposed in the initialize response.")] public string? RuntimeMcpDescription { get; } + [Option("runtime.mcp.query-timeout", Required = false, HelpText = "Set the query timeout in seconds for MCP tool operations. Default: 10 (integer). Must be >= 1.")] + public int? RuntimeMcpQueryTimeout { get; } + [Option("runtime.mcp.dml-tools.enabled", Required = false, HelpText = "Enable DAB's MCP DML tools endpoint. Default: true (boolean).")] public bool? RuntimeMcpDmlToolsEnabled { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 2eaf50a822..fa632f10ac 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -876,13 +876,15 @@ private static bool TryUpdateConfiguredRuntimeOptions( if (options.RuntimeMcpEnabled != null || options.RuntimeMcpPath != null || options.RuntimeMcpDescription != null || + options.RuntimeMcpQueryTimeout != null || options.RuntimeMcpDmlToolsEnabled != null || options.RuntimeMcpDmlToolsDescribeEntitiesEnabled != null || options.RuntimeMcpDmlToolsCreateRecordEnabled != null || options.RuntimeMcpDmlToolsReadRecordsEnabled != null || options.RuntimeMcpDmlToolsUpdateRecordEnabled != null || options.RuntimeMcpDmlToolsDeleteRecordEnabled != null || - options.RuntimeMcpDmlToolsExecuteEntityEnabled != null) + options.RuntimeMcpDmlToolsExecuteEntityEnabled != null || + options.RuntimeMcpDmlToolsAggregateRecordsEnabled != null) { McpRuntimeOptions updatedMcpOptions = runtimeConfig?.Runtime?.Mcp ?? new(); bool status = TryUpdateConfiguredMcpValues(options, ref updatedMcpOptions); @@ -1161,6 +1163,14 @@ private static bool TryUpdateConfiguredMcpValues( _logger.LogInformation("Updated RuntimeConfig with Runtime.Mcp.Description as '{updatedValue}'", updatedValue); } + // Runtime.Mcp.QueryTimeout + updatedValue = options?.RuntimeMcpQueryTimeout; + if (updatedValue != null) + { + updatedMcpOptions = updatedMcpOptions! with { QueryTimeout = (int)updatedValue }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Mcp.QueryTimeout as '{updatedValue}'", updatedValue); + } + // Handle DML tools configuration bool hasToolUpdates = false; DmlToolsConfig? currentDmlTools = updatedMcpOptions?.DmlTools; diff --git a/src/Config/ObjectModel/McpRuntimeOptions.cs b/src/Config/ObjectModel/McpRuntimeOptions.cs index e17d53fc8f..324e0caa55 100644 --- a/src/Config/ObjectModel/McpRuntimeOptions.cs +++ b/src/Config/ObjectModel/McpRuntimeOptions.cs @@ -10,6 +10,7 @@ namespace Azure.DataApiBuilder.Config.ObjectModel; public record McpRuntimeOptions { public const string DEFAULT_PATH = "/mcp"; + public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 10; /// /// Whether MCP endpoints are enabled @@ -36,12 +37,21 @@ public record McpRuntimeOptions [JsonPropertyName("description")] public string? Description { get; init; } + /// + /// Query timeout in seconds for MCP tool operations. + /// This timeout is applied to database queries executed by MCP tools. + /// Default: 10 seconds. + /// + [JsonPropertyName("query-timeout")] + public int? QueryTimeout { get; init; } + [JsonConstructor] public McpRuntimeOptions( bool? Enabled = null, string? Path = null, DmlToolsConfig? DmlTools = null, - string? Description = null) + string? Description = null, + int? QueryTimeout = null) { this.Enabled = Enabled ?? true; @@ -67,6 +77,12 @@ public McpRuntimeOptions( } this.Description = Description; + + if (QueryTimeout is not null) + { + this.QueryTimeout = QueryTimeout; + UserProvidedQueryTimeout = true; + } } /// @@ -78,4 +94,17 @@ public McpRuntimeOptions( [JsonIgnore(Condition = JsonIgnoreCondition.Always)] [MemberNotNullWhen(true, nameof(Enabled))] public bool UserProvidedPath { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write query-timeout + /// property and value to the runtime config file. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedQueryTimeout { get; init; } = false; + + /// + /// Gets the effective query timeout in seconds, using the default if not specified. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public int EffectiveQueryTimeoutSeconds => QueryTimeout ?? DEFAULT_QUERY_TIMEOUT_SECONDS; } diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index f5112844da..ea2299bc6f 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -914,6 +914,16 @@ public void ValidateMcpUri(RuntimeConfig runtimeConfig) statusCode: HttpStatusCode.ServiceUnavailable, subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); } + + // Validate query-timeout if provided + if (runtimeConfig.Runtime.Mcp.QueryTimeout is not null && runtimeConfig.Runtime.Mcp.QueryTimeout < 1) + { + HandleOrRecordException(new DataApiBuilderException( + message: "MCP query-timeout must be a positive integer (>= 1 second). " + + $"Provided value: {runtimeConfig.Runtime.Mcp.QueryTimeout}.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } } private void ValidateAuthenticationOptions(RuntimeConfig runtimeConfig) From 3ccc7482951ef006eecfb390a1d2ab2aec159e55 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:20:20 +0000 Subject: [PATCH 07/32] Update query-timeout default to 30s, add converter support, apply timeout to all MCP tools, add tests Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- schemas/dab.draft.schema.json | 4 +- .../BuiltInTools/AggregateRecordsTool.cs | 4 +- .../Utils/McpTelemetryHelper.cs | 29 +- src/Cli/Commands/ConfigureOptions.cs | 2 +- .../McpRuntimeOptionsConverterFactory.cs | 17 +- src/Config/ObjectModel/McpRuntimeOptions.cs | 9 +- src/Service.Tests/Mcp/McpQueryTimeoutTests.cs | 452 ++++++++++++++++++ 7 files changed, 505 insertions(+), 12 deletions(-) create mode 100644 src/Service.Tests/Mcp/McpQueryTimeoutTests.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 94df7ca77c..e78861807d 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -277,8 +277,8 @@ }, "query-timeout": { "type": "integer", - "description": "Query timeout in seconds for MCP tool operations. Applies to all MCP tools that execute database queries.", - "default": 10, + "description": "Execution timeout in seconds for MCP tool operations. Applies to all MCP tools.", + "default": 30, "minimum": 1 }, "dml-tools": { diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 59bd465ad0..fa75cd2fb9 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -35,7 +35,7 @@ public class AggregateRecordsTool : IMcpTool { public ToolType ToolType { get; } = ToolType.BuiltIn; - private static readonly HashSet ValidFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; + private static readonly HashSet _validFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; public Tool GetToolMetadata() { @@ -183,7 +183,7 @@ public async Task ExecuteAsync( } string function = funcEl.GetString()!.ToLowerInvariant(); - if (!ValidFunctions.Contains(function)) + if (!_validFunctions.Contains(function)) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger); } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs index eabbdc62d8..105bb57ced 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs @@ -60,8 +60,33 @@ public static async Task ExecuteWithTelemetryAsync( operation: operation, dbProcedure: dbProcedure); - // Execute the tool - CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, cancellationToken); + // Read query-timeout from current config per invocation (hot-reload safe). + int timeoutSeconds = McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; + RuntimeConfigProvider? runtimeConfigProvider = serviceProvider.GetService(); + if (runtimeConfigProvider is not null) + { + RuntimeConfig config = runtimeConfigProvider.GetConfig(); + timeoutSeconds = config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds ?? McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; + } + + // Wrap tool execution with the configured timeout using a linked CancellationTokenSource. + using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds)); + + CallToolResult result; + try + { + result = await tool.ExecuteAsync(arguments, serviceProvider, timeoutCts.Token); + } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + // The timeout CTS fired, not the caller's token. Surface as TimeoutException + // so downstream telemetry and tool handlers see TIMEOUT, not cancellation. + throw new TimeoutException( + $"The MCP tool '{toolName}' did not complete within {timeoutSeconds} seconds. " + + "This is NOT a tool error. The operation exceeded the configured query-timeout. " + + "Try narrowing results with a filter, reducing groupby fields, or using pagination."); + } // Check if the tool returned an error result (tools catch exceptions internally // and return CallToolResult with IsError=true instead of throwing) diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index fc0ab7b8e5..bf12cd5199 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -205,7 +205,7 @@ public ConfigureOptions( [Option("runtime.mcp.description", Required = false, HelpText = "Set the MCP server description to be exposed in the initialize response.")] public string? RuntimeMcpDescription { get; } - [Option("runtime.mcp.query-timeout", Required = false, HelpText = "Set the query timeout in seconds for MCP tool operations. Default: 10 (integer). Must be >= 1.")] + [Option("runtime.mcp.query-timeout", Required = false, HelpText = "Set the execution timeout in seconds for MCP tool operations. Applies to all MCP tools. Default: 30 (integer). Must be >= 1.")] public int? RuntimeMcpQueryTimeout { get; } [Option("runtime.mcp.dml-tools.enabled", Required = false, HelpText = "Enable DAB's MCP DML tools endpoint. Default: true (boolean).")] diff --git a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs index 8b3c640725..ad4edc229e 100644 --- a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs +++ b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs @@ -66,12 +66,13 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings? string? path = null; DmlToolsConfig? dmlTools = null; string? description = null; + int? queryTimeout = null; while (reader.Read()) { if (reader.TokenType == JsonTokenType.EndObject) { - return new McpRuntimeOptions(enabled, path, dmlTools, description); + return new McpRuntimeOptions(enabled, path, dmlTools, description, queryTimeout); } string? propertyName = reader.GetString(); @@ -107,6 +108,14 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings? break; + case "query-timeout": + if (reader.TokenType is not JsonTokenType.Null) + { + queryTimeout = reader.GetInt32(); + } + + break; + default: throw new JsonException($"Unexpected property {propertyName}"); } @@ -150,6 +159,12 @@ public override void Write(Utf8JsonWriter writer, McpRuntimeOptions value, JsonS JsonSerializer.Serialize(writer, value.Description, options); } + // Write query-timeout if it's user provided + if (value?.UserProvidedQueryTimeout is true && value.QueryTimeout.HasValue) + { + writer.WriteNumber("query-timeout", value.QueryTimeout.Value); + } + writer.WriteEndObject(); } } diff --git a/src/Config/ObjectModel/McpRuntimeOptions.cs b/src/Config/ObjectModel/McpRuntimeOptions.cs index 324e0caa55..f4b4281a14 100644 --- a/src/Config/ObjectModel/McpRuntimeOptions.cs +++ b/src/Config/ObjectModel/McpRuntimeOptions.cs @@ -10,7 +10,7 @@ namespace Azure.DataApiBuilder.Config.ObjectModel; public record McpRuntimeOptions { public const string DEFAULT_PATH = "/mcp"; - public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 10; + public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 30; /// /// Whether MCP endpoints are enabled @@ -38,9 +38,10 @@ public record McpRuntimeOptions public string? Description { get; init; } /// - /// Query timeout in seconds for MCP tool operations. - /// This timeout is applied to database queries executed by MCP tools. - /// Default: 10 seconds. + /// Execution timeout in seconds for MCP tool operations. + /// This timeout wraps the entire tool execution including database queries. + /// It applies to all MCP tools, not just aggregation. + /// Default: 30 seconds. /// [JsonPropertyName("query-timeout")] public int? QueryTimeout { get; init; } diff --git a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs new file mode 100644 index 0000000000..f5b29f2b8a --- /dev/null +++ b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs @@ -0,0 +1,452 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for the MCP query-timeout configuration property. + /// Verifies: + /// - Default value of 30 seconds when not configured + /// - Custom value overrides default + /// - Timeout wrapping applies to all MCP tools via ExecuteWithTelemetryAsync + /// - Hot reload: changing config value updates behavior without restart + /// - Timeout surfaces as TimeoutException, not generic cancellation + /// - Telemetry maps timeout to TIMEOUT error code + /// + [TestClass] + public class McpQueryTimeoutTests + { + #region Default Value Tests + + [TestMethod] + public void McpRuntimeOptions_DefaultQueryTimeout_Is30Seconds() + { + Assert.AreEqual(30, McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS); + } + + [TestMethod] + public void McpRuntimeOptions_EffectiveTimeout_ReturnsDefault_WhenNotConfigured() + { + McpRuntimeOptions options = new(); + Assert.IsNull(options.QueryTimeout); + Assert.AreEqual(30, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_EffectiveTimeout_ReturnsConfiguredValue() + { + McpRuntimeOptions options = new(QueryTimeout: 60); + Assert.AreEqual(60, options.QueryTimeout); + Assert.AreEqual(60, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_UserProvidedQueryTimeout_FalseByDefault() + { + McpRuntimeOptions options = new(); + Assert.IsFalse(options.UserProvidedQueryTimeout); + } + + [TestMethod] + public void McpRuntimeOptions_UserProvidedQueryTimeout_TrueWhenSet() + { + McpRuntimeOptions options = new(QueryTimeout: 45); + Assert.IsTrue(options.UserProvidedQueryTimeout); + } + + #endregion + + #region Custom Value Tests + + [TestMethod] + public void McpRuntimeOptions_CustomTimeout_1Second() + { + McpRuntimeOptions options = new(QueryTimeout: 1); + Assert.AreEqual(1, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_CustomTimeout_120Seconds() + { + McpRuntimeOptions options = new(QueryTimeout: 120); + Assert.AreEqual(120, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void RuntimeConfig_McpQueryTimeout_ExposedInConfig() + { + RuntimeConfig config = CreateConfigWithQueryTimeout(45); + Assert.AreEqual(45, config.Runtime?.Mcp?.QueryTimeout); + Assert.AreEqual(45, config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void RuntimeConfig_McpQueryTimeout_DefaultWhenNotSet() + { + RuntimeConfig config = CreateConfigWithoutQueryTimeout(); + Assert.IsNull(config.Runtime?.Mcp?.QueryTimeout); + Assert.AreEqual(30, config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds); + } + + #endregion + + #region Timeout Wrapping Tests + + [TestMethod] + public async Task ExecuteWithTelemetry_CompletesSuccessfully_WithinTimeout() + { + // A tool that completes immediately should succeed + RuntimeConfig config = CreateConfigWithQueryTimeout(30); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new ImmediateCompletionTool(); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "test_tool", null, sp, CancellationToken.None); + + // Tool should complete without throwing TimeoutException + Assert.IsNotNull(result); + Assert.IsTrue(result.IsError != true, "Tool result should not be an error"); + } + + [TestMethod] + public async Task ExecuteWithTelemetry_ThrowsTimeoutException_WhenToolExceedsTimeout() + { + // Configure a very short timeout (1 second) and a tool that takes longer + RuntimeConfig config = CreateConfigWithQueryTimeout(1); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new SlowTool(delaySeconds: 30); + + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "slow_tool", null, sp, CancellationToken.None); + }); + } + + [TestMethod] + public async Task ExecuteWithTelemetry_TimeoutMessage_ContainsToolName() + { + RuntimeConfig config = CreateConfigWithQueryTimeout(1); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new SlowTool(delaySeconds: 30); + + try + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "aggregate_records", null, sp, CancellationToken.None); + Assert.Fail("Expected TimeoutException"); + } + catch (TimeoutException ex) + { + Assert.IsTrue(ex.Message.Contains("aggregate_records"), "Message should contain tool name"); + Assert.IsTrue(ex.Message.Contains("1 seconds"), "Message should contain timeout value"); + Assert.IsTrue(ex.Message.Contains("NOT a tool error"), "Message should clarify it is not a tool error"); + } + } + + [TestMethod] + public async Task ExecuteWithTelemetry_ClientCancellation_PropagatesAsCancellation() + { + // Client cancellation (not timeout) should propagate as OperationCanceledException + // rather than being converted to TimeoutException. + RuntimeConfig config = CreateConfigWithQueryTimeout(30); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new SlowTool(delaySeconds: 30); + + using CancellationTokenSource cts = new(); + cts.Cancel(); // Cancel immediately + + try + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "test_tool", null, sp, cts.Token); + Assert.Fail("Expected OperationCanceledException or subclass to be thrown"); + } + catch (TimeoutException) + { + Assert.Fail("Client cancellation should NOT be converted to TimeoutException"); + } + catch (OperationCanceledException) + { + // Expected: client-initiated cancellation propagates as OperationCanceledException + // (or subclass TaskCanceledException) + } + } + + [TestMethod] + public async Task ExecuteWithTelemetry_AppliesTimeout_ToAllToolTypes() + { + // Verify timeout applies to both built-in and custom tool types + RuntimeConfig config = CreateConfigWithQueryTimeout(1); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + + // Test with built-in tool type + IMcpTool builtInTool = new SlowTool(delaySeconds: 30, toolType: ToolType.BuiltIn); + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + builtInTool, "builtin_slow", null, sp, CancellationToken.None); + }); + + // Test with custom tool type + IMcpTool customTool = new SlowTool(delaySeconds: 30, toolType: ToolType.Custom); + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + customTool, "custom_slow", null, sp, CancellationToken.None); + }); + } + + #endregion + + #region Hot Reload Tests + + [TestMethod] + public async Task ExecuteWithTelemetry_ReadsConfigPerInvocation_HotReload() + { + // First invocation with long timeout should succeed + RuntimeConfig config1 = CreateConfigWithQueryTimeout(30); + IServiceProvider sp1 = CreateServiceProviderWithConfig(config1); + + IMcpTool fastTool = new ImmediateCompletionTool(); + CallToolResult result1 = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + fastTool, "test_tool", null, sp1, CancellationToken.None); + Assert.IsNotNull(result1); + + // Second invocation with very short timeout and a slow tool should timeout. + // This demonstrates that each invocation reads the current config value. + RuntimeConfig config2 = CreateConfigWithQueryTimeout(1); + IServiceProvider sp2 = CreateServiceProviderWithConfig(config2); + + IMcpTool slowTool = new SlowTool(delaySeconds: 30); + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + slowTool, "test_tool", null, sp2, CancellationToken.None); + }); + } + + #endregion + + #region Telemetry Tests + + [TestMethod] + public void MapExceptionToErrorCode_TimeoutException_ReturnsTIMEOUT() + { + string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new TimeoutException()); + Assert.AreEqual(McpTelemetryErrorCodes.TIMEOUT, errorCode); + } + + [TestMethod] + public void MapExceptionToErrorCode_OperationCanceled_ReturnsOPERATION_CANCELLED() + { + string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new OperationCanceledException()); + Assert.AreEqual(McpTelemetryErrorCodes.OPERATION_CANCELLED, errorCode); + } + + [TestMethod] + public void MapExceptionToErrorCode_TaskCanceled_ReturnsOPERATION_CANCELLED() + { + string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new TaskCanceledException()); + Assert.AreEqual(McpTelemetryErrorCodes.OPERATION_CANCELLED, errorCode); + } + + #endregion + + #region JSON Serialization Tests + + [TestMethod] + public void McpRuntimeOptions_Serialization_IncludesQueryTimeout_WhenUserProvided() + { + McpRuntimeOptions options = new(QueryTimeout: 45); + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(); + string json = JsonSerializer.Serialize(options, serializerOptions); + Assert.IsTrue(json.Contains("\"query-timeout\": 45") || json.Contains("\"query-timeout\":45")); + } + + [TestMethod] + public void McpRuntimeOptions_Deserialization_ReadsQueryTimeout() + { + string json = @"{""enabled"": true, ""query-timeout"": 60}"; + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(); + McpRuntimeOptions? options = JsonSerializer.Deserialize(json, serializerOptions); + Assert.IsNotNull(options); + Assert.AreEqual(60, options.QueryTimeout); + Assert.AreEqual(60, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_Deserialization_DefaultsWhenOmitted() + { + string json = @"{""enabled"": true}"; + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(); + McpRuntimeOptions? options = JsonSerializer.Deserialize(json, serializerOptions); + Assert.IsNotNull(options); + Assert.IsNull(options.QueryTimeout); + Assert.AreEqual(30, options.EffectiveQueryTimeoutSeconds); + } + + #endregion + + #region Helpers + + private static RuntimeConfig CreateConfigWithQueryTimeout(int timeoutSeconds) + { + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + QueryTimeout: timeoutSeconds, + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(new Dictionary()) + ); + } + + private static RuntimeConfig CreateConfigWithoutQueryTimeout() + { + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(new Dictionary()) + ); + } + + private static IServiceProvider CreateServiceProviderWithConfig(RuntimeConfig config) + { + ServiceCollection services = new(); + RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + services.AddLogging(); + return services.BuildServiceProvider(); + } + + /// + /// A mock tool that completes immediately with a success result. + /// + private class ImmediateCompletionTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "test_tool", + Description = "A test tool that completes immediately", + InputSchema = doc.RootElement.Clone() + }; + } + + public Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + return Task.FromResult(new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"success\"}" } + } + }); + } + } + + /// + /// A mock tool that delays for a specified duration, respecting cancellation. + /// Used to test timeout behavior. + /// + private class SlowTool : IMcpTool + { + private readonly int _delaySeconds; + + public SlowTool(int delaySeconds, ToolType toolType = ToolType.BuiltIn) + { + _delaySeconds = delaySeconds; + ToolType = toolType; + } + + public ToolType ToolType { get; } + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "slow_tool", + Description = "A test tool that takes a long time", + InputSchema = doc.RootElement.Clone() + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + await Task.Delay(TimeSpan.FromSeconds(_delaySeconds), cancellationToken); + return new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"completed\"}" } + } + }; + } + } + + #endregion + } +} From 381899d84dbb8ff0827b7ebcafce947f9c230ccc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:33:35 +0000 Subject: [PATCH 08/32] Fix group key collision using \\0 delimiter, add #nullable enable to test file Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 4 +++- src/Service.Tests/Mcp/AggregateRecordsToolTests.cs | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index fa75cd2fb9..f19de23eeb 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -706,7 +706,9 @@ private static string BuildGroupKey(JsonElement record, List groupby) } } - return string.Join("|", parts); + // Use null character (\0) as delimiter to avoid collisions with + // field values that may contain printable characters like '|'. + return string.Join("\0", parts); } private static Dictionary ExtractGroupFields(JsonElement record, List groupby) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 9255a2e9c5..4a23a8c1fc 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#nullable enable + using System; using System.Collections.Generic; using System.Text.Json; From fde4d65a84da58b2c2aa009afe72dfbb3c71afa1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:53:42 +0000 Subject: [PATCH 09/32] Fix nullable warnings in AggregateRecordsToolTests.cs Co-authored-by: anushakolan <45540936+anushakolan@users.noreply.github.com> --- src/Service.Tests/Mcp/AggregateRecordsToolTests.cs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 4a23a8c1fc..67477a9d2f 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -643,14 +643,15 @@ public async Task AggregateRecords_OperationCanceled_ReturnsExplicitCanceledMess Assert.IsTrue(result.IsError == true); JsonElement content = ParseContent(result); Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); - string errorType = error.GetProperty("type").GetString(); - string errorMessage = error.GetProperty("message").GetString(); + string? errorType = error.GetProperty("type").GetString(); + string? errorMessage = error.GetProperty("message").GetString(); // Verify the error type identifies it as a cancellation Assert.AreEqual("OperationCanceled", errorType); // Verify the message explicitly tells the model this is NOT a tool error - Assert.IsTrue(errorMessage.Contains("NOT a tool error"), "Message must explicitly state this is NOT a tool error."); + Assert.IsNotNull(errorMessage); + Assert.IsTrue(errorMessage!.Contains("NOT a tool error"), "Message must explicitly state this is NOT a tool error."); // Verify the message tells the model what happened Assert.IsTrue(errorMessage.Contains("canceled"), "Message must mention the operation was canceled."); From ba371d52bd7dd5d9d8d7c65e7dbd648f53b40b05 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:54:58 +0000 Subject: [PATCH 10/32] Add null check for errorType in AggregateRecordsToolTests Co-authored-by: anushakolan <45540936+anushakolan@users.noreply.github.com> --- src/Service.Tests/Mcp/AggregateRecordsToolTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 67477a9d2f..ce578e746e 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -647,6 +647,7 @@ public async Task AggregateRecords_OperationCanceled_ReturnsExplicitCanceledMess string? errorMessage = error.GetProperty("message").GetString(); // Verify the error type identifies it as a cancellation + Assert.IsNotNull(errorType); Assert.AreEqual("OperationCanceled", errorType); // Verify the message explicitly tells the model this is NOT a tool error From d340cb43bd3b9682113cc36aaaaaad410bb5afd4 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 13:52:27 -0700 Subject: [PATCH 11/32] Apply validation fixes and additional tests from copilot/update-aggregate-records-tool-fixes --- .../BuiltInTools/AggregateRecordsTool.cs | 29 +- .../Utils/McpTelemetryHelper.cs | 2 +- src/Service.Tests/Mcp/McpQueryTimeoutTests.cs | 2 +- .../UnitTests/AggregateRecordsToolTests.cs | 411 ++++++++++++++++++ .../UnitTests/McpTelemetryTests.cs | 168 ++++++- 5 files changed, 603 insertions(+), 9 deletions(-) create mode 100644 src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index f19de23eeb..b8dd85c175 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -52,16 +52,15 @@ public Tool GetToolMetadata() + "orderby ('asc' or 'desc' to sort grouped results by aggregated value; requires groupby), " + "having (object to filter groups after aggregating, operators: eq, neq, gt, gte, lt, lte, in; requires groupby), " + "first (integer >= 1, maximum grouped results to return; requires groupby), " - + "after (opaque cursor string from a previous response's endCursor; requires first and groupby). " + + "after (opaque cursor string from a previous response's endCursor for pagination). " + "RESPONSE: The aggregated value is aliased as '{function}_{field}' (e.g. avg_unitPrice, sum_revenue). " + "For count with field '*', the alias is 'count'. " + "When first is used with groupby, response contains: items (array), endCursor (string), hasNextPage (boolean). " + "RULES: 1) ALWAYS call describe_entities first to get valid entity and field names. " + "2) Use field '*' ONLY with function 'count'. " + "3) For avg, sum, min, max: field MUST be a numeric field name from describe_entities. " - + "4) orderby, having, first, and after ONLY apply when groupby is provided. " - + "5) after REQUIRES first to also be set. " - + "6) Use first and after for paginating large grouped result sets.", + + "4) orderby, having, and first ONLY apply when groupby is provided. " + + "5) Use first and after for paginating large grouped result sets.", InputSchema = JsonSerializer.Deserialize( @"{ ""type"": ""object"", @@ -194,7 +193,25 @@ public async Task ExecuteAsync( } string field = fieldEl.GetString()!; + + // Validate field/function compatibility + bool isCountStar = function == "count" && field == "*"; + + if (field == "*" && function != "count") + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"Field '*' is only valid with function 'count'. For function '{function}', provide a specific field name.", logger); + } + bool distinct = root.TryGetProperty("distinct", out JsonElement distinctEl) && distinctEl.GetBoolean(); + + // Reject count(*) with distinct as it is semantically undefined + if (isCountStar && distinct) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "Cannot use distinct=true with field='*'. DISTINCT requires a specific field name. Use a field name instead of '*' to count distinct values.", logger); + } + string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null; string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc"; @@ -285,7 +302,6 @@ public async Task ExecuteAsync( // Build select list: groupby fields + aggregation field List selectFields = new(groupby); - bool isCountStar = function == "count" && field == "*"; if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase)) { selectFields.Add(field); @@ -610,7 +626,8 @@ internal static PaginationResult ApplyPagination( { if (isCountStar) { - return distinct ? 0 : records.Count; + // count(*) always counts all rows; distinct is rejected at ExecuteAsync validation level + return records.Count; } List values = new(); diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs index 105bb57ced..ac567d4d8c 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs @@ -83,7 +83,7 @@ public static async Task ExecuteWithTelemetryAsync( // The timeout CTS fired, not the caller's token. Surface as TimeoutException // so downstream telemetry and tool handlers see TIMEOUT, not cancellation. throw new TimeoutException( - $"The MCP tool '{toolName}' did not complete within {timeoutSeconds} seconds. " + $"The MCP tool '{toolName}' did not complete within {timeoutSeconds} {(timeoutSeconds == 1 ? "second" : "seconds")}. " + "This is NOT a tool error. The operation exceeded the configured query-timeout. " + "Try narrowing results with a filter, reducing groupby fields, or using pagination."); } diff --git a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs index f5b29f2b8a..0f5ee3951a 100644 --- a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs +++ b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs @@ -156,7 +156,7 @@ await McpTelemetryHelper.ExecuteWithTelemetryAsync( catch (TimeoutException ex) { Assert.IsTrue(ex.Message.Contains("aggregate_records"), "Message should contain tool name"); - Assert.IsTrue(ex.Message.Contains("1 seconds"), "Message should contain timeout value"); + Assert.IsTrue(ex.Message.Contains("1 second"), "Message should contain timeout value"); Assert.IsTrue(ex.Message.Contains("NOT a tool error"), "Message should clarify it is not a tool error"); } } diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs new file mode 100644 index 0000000000..dee8842a0d --- /dev/null +++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs @@ -0,0 +1,411 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System.Collections.Generic; +using System.Text.Json; +using Azure.DataApiBuilder.Mcp.BuiltInTools; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + /// + /// Unit tests for AggregateRecordsTool's internal helper methods. + /// Covers validation paths, aggregation logic, and pagination behavior. + /// + [TestClass] + public class AggregateRecordsToolTests + { + #region ComputeAlias tests + + [TestMethod] + [DataRow("count", "*", "count", DisplayName = "count(*) alias is 'count'")] + [DataRow("count", "userId", "count_userId", DisplayName = "count(field) alias is 'count_field'")] + [DataRow("avg", "price", "avg_price", DisplayName = "avg alias")] + [DataRow("sum", "amount", "sum_amount", DisplayName = "sum alias")] + [DataRow("min", "age", "min_age", DisplayName = "min alias")] + [DataRow("max", "score", "max_score", DisplayName = "max alias")] + public void ComputeAlias_ReturnsExpectedAlias(string function, string field, string expectedAlias) + { + string result = AggregateRecordsTool.ComputeAlias(function, field); + Assert.AreEqual(expectedAlias, result); + } + + #endregion + + #region PerformAggregation tests - no groupby + + private static JsonElement CreateRecordsArray(params double[] values) + { + var list = new List(); + foreach (double v in values) + { + list.Add(new Dictionary { ["value"] = v }); + } + + string json = JsonSerializer.Serialize(list); + return JsonDocument.Parse(json).RootElement.Clone(); + } + + private static JsonElement CreateEmptyArray() + { + return JsonDocument.Parse("[]").RootElement.Clone(); + } + + private static JsonElement CreateMixedArray() + { + // Records where some have 'value' (numeric) and some have 'category' (string) + string json = """ + [ + {"value": 10.0, "category": "A"}, + {"value": 20.0, "category": "B"}, + {"value": 10.0, "category": "A"} + ] + """; + return JsonDocument.Parse(json).RootElement.Clone(); + } + + [TestMethod] + public void PerformAggregation_CountStar_NoGroupBy_ReturnsCount() + { + JsonElement records = CreateRecordsArray(1, 2, 3, 4, 5); + var result = AggregateRecordsTool.PerformAggregation( + records, "count", "*", distinct: false, new List(), null, null, "desc", "count"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(5.0, result[0]["count"]); + } + + [TestMethod] + public void PerformAggregation_CountField_NoGroupBy_CountsNumericValues() + { + JsonElement records = CreateRecordsArray(10.0, 20.0, 30.0); + var result = AggregateRecordsTool.PerformAggregation( + records, "count", "value", distinct: false, new List(), null, null, "desc", "count_value"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(3.0, result[0]["count_value"]); + } + + [TestMethod] + public void PerformAggregation_CountField_Distinct_CountsUniqueValues() + { + JsonElement records = CreateRecordsArray(10.0, 20.0, 10.0); + var result = AggregateRecordsTool.PerformAggregation( + records, "count", "value", distinct: true, new List(), null, null, "desc", "count_value"); + + Assert.AreEqual(1, result.Count); + // 10 and 20 are the distinct values + Assert.AreEqual(2.0, result[0]["count_value"]); + } + + [TestMethod] + public void PerformAggregation_Avg_NoGroupBy_ReturnsAverage() + { + JsonElement records = CreateRecordsArray(10.0, 20.0, 30.0); + var result = AggregateRecordsTool.PerformAggregation( + records, "avg", "value", distinct: false, new List(), null, null, "desc", "avg_value"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(20.0, result[0]["avg_value"]); + } + + [TestMethod] + public void PerformAggregation_Sum_NoGroupBy_ReturnsSum() + { + JsonElement records = CreateRecordsArray(10.0, 20.0, 30.0); + var result = AggregateRecordsTool.PerformAggregation( + records, "sum", "value", distinct: false, new List(), null, null, "desc", "sum_value"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(60.0, result[0]["sum_value"]); + } + + [TestMethod] + public void PerformAggregation_Min_NoGroupBy_ReturnsMinimum() + { + JsonElement records = CreateRecordsArray(30.0, 10.0, 20.0); + var result = AggregateRecordsTool.PerformAggregation( + records, "min", "value", distinct: false, new List(), null, null, "desc", "min_value"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(10.0, result[0]["min_value"]); + } + + [TestMethod] + public void PerformAggregation_Max_NoGroupBy_ReturnsMaximum() + { + JsonElement records = CreateRecordsArray(30.0, 10.0, 20.0); + var result = AggregateRecordsTool.PerformAggregation( + records, "max", "value", distinct: false, new List(), null, null, "desc", "max_value"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(30.0, result[0]["max_value"]); + } + + [TestMethod] + public void PerformAggregation_EmptyRecords_ReturnsNullForNumericFunctions() + { + JsonElement records = CreateEmptyArray(); + var result = AggregateRecordsTool.PerformAggregation( + records, "avg", "value", distinct: false, new List(), null, null, "desc", "avg_value"); + + Assert.AreEqual(1, result.Count); + Assert.IsNull(result[0]["avg_value"]); + } + + [TestMethod] + public void PerformAggregation_EmptyRecords_CountStar_ReturnsZero() + { + JsonElement records = CreateEmptyArray(); + var result = AggregateRecordsTool.PerformAggregation( + records, "count", "*", distinct: false, new List(), null, null, "desc", "count"); + + Assert.AreEqual(1, result.Count); + Assert.AreEqual(0.0, result[0]["count"]); + } + + #endregion + + #region PerformAggregation tests - with groupby + + [TestMethod] + public void PerformAggregation_GroupBy_CountStar_ReturnsGroupCounts() + { + JsonElement records = CreateMixedArray(); + var groupby = new List { "category" }; + + var result = AggregateRecordsTool.PerformAggregation( + records, "count", "*", distinct: false, groupby, null, null, "desc", "count"); + + Assert.AreEqual(2, result.Count); + // desc ordering: A has 2, B has 1 + Assert.AreEqual("A", result[0]["category"]); + Assert.AreEqual(2.0, result[0]["count"]); + Assert.AreEqual("B", result[1]["category"]); + Assert.AreEqual(1.0, result[1]["count"]); + } + + [TestMethod] + public void PerformAggregation_GroupBy_Avg_ReturnsGroupAverages() + { + JsonElement records = CreateMixedArray(); + var groupby = new List { "category" }; + + var result = AggregateRecordsTool.PerformAggregation( + records, "avg", "value", distinct: false, groupby, null, null, "asc", "avg_value"); + + Assert.AreEqual(2, result.Count); + // asc ordering by avg_value: B has 20, A has average (10+10)/2=10 + Assert.AreEqual("A", result[0]["category"]); + Assert.AreEqual(10.0, result[0]["avg_value"]); + Assert.AreEqual("B", result[1]["category"]); + Assert.AreEqual(20.0, result[1]["avg_value"]); + } + + [TestMethod] + public void PerformAggregation_GroupBy_Having_FiltersGroups() + { + JsonElement records = CreateMixedArray(); + var groupby = new List { "category" }; + var havingOps = new Dictionary(System.StringComparer.OrdinalIgnoreCase) + { + ["gt"] = 1.0 // Keep groups with count > 1 + }; + + var result = AggregateRecordsTool.PerformAggregation( + records, "count", "*", distinct: false, groupby, havingOps, null, "desc", "count"); + + // Only category "A" (count=2) should pass count > 1 + Assert.AreEqual(1, result.Count); + Assert.AreEqual("A", result[0]["category"]); + } + + #endregion + + #region Pagination tests + + [TestMethod] + public void ApplyPagination_FirstPage_ReturnsItemsAndCursor() + { + var allResults = new List> + { + new() { ["id"] = 1 }, + new() { ["id"] = 2 }, + new() { ["id"] = 3 }, + new() { ["id"] = 4 }, + new() { ["id"] = 5 } + }; + + var result = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); + + Assert.AreEqual(2, result.Items.Count); + Assert.AreEqual(1, result.Items[0]["id"]); + Assert.AreEqual(2, result.Items[1]["id"]); + Assert.IsTrue(result.HasNextPage); + Assert.IsNotNull(result.EndCursor); + } + + [TestMethod] + public void ApplyPagination_SecondPage_ReturnsCorrectItems() + { + var allResults = new List> + { + new() { ["id"] = 1 }, + new() { ["id"] = 2 }, + new() { ["id"] = 3 }, + new() { ["id"] = 4 }, + new() { ["id"] = 5 } + }; + + // Get first page to obtain cursor + var firstPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); + string? cursor = firstPage.EndCursor; + + // Use cursor to get second page + var secondPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: cursor); + + Assert.AreEqual(2, secondPage.Items.Count); + Assert.AreEqual(3, secondPage.Items[0]["id"]); + Assert.AreEqual(4, secondPage.Items[1]["id"]); + Assert.IsTrue(secondPage.HasNextPage); + } + + [TestMethod] + public void ApplyPagination_LastPage_HasNextPageFalse() + { + var allResults = new List> + { + new() { ["id"] = 1 }, + new() { ["id"] = 2 }, + new() { ["id"] = 3 } + }; + + // Get first page + var firstPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); + // Get last page + var lastPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: firstPage.EndCursor); + + Assert.AreEqual(1, lastPage.Items.Count); + Assert.AreEqual(3, lastPage.Items[0]["id"]); + Assert.IsFalse(lastPage.HasNextPage); + } + + [TestMethod] + public void ApplyPagination_TerminalCursor_ReturnsEmptyItems() + { + var allResults = new List> + { + new() { ["id"] = 1 }, + new() { ["id"] = 2 } + }; + + // Get last page + var lastPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); + Assert.IsFalse(lastPage.HasNextPage); + Assert.IsNotNull(lastPage.EndCursor); + + // Using the terminal endCursor should return empty results + var beyondLastPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: lastPage.EndCursor); + Assert.AreEqual(0, beyondLastPage.Items.Count); + Assert.IsFalse(beyondLastPage.HasNextPage); + Assert.IsNull(beyondLastPage.EndCursor); + } + + [TestMethod] + public void ApplyPagination_InvalidCursor_StartsFromBeginning() + { + var allResults = new List> + { + new() { ["id"] = 1 }, + new() { ["id"] = 2 } + }; + + var result = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: "not-valid-base64!!"); + + // Should start from beginning + Assert.AreEqual(2, result.Items.Count); + Assert.AreEqual(1, result.Items[0]["id"]); + } + + [TestMethod] + public void ApplyPagination_AfterWithoutFirst_IgnoresCursor() + { + // When first is not provided, after should not be used + // (ApplyPagination is only called when first is provided in ExecuteAsync) + var allResults = new List> + { + new() { ["id"] = 1 }, + new() { ["id"] = 2 }, + new() { ["id"] = 3 } + }; + + // Get page 1 cursor + var page1 = AggregateRecordsTool.ApplyPagination(allResults, first: 1, after: null); + Assert.IsNotNull(page1.EndCursor); + + // Call with first=3 and the cursor - should return 2 items from offset 1 + var result = AggregateRecordsTool.ApplyPagination(allResults, first: 3, after: page1.EndCursor); + Assert.AreEqual(2, result.Items.Count); + Assert.AreEqual(2, result.Items[0]["id"]); + } + + #endregion + + #region Validation tests (via ExecuteAsync return codes) + + // Note: Full ExecuteAsync validation tests require a full service provider setup + // with database, auth etc. The validation logic is tested below by examining + // the error condition directly since validation happens before any DB call. + + [TestMethod] + [DataRow("avg", "Validation: avg with star field should be rejected")] + [DataRow("sum", "Validation: sum with star field should be rejected")] + [DataRow("min", "Validation: min with star field should be rejected")] + [DataRow("max", "Validation: max with star field should be rejected")] + public void ValidateFieldFunctionCompat_StarWithNumericFunction_IsInvalid(string function, string description) + { + // Verify the business rule: only count can use field='*' + // This tests the condition used in ExecuteAsync without needing a full service provider + bool isCountStar = function == "count" && "*" == "*"; + bool isInvalidStarUsage = "*" == "*" && function != "count"; + + Assert.IsFalse(isCountStar, $"{description}: should not be count-star"); + Assert.IsTrue(isInvalidStarUsage, $"{description}: should be identified as invalid star usage"); + } + + [TestMethod] + public void ValidateFieldFunctionCompat_CountStar_IsValid() + { + // count with field='*' should be valid + bool isCountStar = "count" == "count" && "*" == "*"; + Assert.IsTrue(isCountStar, "count(*) should be valid"); + } + + [TestMethod] + public void ValidateDistinctCountStar_IsInvalid() + { + // count(*) with distinct=true should be rejected + // Verify the condition used in ExecuteAsync + bool isCountStar = "count" == "count" && "*" == "*"; + bool distinct = true; + + bool shouldReject = isCountStar && distinct; + Assert.IsTrue(shouldReject, "count(*) with distinct=true should be rejected"); + } + + [TestMethod] + public void ValidateDistinctCountField_IsValid() + { + // count(field) with distinct=true should be valid + bool isCountStar = "count" == "count" && "userId" == "*"; + bool distinct = true; + + bool shouldReject = isCountStar && distinct; + Assert.IsFalse(shouldReject, "count(field) with distinct=true should be valid"); + } + + #endregion + } +} diff --git a/src/Service.Tests/UnitTests/McpTelemetryTests.cs b/src/Service.Tests/UnitTests/McpTelemetryTests.cs index 18c043d4dd..61a9834a02 100644 --- a/src/Service.Tests/UnitTests/McpTelemetryTests.cs +++ b/src/Service.Tests/UnitTests/McpTelemetryTests.cs @@ -17,7 +17,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using ModelContextProtocol.Protocol; using static Azure.DataApiBuilder.Mcp.Model.McpEnums; - namespace Azure.DataApiBuilder.Service.Tests.UnitTests { /// @@ -337,6 +336,98 @@ public async Task ExecuteWithTelemetryAsync_RecordsExceptionAndRethrows_WhenTool Assert.IsNotNull(exceptionEvent, "Exception event should be recorded"); } + /// + /// Test that ExecuteWithTelemetryAsync applies the configured query-timeout and throws TimeoutException + /// when a tool exceeds the configured timeout. + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_ThrowsTimeoutException_WhenToolExceedsTimeout() + { + // Use a 1-second timeout with a tool that takes 10 seconds + IServiceProvider serviceProvider = CreateServiceProviderWithTimeout(queryTimeoutSeconds: 1); + IMcpTool tool = new SlowTool(delaySeconds: 10); + + TimeoutException thrownEx = await Assert.ThrowsExceptionAsync( + () => McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "aggregate_records", arguments: null, serviceProvider, CancellationToken.None)); + + Assert.IsTrue(thrownEx.Message.Contains("aggregate_records"), "Exception message should contain tool name"); + Assert.IsTrue(thrownEx.Message.Contains("1 second"), "Exception message should contain timeout duration"); + } + + /// + /// Test that ExecuteWithTelemetryAsync succeeds when tool completes before the timeout. + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_Succeeds_WhenToolCompletesBeforeTimeout() + { + // Use a 30-second timeout with a tool that completes immediately + IServiceProvider serviceProvider = CreateServiceProviderWithTimeout(queryTimeoutSeconds: 30); + IMcpTool tool = new ImmediateCompletionTool(); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "aggregate_records", arguments: null, serviceProvider, CancellationToken.None); + + Assert.IsNotNull(result); + Assert.IsFalse(result.IsError == true); + } + + /// + /// Test that aggregate_records tool name maps to "aggregate" operation. + /// + [TestMethod] + public void InferOperationFromTool_AggregateRecords_ReturnsAggregate() + { + CallToolResult dummyResult = CreateToolResult("ok"); + IMcpTool tool = new MockMcpTool(dummyResult, ToolType.BuiltIn); + + string operation = McpTelemetryHelper.InferOperationFromTool(tool, "aggregate_records"); + + Assert.AreEqual("aggregate", operation); + } + + #endregion + + #region Helpers for timeout tests + + /// + /// Creates a service provider with a RuntimeConfigProvider configured with the given timeout. + /// + private static IServiceProvider CreateServiceProviderWithTimeout(int queryTimeoutSeconds) + { + Azure.DataApiBuilder.Config.ObjectModel.RuntimeConfig config = CreateConfigWithQueryTimeout(queryTimeoutSeconds); + ServiceCollection services = new(); + Azure.DataApiBuilder.Core.Configurations.RuntimeConfigProvider configProvider = + TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + services.AddLogging(); + return services.BuildServiceProvider(); + } + + private static Azure.DataApiBuilder.Config.ObjectModel.RuntimeConfig CreateConfigWithQueryTimeout(int queryTimeoutSeconds) + { + return new Azure.DataApiBuilder.Config.ObjectModel.RuntimeConfig( + Schema: "test-schema", + DataSource: new Azure.DataApiBuilder.Config.ObjectModel.DataSource( + DatabaseType: Azure.DataApiBuilder.Config.ObjectModel.DatabaseType.MSSQL, + ConnectionString: "", + Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: null, + Description: null, + QueryTimeout: queryTimeoutSeconds + ), + Host: new(Cors: null, Authentication: null, Mode: Azure.DataApiBuilder.Config.ObjectModel.HostMode.Development) + ), + Entities: new(new System.Collections.Generic.Dictionary()) + ); + } + #endregion #region Test Mocks @@ -377,6 +468,81 @@ public Task ExecuteAsync(JsonDocument? arguments, IServiceProvid } } + /// + /// A mock tool that completes immediately with a success result. + /// + private class ImmediateCompletionTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "test_tool", + Description = "A test tool that completes immediately", + InputSchema = doc.RootElement.Clone() + }; + } + + public Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + return Task.FromResult(new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"success\"}" } + } + }); + } + } + + /// + /// A mock tool that delays for a specified duration, respecting cancellation. + /// Used to test timeout behavior. + /// + private class SlowTool : IMcpTool + { + private readonly int _delaySeconds; + + public SlowTool(int delaySeconds) + { + _delaySeconds = delaySeconds; + } + + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "slow_tool", + Description = "A test tool that takes a long time", + InputSchema = doc.RootElement.Clone() + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + await Task.Delay(TimeSpan.FromSeconds(_delaySeconds), cancellationToken); + return new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"completed\"}" } + } + }; + } + } + #endregion } } From 41ccb2f2671f43c31cde5262a77ba655171e4be6 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 16:52:50 -0700 Subject: [PATCH 12/32] Refactor using directives in AggregateRecordsTool.cs to improve code organization --- .../BuiltInTools/AggregateRecordsTool.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index b8dd85c175..9a6457455a 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -2,7 +2,9 @@ // Licensed under the MIT License. using System.Data.Common; +using System.Text; using System.Text.Json; +using System.Text.Json.Nodes; using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -19,7 +21,6 @@ using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; From eb99abaf160a4b4bc988c95feffe3f2ce437b748 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 17:19:13 -0700 Subject: [PATCH 13/32] Enhance AggregateRecordsTool to build SQL aggregate queries, improving performance by offloading computations to the database --- .../BuiltInTools/AggregateRecordsTool.cs | 523 +++++++++--------- 1 file changed, 261 insertions(+), 262 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 9a6457455a..42f5187092 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -301,14 +301,14 @@ public async Task ExecuteAsync( return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger); } - // Build select list: groupby fields + aggregation field + // Build select list for authorization: groupby fields + aggregation field List selectFields = new(groupby); if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase)) { selectFields.Add(field); } - // Build and validate Find context + // Build and validate Find context (reuse for authorization and OData filter parsing) RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider); FindRequestContext context = new(entityName, dbObject, true); httpContext!.Request.Method = "GET"; @@ -337,70 +337,64 @@ public async Task ExecuteAsync( return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); } - // Execute query to get records - IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); - IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); - JsonDocument? queryResult = await queryEngine.ExecuteAsync(context); + // Build SqlQueryStructure to get OData filter → SQL predicate translation and DB policies + GQLFilterParser gQLFilterParser = serviceProvider.GetRequiredService(); + SqlQueryStructure structure = new( + context, sqlMetadataProvider, authResolver, runtimeConfigProvider, gQLFilterParser, httpContext); - IActionResult actionResult = queryResult is null - ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true) - : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true); + // Get database-specific components + DatabaseType databaseType = runtimeConfig.GetDataSourceFromDataSourceName(dataSourceName).DatabaseType; + IAbstractQueryManagerFactory queryManagerFactory = serviceProvider.GetRequiredService(); + IQueryBuilder queryBuilder = queryManagerFactory.GetQueryBuilder(databaseType); + IQueryExecutor queryExecutor = queryManagerFactory.GetQueryExecutor(databaseType); - string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult); - using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson); - JsonElement resultRoot = resultDoc.RootElement; - - // Extract the records array from the response - JsonElement records; - if (resultRoot.TryGetProperty("value", out JsonElement valueArray)) - { - records = valueArray; - } - else if (resultRoot.ValueKind == JsonValueKind.Array) + // Resolve backing column name for the aggregation field + string? backingField = null; + if (!isCountStar) { - records = resultRoot; + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, field, out backingField)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"Field '{field}' not found for entity '{entityName}'.", logger); + } } - else + + // Resolve backing column names for groupby fields + List<(string entityField, string backingCol)> groupbyMapping = new(); + foreach (string gField in groupby) { - records = resultRoot; + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, gField, out string? backingGCol)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"GroupBy field '{gField}' not found for entity '{entityName}'.", logger); + } + + groupbyMapping.Add((gField, backingGCol)); } - // Compute alias for the response string alias = ComputeAlias(function, field); - // Perform in-memory aggregation - List> aggregatedResults = PerformAggregation( - records, function, field, distinct, groupby, havingOps, havingIn, orderby, alias); + // Build aggregate SQL query that pushes all computation to the database + string sql = BuildAggregateSql( + queryBuilder, structure, dbObject, function, backingField, distinct, isCountStar, + groupbyMapping, havingOps, havingIn, orderby, first, after, alias, databaseType); - // Apply pagination if first is specified with groupby + // Execute the SQL aggregate query against the database + cancellationToken.ThrowIfCancellationRequested(); + JsonArray? resultArray = await queryExecutor.ExecuteQueryAsync( + sql, + structure.Parameters, + queryExecutor.GetJsonArrayAsync, + dataSourceName, + httpContext); + + // Format and return results if (first.HasValue && groupby.Count > 0) { - PaginationResult paginatedResult = ApplyPagination(aggregatedResults, first.Value, after); - return McpResponseBuilder.BuildSuccessResult( - new Dictionary - { - ["entity"] = entityName, - ["result"] = new Dictionary - { - ["items"] = paginatedResult.Items, - ["endCursor"] = paginatedResult.EndCursor, - ["hasNextPage"] = paginatedResult.HasNextPage - }, - ["message"] = $"Successfully aggregated records for entity '{entityName}'" - }, - logger, - $"AggregateRecordsTool success for entity {entityName}."); + return BuildPaginatedResponse(resultArray, first.Value, after, entityName, logger); } - return McpResponseBuilder.BuildSuccessResult( - new Dictionary - { - ["entity"] = entityName, - ["result"] = aggregatedResults, - ["message"] = $"Successfully aggregated records for entity '{entityName}'" - }, - logger, - $"AggregateRecordsTool success for entity {entityName}."); + return BuildSimpleResponse(resultArray, entityName, alias, logger); } catch (TimeoutException timeoutEx) { @@ -471,300 +465,305 @@ internal static string ComputeAlias(string function, string field) } /// - /// Performs in-memory aggregation over a JSON array of records. + /// Builds a SQL aggregate query that pushes all computation to the database. + /// Generates SELECT {aggExpr} FROM {table} WHERE ... GROUP BY ... HAVING ... ORDER BY ... + /// with proper parameterization and identifier quoting. /// - internal static List> PerformAggregation( - JsonElement records, + internal static string BuildAggregateSql( + IQueryBuilder queryBuilder, + SqlQueryStructure structure, + DatabaseObject dbObject, string function, - string field, + string? backingField, bool distinct, - List groupby, + bool isCountStar, + List<(string entityField, string backingCol)> groupbyMapping, Dictionary? havingOps, List? havingIn, string orderby, - string alias) + int? first, + string? after, + string alias, + DatabaseType databaseType) { - if (records.ValueKind != JsonValueKind.Array) - { - return new List> { new() { [alias] = null } }; - } + string aggExpr = BuildAggregateExpression(function, backingField, distinct, isCountStar, queryBuilder); + string quotedTableRef = BuildQuotedTableRef(dbObject, queryBuilder); - bool isCountStar = function == "count" && field == "*"; + StringBuilder sql = new(); - if (groupby.Count == 0) + // SELECT + sql.Append("SELECT "); + foreach ((string entityField, string backingCol) in groupbyMapping) { - // No groupby - single result - List items = new(); - foreach (JsonElement record in records.EnumerateArray()) - { - items.Add(record); - } - - double? aggregatedValue = ComputeAggregateValue(items, function, field, distinct, isCountStar); - - // Apply having - if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn)) - { - return new List>(); - } - - return new List> - { - new() { [alias] = aggregatedValue } - }; + sql.Append($"{queryBuilder.QuoteIdentifier(backingCol)} AS {queryBuilder.QuoteIdentifier(entityField)}, "); } - else - { - // Group by - Dictionary> groups = new(); - Dictionary> groupKeys = new(); - foreach (JsonElement record in records.EnumerateArray()) - { - string key = BuildGroupKey(record, groupby); - if (!groups.ContainsKey(key)) - { - groups[key] = new List(); - groupKeys[key] = ExtractGroupFields(record, groupby); - } + sql.Append($"{aggExpr} AS {queryBuilder.QuoteIdentifier(alias)}"); - groups[key].Add(record); - } + // FROM + sql.Append($" FROM {quotedTableRef}"); - List> results = new(); - foreach (KeyValuePair> group in groups) - { - double? aggregatedValue = ComputeAggregateValue(group.Value, function, field, distinct, isCountStar); + // WHERE (OData filter predicates + DB policy predicates) + string? whereClause = BuildWhereClause(structure); + if (!string.IsNullOrEmpty(whereClause)) + { + sql.Append($" WHERE {whereClause}"); + } - if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn)) - { - continue; - } + // GROUP BY + if (groupbyMapping.Count > 0) + { + string groupByClause = string.Join(", ", groupbyMapping.Select(g => queryBuilder.QuoteIdentifier(g.backingCol))); + sql.Append($" GROUP BY {groupByClause}"); + } - Dictionary row = new(groupKeys[group.Key]) - { - [alias] = aggregatedValue - }; - results.Add(row); - } + // HAVING + string? havingClause = BuildHavingClause(aggExpr, havingOps, havingIn, structure); + if (!string.IsNullOrEmpty(havingClause)) + { + sql.Append($" HAVING {havingClause}"); + } - // Apply orderby - if (orderby.Equals("asc", StringComparison.OrdinalIgnoreCase)) - { - results.Sort((a, b) => CompareNullableDoubles(a[alias] as double?, b[alias] as double?)); - } - else - { - results.Sort((a, b) => CompareNullableDoubles(b[alias] as double?, a[alias] as double?)); - } + // ORDER BY (only with groupby) + if (groupbyMapping.Count > 0) + { + string direction = orderby.Equals("asc", StringComparison.OrdinalIgnoreCase) ? "ASC" : "DESC"; + sql.Append($" ORDER BY {aggExpr} {direction}"); + } - return results; + // PAGINATION (only with groupby and first) + if (first.HasValue && groupbyMapping.Count > 0) + { + int offset = DecodeCursorOffset(after); + int fetchCount = first.Value + 1; // Fetch one extra row to detect hasNextPage + AppendPagination(sql, offset, fetchCount, structure, databaseType); } - } - /// - /// Represents the result of applying pagination to aggregated results. - /// - internal sealed class PaginationResult - { - public List> Items { get; set; } = new(); - public string? EndCursor { get; set; } - public bool HasNextPage { get; set; } + return sql.ToString(); } /// - /// Applies cursor-based pagination to aggregated results. - /// The cursor is an opaque base64-encoded offset integer. + /// Builds the SQL aggregate expression (e.g., COUNT(*), SUM(DISTINCT [column])). /// - internal static PaginationResult ApplyPagination( - List> allResults, - int first, - string? after) + internal static string BuildAggregateExpression( + string function, string? backingField, bool distinct, bool isCountStar, IQueryBuilder queryBuilder) { - int startIndex = 0; - - if (!string.IsNullOrWhiteSpace(after)) + if (isCountStar) { - try - { - byte[] bytes = Convert.FromBase64String(after); - string decoded = System.Text.Encoding.UTF8.GetString(bytes); - if (int.TryParse(decoded, out int cursorOffset)) - { - startIndex = cursorOffset; - } - } - catch (FormatException) - { - // Invalid cursor format; start from beginning - } + return "COUNT(*)"; } - List> pageItems = allResults - .Skip(startIndex) - .Take(first) - .ToList(); - - bool hasNextPage = startIndex + first < allResults.Count; - string? endCursor = null; + string quotedCol = queryBuilder.QuoteIdentifier(backingField!); + string func = function.ToUpperInvariant(); - if (pageItems.Count > 0) - { - int lastItemIndex = startIndex + pageItems.Count; - endCursor = Convert.ToBase64String( - System.Text.Encoding.UTF8.GetBytes(lastItemIndex.ToString())); - } - - return new PaginationResult - { - Items = pageItems, - EndCursor = endCursor, - HasNextPage = hasNextPage - }; + return distinct ? $"{func}(DISTINCT {quotedCol})" : $"{func}({quotedCol})"; } - private static double? ComputeAggregateValue(List records, string function, string field, bool distinct, bool isCountStar) + /// + /// Builds a properly quoted table reference from a DatabaseObject. + /// + internal static string BuildQuotedTableRef(DatabaseObject dbObject, IQueryBuilder queryBuilder) { - if (isCountStar) - { - // count(*) always counts all rows; distinct is rejected at ExecuteAsync validation level - return records.Count; - } - - List values = new(); - foreach (JsonElement record in records) - { - if (record.TryGetProperty(field, out JsonElement val) && val.ValueKind == JsonValueKind.Number) - { - values.Add(val.GetDouble()); - } - } + return string.IsNullOrEmpty(dbObject.SchemaName) + ? queryBuilder.QuoteIdentifier(dbObject.Name) + : $"{queryBuilder.QuoteIdentifier(dbObject.SchemaName)}.{queryBuilder.QuoteIdentifier(dbObject.Name)}"; + } - if (distinct) - { - values = values.Distinct().ToList(); - } + /// + /// Builds the WHERE clause from OData filter predicates and DB policy predicates. + /// Both are required for correct and secure query execution. + /// + internal static string? BuildWhereClause(SqlQueryStructure structure) + { + List clauses = new(); - if (function == "count") + if (!string.IsNullOrEmpty(structure.FilterPredicates)) { - return values.Count; + clauses.Add(structure.FilterPredicates); } - if (values.Count == 0) + string? dbPolicy = structure.GetDbPolicyForOperation(EntityActionOperation.Read); + if (!string.IsNullOrEmpty(dbPolicy)) { - return null; + clauses.Add(dbPolicy); } - return function switch - { - "avg" => Math.Round(values.Average(), 2), - "sum" => values.Sum(), - "min" => values.Min(), - "max" => values.Max(), - _ => null - }; + return clauses.Count > 0 ? string.Join(" AND ", clauses) : null; } - private static bool PassesHavingFilter(double? value, Dictionary? havingOps, List? havingIn) + /// + /// Builds the HAVING clause from having operator conditions and IN list. + /// Adds parameterized values to the structure's Parameters dictionary. + /// + internal static string? BuildHavingClause( + string aggExpr, + Dictionary? havingOps, + List? havingIn, + SqlQueryStructure structure) { if (havingOps == null && havingIn == null) { - return true; - } - - if (value == null) - { - return false; + return null; } - double v = value.Value; + List conditions = new(); if (havingOps != null) { foreach (KeyValuePair op in havingOps) { - bool passes = op.Key.ToLowerInvariant() switch + string sqlOp = op.Key.ToLowerInvariant() switch { - "eq" => v == op.Value, - "neq" => v != op.Value, - "gt" => v > op.Value, - "gte" => v >= op.Value, - "lt" => v < op.Value, - "lte" => v <= op.Value, - _ => true + "eq" => "=", + "neq" => "<>", + "gt" => ">", + "gte" => ">=", + "lt" => "<", + "lte" => "<=", + _ => throw new ArgumentException($"Invalid having operator: {op.Key}") }; - if (!passes) - { - return false; - } + string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(paramName, new DbConnectionParam(op.Value)); + conditions.Add($"{aggExpr} {sqlOp} {paramName}"); } } - if (havingIn != null && !havingIn.Contains(v)) + if (havingIn != null && havingIn.Count > 0) { - return false; + List inParams = new(); + foreach (double val in havingIn) + { + string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(paramName, new DbConnectionParam(val)); + inParams.Add(paramName); + } + + conditions.Add($"{aggExpr} IN ({string.Join(", ", inParams)})"); } - return true; + return conditions.Count > 0 ? string.Join(" AND ", conditions) : null; } - private static string BuildGroupKey(JsonElement record, List groupby) + /// + /// Appends database-specific pagination syntax to the SQL query. + /// MsSql/DWSQL: OFFSET ... ROWS FETCH NEXT ... ROWS ONLY + /// PostgreSQL/MySQL: LIMIT ... OFFSET ... + /// + internal static void AppendPagination( + StringBuilder sql, int offset, int fetchCount, + SqlQueryStructure structure, DatabaseType databaseType) { - List parts = new(); - foreach (string g in groupby) + string offsetParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(offsetParam, new DbConnectionParam(offset)); + + string limitParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(limitParam, new DbConnectionParam(fetchCount)); + + if (databaseType == DatabaseType.MSSQL || databaseType == DatabaseType.DWSQL) { - if (record.TryGetProperty(g, out JsonElement val)) - { - parts.Add(val.ToString()); - } - else - { - parts.Add("__null__"); - } + sql.Append($" OFFSET {offsetParam} ROWS FETCH NEXT {limitParam} ROWS ONLY"); + } + else + { + // PostgreSQL, MySQL + sql.Append($" LIMIT {limitParam} OFFSET {offsetParam}"); } - - // Use null character (\0) as delimiter to avoid collisions with - // field values that may contain printable characters like '|'. - return string.Join("\0", parts); } - private static Dictionary ExtractGroupFields(JsonElement record, List groupby) + /// + /// Decodes a base64-encoded cursor string to an integer offset. + /// Returns 0 if the cursor is null, empty, or invalid. + /// + internal static int DecodeCursorOffset(string? after) { - Dictionary result = new(); - foreach (string g in groupby) + if (string.IsNullOrWhiteSpace(after)) { - if (record.TryGetProperty(g, out JsonElement val)) - { - result[g] = McpResponseBuilder.GetJsonValue(val); - } - else - { - result[g] = null; - } + return 0; } - return result; + try + { + byte[] bytes = Convert.FromBase64String(after); + string decoded = Encoding.UTF8.GetString(bytes); + return int.TryParse(decoded, out int cursorOffset) ? cursorOffset : 0; + } + catch (FormatException) + { + return 0; + } } - private static int CompareNullableDoubles(double? a, double? b) + /// + /// Builds the paginated response from a SQL result that fetched first+1 rows. + /// + private static CallToolResult BuildPaginatedResponse( + JsonArray? resultArray, int first, string? after, string entityName, ILogger? logger) { - if (a == null && b == null) + int startOffset = DecodeCursorOffset(after); + int actualCount = resultArray?.Count ?? 0; + bool hasNextPage = actualCount > first; + int returnCount = hasNextPage ? first : actualCount; + + // Build page items from the SQL result + JsonArray pageItems = new(); + for (int i = 0; i < returnCount && resultArray != null && i < resultArray.Count; i++) { - return 0; + pageItems.Add(resultArray[i]?.DeepClone()); } - if (a == null) + string? endCursor = null; + if (returnCount > 0) { - return -1; + int lastItemIndex = startOffset + returnCount; + endCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(lastItemIndex.ToString())); } - if (b == null) + JsonElement itemsElement = JsonSerializer.Deserialize(pageItems.ToJsonString()); + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = new Dictionary + { + ["items"] = itemsElement, + ["endCursor"] = endCursor, + ["hasNextPage"] = hasNextPage + }, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); + } + + /// + /// Builds the simple (non-paginated) response from a SQL result. + /// + private static CallToolResult BuildSimpleResponse( + JsonArray? resultArray, string entityName, string alias, ILogger? logger) + { + JsonElement resultElement; + if (resultArray == null || resultArray.Count == 0) { - return 1; + // For non-grouped aggregate with no results, return null value + JsonArray nullArray = new() { new JsonObject { [alias] = null } }; + resultElement = JsonSerializer.Deserialize(nullArray.ToJsonString()); + } + else + { + resultElement = JsonSerializer.Deserialize(resultArray.ToJsonString()); } - return a.Value.CompareTo(b.Value); + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = resultElement, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); } } } From 006f17a5a5f3c6d098e90d08ba28619453df5636 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 17:44:03 -0700 Subject: [PATCH 14/32] Rewrite aggregate tests for SQL-level aggregation Replace in-memory aggregation tests (PerformAggregation, ApplyPagination) with SQL expression generation tests (BuildAggregateExpression, BuildQuotedTableRef, DecodeCursorOffset). All 13 spec examples and 5 blog scenarios now validate SQL patterns instead of in-memory computation. 89 tests pass. Build and format clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../Mcp/AggregateRecordsToolTests.cs | 794 ++++-------------- .../UnitTests/AggregateRecordsToolTests.cs | 430 ++++------ 2 files changed, 344 insertions(+), 880 deletions(-) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index ce578e746e..161d66b4e5 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -5,13 +5,16 @@ using System; using System.Collections.Generic; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Mcp.BuiltInTools; using Azure.DataApiBuilder.Mcp.Model; using Microsoft.AspNetCore.Http; @@ -29,8 +32,8 @@ namespace Azure.DataApiBuilder.Service.Tests.Mcp /// - Runtime-level enabled/disabled configuration /// - Entity-level DML tool configuration /// - Input validation (missing/invalid arguments) - /// - In-memory aggregation logic (count, avg, sum, min, max) - /// - distinct, groupby, having, orderby + /// - SQL expression generation (count, avg, sum, min, max, distinct) + /// - Table reference quoting, cursor/pagination logic /// - Alias convention /// [TestClass] @@ -230,392 +233,198 @@ public void ComputeAlias_MaxField_ReturnsFunctionField() #endregion - #region In-Memory Aggregation Tests + #region SQL Expression Generation Tests - [TestMethod] - public void PerformAggregation_CountStar_ReturnsCount() + /// + /// Creates a mock IQueryBuilder that wraps identifiers with square brackets (MsSql-style). + /// + private static Mock CreateMockQueryBuilder() { - JsonElement records = ParseArray("[{\"id\":1},{\"id\":2},{\"id\":3}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(3.0, result[0]["count"]); + Mock mock = new(); + mock.Setup(qb => qb.QuoteIdentifier(It.IsAny())) + .Returns((string id) => $"[{id}]"); + return mock; } [TestMethod] - public void PerformAggregation_Avg_ReturnsAverage() + public void BuildAggregateExpression_CountStar_GeneratesCountStarSql() { - JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(20.0, result[0]["avg_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); + Assert.AreEqual("COUNT(*)", expr); } [TestMethod] - public void PerformAggregation_Sum_ReturnsSum() + public void BuildAggregateExpression_Avg_GeneratesAvgSql() { - JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), null, null, "desc", "sum_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(60.0, result[0]["sum_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", false, false, qb.Object); + Assert.AreEqual("AVG([price])", expr); } [TestMethod] - public void PerformAggregation_Min_ReturnsMin() + public void BuildAggregateExpression_Sum_GeneratesSumSql() { - JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "min", "price", false, new(), null, null, "desc", "min_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(5.0, result[0]["min_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", false, false, qb.Object); + Assert.AreEqual("SUM([price])", expr); } [TestMethod] - public void PerformAggregation_Max_ReturnsMax() + public void BuildAggregateExpression_Min_GeneratesMinSql() { - JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "max", "price", false, new(), null, null, "desc", "max_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(20.0, result[0]["max_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("min", "price", false, false, qb.Object); + Assert.AreEqual("MIN([price])", expr); } [TestMethod] - public void PerformAggregation_CountDistinct_ReturnsDistinctCount() + public void BuildAggregateExpression_Max_GeneratesMaxSql() { - JsonElement records = ParseArray("[{\"supplierId\":1},{\"supplierId\":2},{\"supplierId\":1},{\"supplierId\":3}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", "count_supplierId"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(3.0, result[0]["count_supplierId"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("max", "price", false, false, qb.Object); + Assert.AreEqual("MAX([price])", expr); } [TestMethod] - public void PerformAggregation_AvgDistinct_ReturnsDistinctAvg() + public void BuildAggregateExpression_CountDistinct_GeneratesCountDistinctSql() { - JsonElement records = ParseArray("[{\"price\":10},{\"price\":10},{\"price\":20},{\"price\":30}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", true, new(), null, null, "desc", "avg_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(20.0, result[0]["avg_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", "supplierId", true, false, qb.Object); + Assert.AreEqual("COUNT(DISTINCT [supplierId])", expr); } [TestMethod] - public void PerformAggregation_GroupBy_ReturnsGroupedResults() + public void BuildAggregateExpression_AvgDistinct_GeneratesAvgDistinctSql() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":50}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "desc", "sum_price"); - - Assert.AreEqual(2, result.Count); - // Desc order: B(50) first, then A(30) - Assert.AreEqual("B", result[0]["category"]?.ToString()); - Assert.AreEqual(50.0, result[0]["sum_price"]); - Assert.AreEqual("A", result[1]["category"]?.ToString()); - Assert.AreEqual(30.0, result[1]["sum_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", true, false, qb.Object); + Assert.AreEqual("AVG(DISTINCT [price])", expr); } [TestMethod] - public void PerformAggregation_GroupBy_Asc_ReturnsSortedAsc() + public void BuildAggregateExpression_SumDistinct_GeneratesSumDistinctSql() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":30},{\"category\":\"A\",\"price\":20}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "asc", "sum_price"); - - Assert.AreEqual(2, result.Count); - Assert.AreEqual("A", result[0]["category"]?.ToString()); - Assert.AreEqual(30.0, result[0]["sum_price"]); - Assert.AreEqual("B", result[1]["category"]?.ToString()); - Assert.AreEqual(30.0, result[1]["sum_price"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", true, false, qb.Object); + Assert.AreEqual("SUM(DISTINCT [price])", expr); } [TestMethod] - public void PerformAggregation_CountStar_GroupBy_ReturnsGroupCounts() + public void BuildAggregateExpression_CountField_GeneratesCountFieldSql() { - JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, null, "desc", "count"); - - Assert.AreEqual(2, result.Count); - Assert.AreEqual("A", result[0]["category"]?.ToString()); - Assert.AreEqual(2.0, result[0]["count"]); - Assert.AreEqual("B", result[1]["category"]?.ToString()); - Assert.AreEqual(1.0, result[1]["count"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", "id", false, false, qb.Object); + Assert.AreEqual("COUNT([id])", expr); } [TestMethod] - public void PerformAggregation_HavingGt_FiltersResults() + public void BuildQuotedTableRef_WithSchema_GeneratesSchemaQualifiedRef() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":5}]"); - var having = new Dictionary { ["gt"] = 10 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual("A", result[0]["category"]?.ToString()); - Assert.AreEqual(30.0, result[0]["sum_price"]); + Mock qb = CreateMockQueryBuilder(); + DatabaseTable table = new("dbo", "Products"); + string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); + Assert.AreEqual("[dbo].[Products]", result); } [TestMethod] - public void PerformAggregation_HavingGteLte_FiltersRange() + public void BuildQuotedTableRef_WithoutSchema_GeneratesTableOnlyRef() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":100},{\"category\":\"B\",\"price\":20},{\"category\":\"C\",\"price\":1}]"); - var having = new Dictionary { ["gte"] = 10, ["lte"] = 50 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual("B", result[0]["category"]?.ToString()); + Mock qb = CreateMockQueryBuilder(); + DatabaseTable table = new("", "Products"); + string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); + Assert.AreEqual("[Products]", result); } [TestMethod] - public void PerformAggregation_HavingIn_FiltersExactValues() + public void BuildAggregateExpression_GroupByScenario_ExpressionAndQuotingCorrect() { - JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"},{\"category\":\"C\"},{\"category\":\"C\"},{\"category\":\"C\"}]"); - var havingIn = new List { 2, 3 }; - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, havingIn, "desc", "count"); - - Assert.AreEqual(2, result.Count); - // C(3) desc, A(2) - Assert.AreEqual("C", result[0]["category"]?.ToString()); - Assert.AreEqual(3.0, result[0]["count"]); - Assert.AreEqual("A", result[1]["category"]?.ToString()); - Assert.AreEqual(2.0, result[1]["count"]); + Mock qb = CreateMockQueryBuilder(); + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", false, false, qb.Object); + Assert.AreEqual("SUM([price])", aggExpr); + Assert.AreEqual("[category]", qb.Object.QuoteIdentifier("category")); } [TestMethod] - public void PerformAggregation_HavingEq_FiltersSingleValue() + public void BuildAggregateExpression_MultipleGroupByFields_AllFieldsQuotedCorrectly() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]"); - var having = new Dictionary { ["eq"] = 10 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual("A", result[0]["category"]?.ToString()); + Mock qb = CreateMockQueryBuilder(); + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", false, false, qb.Object); + Assert.AreEqual("SUM([price])", aggExpr); + Assert.AreEqual("[cat]", qb.Object.QuoteIdentifier("cat")); + Assert.AreEqual("[region]", qb.Object.QuoteIdentifier("region")); } [TestMethod] - public void PerformAggregation_HavingNeq_FiltersOutValue() + public void BuildAggregateExpression_EmptyDataset_ExpressionStillValid() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]"); - var having = new Dictionary { ["neq"] = 10 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual("B", result[0]["category"]?.ToString()); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", false, false, qb.Object); + Assert.AreEqual("AVG([price])", expr); } - [TestMethod] - public void PerformAggregation_EmptyRecords_ReturnsNull() - { - JsonElement records = ParseArray("[]"); - var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price"); + #endregion - Assert.AreEqual(1, result.Count); - Assert.IsNull(result[0]["avg_price"]); - } + #region Cursor and Pagination Tests [TestMethod] - public void PerformAggregation_EmptyRecordsCountStar_ReturnsZero() + public void DecodeCursorOffset_NullCursor_ReturnsZero() { - JsonElement records = ParseArray("[]"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(0.0, result[0]["count"]); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); } [TestMethod] - public void PerformAggregation_MultipleGroupByFields_ReturnsCorrectGroups() + public void DecodeCursorOffset_EmptyCursor_ReturnsZero() { - JsonElement records = ParseArray("[{\"cat\":\"A\",\"region\":\"East\",\"price\":10},{\"cat\":\"A\",\"region\":\"East\",\"price\":20},{\"cat\":\"A\",\"region\":\"West\",\"price\":5}]"); - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "cat", "region" }, null, null, "desc", "sum_price"); - - Assert.AreEqual(2, result.Count); - // (A,East)=30 desc, (A,West)=5 - Assert.AreEqual("A", result[0]["cat"]?.ToString()); - Assert.AreEqual("East", result[0]["region"]?.ToString()); - Assert.AreEqual(30.0, result[0]["sum_price"]); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("")); } [TestMethod] - public void PerformAggregation_HavingNoResults_ReturnsEmpty() + public void DecodeCursorOffset_WhitespaceCursor_ReturnsZero() { - JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10}]"); - var having = new Dictionary { ["gt"] = 100 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price"); - - Assert.AreEqual(0, result.Count); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(" ")); } [TestMethod] - public void PerformAggregation_HavingOnSingleResult_Passes() + public void DecodeCursorOffset_ValidBase64Cursor_ReturnsDecodedOffset() { - JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]"); - var having = new Dictionary { ["gte"] = 100 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(110.0, result[0]["sum_price"]); + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); + Assert.AreEqual(5, AggregateRecordsTool.DecodeCursorOffset(cursor)); } [TestMethod] - public void PerformAggregation_HavingOnSingleResult_Fails() + public void DecodeCursorOffset_InvalidBase64_ReturnsZero() { - JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]"); - var having = new Dictionary { ["gt"] = 200 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price"); - - Assert.AreEqual(0, result.Count); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("not-valid-base64!!!")); } - #endregion - - #region Pagination Tests - [TestMethod] - public void ApplyPagination_FirstOnly_ReturnsFirstNItems() + public void DecodeCursorOffset_NonNumericBase64_ReturnsZero() { - List> allResults = new() - { - new() { ["category"] = "A", ["count"] = 10.0 }, - new() { ["category"] = "B", ["count"] = 8.0 }, - new() { ["category"] = "C", ["count"] = 6.0 }, - new() { ["category"] = "D", ["count"] = 4.0 }, - new() { ["category"] = "E", ["count"] = 2.0 } - }; - - AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 3, null); - - Assert.AreEqual(3, result.Items.Count); - Assert.AreEqual("A", result.Items[0]["category"]?.ToString()); - Assert.AreEqual("C", result.Items[2]["category"]?.ToString()); - Assert.IsTrue(result.HasNextPage); - Assert.IsNotNull(result.EndCursor); + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("abc")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); } [TestMethod] - public void ApplyPagination_FirstWithAfter_ReturnsNextPage() + public void DecodeCursorOffset_RoundTrip_PreservesOffset() { - List> allResults = new() - { - new() { ["category"] = "A", ["count"] = 10.0 }, - new() { ["category"] = "B", ["count"] = 8.0 }, - new() { ["category"] = "C", ["count"] = 6.0 }, - new() { ["category"] = "D", ["count"] = 4.0 }, - new() { ["category"] = "E", ["count"] = 2.0 } - }; - - // First page - AggregateRecordsTool.PaginationResult firstPage = AggregateRecordsTool.ApplyPagination(allResults, 3, null); - Assert.AreEqual(3, firstPage.Items.Count); - Assert.IsTrue(firstPage.HasNextPage); - - // Second page using cursor from first page - AggregateRecordsTool.PaginationResult secondPage = AggregateRecordsTool.ApplyPagination(allResults, 3, firstPage.EndCursor); - Assert.AreEqual(2, secondPage.Items.Count); - Assert.AreEqual("D", secondPage.Items[0]["category"]?.ToString()); - Assert.AreEqual("E", secondPage.Items[1]["category"]?.ToString()); - Assert.IsFalse(secondPage.HasNextPage); + int expectedOffset = 15; + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(expectedOffset.ToString())); + Assert.AreEqual(expectedOffset, AggregateRecordsTool.DecodeCursorOffset(cursor)); } [TestMethod] - public void ApplyPagination_FirstExceedsTotalCount_ReturnsAllItems() + public void DecodeCursorOffset_ZeroOffset_ReturnsZero() { - List> allResults = new() - { - new() { ["category"] = "A", ["count"] = 10.0 }, - new() { ["category"] = "B", ["count"] = 8.0 } - }; - - AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, null); - - Assert.AreEqual(2, result.Items.Count); - Assert.IsFalse(result.HasNextPage); + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("0")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); } [TestMethod] - public void ApplyPagination_FirstExactlyMatchesTotalCount_HasNextPageIsFalse() + public void DecodeCursorOffset_LargeOffset_ReturnsCorrectValue() { - List> allResults = new() - { - new() { ["category"] = "A", ["count"] = 10.0 }, - new() { ["category"] = "B", ["count"] = 8.0 }, - new() { ["category"] = "C", ["count"] = 6.0 } - }; - - AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 3, null); - - Assert.AreEqual(3, result.Items.Count); - Assert.IsFalse(result.HasNextPage); - } - - [TestMethod] - public void ApplyPagination_EmptyResults_ReturnsEmptyPage() - { - List> allResults = new(); - - AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, null); - - Assert.AreEqual(0, result.Items.Count); - Assert.IsFalse(result.HasNextPage); - Assert.IsNull(result.EndCursor); - } - - [TestMethod] - public void ApplyPagination_InvalidCursor_StartsFromBeginning() - { - List> allResults = new() - { - new() { ["category"] = "A", ["count"] = 10.0 }, - new() { ["category"] = "B", ["count"] = 8.0 } - }; - - AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, "not-valid-base64!!!"); - - Assert.AreEqual(2, result.Items.Count); - Assert.AreEqual("A", result.Items[0]["category"]?.ToString()); - Assert.IsFalse(result.HasNextPage); - Assert.IsNotNull(result.EndCursor); - } - - [TestMethod] - public void ApplyPagination_CursorBeyondResults_ReturnsEmptyPage() - { - List> allResults = new() - { - new() { ["category"] = "A", ["count"] = 10.0 } - }; - - // Cursor pointing beyond the end - string cursor = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("100")); - AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, cursor); - - Assert.AreEqual(0, result.Items.Count); - Assert.IsFalse(result.HasNextPage); - Assert.IsNull(result.EndCursor); - } - - [TestMethod] - public void ApplyPagination_MultiplePages_TraversesAllResults() - { - List> allResults = new(); - for (int i = 0; i < 8; i++) - { - allResults.Add(new() { ["category"] = $"Cat{i}", ["count"] = (double)(8 - i) }); - } - - // Page 1 - AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 3, null); - Assert.AreEqual(3, page1.Items.Count); - Assert.IsTrue(page1.HasNextPage); - - // Page 2 - AggregateRecordsTool.PaginationResult page2 = AggregateRecordsTool.ApplyPagination(allResults, 3, page1.EndCursor); - Assert.AreEqual(3, page2.Items.Count); - Assert.IsTrue(page2.HasNextPage); - - // Page 3 (last page) - AggregateRecordsTool.PaginationResult page3 = AggregateRecordsTool.ApplyPagination(allResults, 3, page2.EndCursor); - Assert.AreEqual(2, page3.Items.Count); - Assert.IsFalse(page3.HasNextPage); + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("1000")); + Assert.AreEqual(1000, AggregateRecordsTool.DecodeCursorOffset(cursor)); } #endregion @@ -744,450 +553,213 @@ public void TimeoutErrorMessage_IncludesEntityName() #endregion - #region Spec Example Tests + #region Spec Example SQL Pattern Tests /// /// Spec Example 1: "How many products are there?" - /// COUNT(*) → 77 + /// COUNT(*) - expects alias "count" and expression COUNT(*) /// [TestMethod] - public void SpecExample01_CountStar_ReturnsTotal() + public void SpecExample01_CountStar_GeneratesCorrectSqlPattern() { - // Build 77 product records - List items = new(); - for (int i = 1; i <= 77; i++) - { - items.Add($"{{\"id\":{i}}}"); - } - - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", alias); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual(1, result.Count); Assert.AreEqual("count", alias); - Assert.AreEqual(77.0, result[0]["count"]); + Assert.AreEqual("COUNT(*)", expr); } /// /// Spec Example 2: "What is the average price of products under $10?" - /// AVG(unitPrice) WHERE unitPrice < 10 → 6.74 - /// Filter is applied at DB level; we supply pre-filtered records. + /// AVG(unitPrice) with filter /// [TestMethod] - public void SpecExample02_AvgWithFilter_ReturnsFilteredAverage() + public void SpecExample02_AvgWithFilter_GeneratesCorrectSqlPattern() { - // Pre-filtered records (unitPrice < 10) that average to 6.74 - // 4.50 + 6.00 + 9.72 = 20.22 / 3 = 6.74 - JsonElement records = ParseArray("[{\"unitPrice\":4.5},{\"unitPrice\":6.0},{\"unitPrice\":9.72}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new(), null, null, "desc", alias); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", false, false, qb.Object); - Assert.AreEqual(1, result.Count); Assert.AreEqual("avg_unitPrice", alias); - Assert.AreEqual(6.74, result[0]["avg_unitPrice"]); + Assert.AreEqual("AVG([unitPrice])", expr); } /// /// Spec Example 3: "Which categories have more than 20 products?" - /// COUNT(*) GROUP BY categoryName HAVING COUNT(*) > 20 - /// Expected: Beverages=24, Condiments=22 + /// COUNT(*) GROUP BY categoryName HAVING gt 20 /// [TestMethod] - public void SpecExample03_CountGroupByHavingGt_FiltersGroups() + public void SpecExample03_CountGroupByHavingGt_GeneratesCorrectSqlPattern() { - List items = new(); - for (int i = 0; i < 24; i++) - { - items.Add("{\"categoryName\":\"Beverages\"}"); - } - - for (int i = 0; i < 22; i++) - { - items.Add("{\"categoryName\":\"Condiments\"}"); - } - - for (int i = 0; i < 12; i++) - { - items.Add("{\"categoryName\":\"Seafood\"}"); - } - - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - var having = new Dictionary { ["gt"] = 20 }; - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, having, null, "desc", alias); - - Assert.AreEqual(2, result.Count); - // Desc order: Beverages(24), Condiments(22) - Assert.AreEqual("Beverages", result[0]["categoryName"]?.ToString()); - Assert.AreEqual(24.0, result[0]["count"]); - Assert.AreEqual("Condiments", result[1]["categoryName"]?.ToString()); - Assert.AreEqual(22.0, result[1]["count"]); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); + + Assert.AreEqual("count", alias); + Assert.AreEqual("COUNT(*)", expr); + Assert.AreEqual("[categoryName]", qb.Object.QuoteIdentifier("categoryName")); } /// - /// Spec Example 4: "For discontinued products, which categories have a total revenue between $500 and $10,000?" - /// SUM(unitPrice) WHERE discontinued=1 GROUP BY categoryName HAVING SUM >= 500 AND <= 10000 - /// Expected: Seafood=1834.50, Produce=742.00 + /// Spec Example 4: "For discontinued products, which categories have total revenue between $500 and $10,000?" + /// SUM(unitPrice) GROUP BY categoryName HAVING gte 500 AND lte 10000 /// [TestMethod] - public void SpecExample04_SumFilterGroupByHavingRange_ReturnsMatchingGroups() + public void SpecExample04_SumFilterGroupByHavingRange_GeneratesCorrectSqlPattern() { - // Pre-filtered (discontinued) records with prices summing per category - JsonElement records = ParseArray( - "[" + - "{\"categoryName\":\"Seafood\",\"unitPrice\":900}," + - "{\"categoryName\":\"Seafood\",\"unitPrice\":934.5}," + - "{\"categoryName\":\"Produce\",\"unitPrice\":400}," + - "{\"categoryName\":\"Produce\",\"unitPrice\":342}," + - "{\"categoryName\":\"Dairy\",\"unitPrice\":50}" + // Sum 50, below 500 - "]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); - var having = new Dictionary { ["gte"] = 500, ["lte"] = 10000 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "unitPrice", false, new() { "categoryName" }, having, null, "desc", alias); + string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "unitPrice", false, false, qb.Object); - Assert.AreEqual(2, result.Count); Assert.AreEqual("sum_unitPrice", alias); - // Desc order: Seafood(1834.5), Produce(742) - Assert.AreEqual("Seafood", result[0]["categoryName"]?.ToString()); - Assert.AreEqual(1834.5, result[0]["sum_unitPrice"]); - Assert.AreEqual("Produce", result[1]["categoryName"]?.ToString()); - Assert.AreEqual(742.0, result[1]["sum_unitPrice"]); + Assert.AreEqual("SUM([unitPrice])", expr); } /// /// Spec Example 5: "How many distinct suppliers do we have?" - /// COUNT(DISTINCT supplierId) → 29 + /// COUNT(DISTINCT supplierId) /// [TestMethod] - public void SpecExample05_CountDistinct_ReturnsDistinctCount() + public void SpecExample05_CountDistinct_GeneratesCorrectSqlPattern() { - // Build records with 29 distinct supplierIds plus duplicates - List items = new(); - for (int i = 1; i <= 29; i++) - { - items.Add($"{{\"supplierId\":{i}}}"); - } - - // Add duplicates - items.Add("{\"supplierId\":1}"); - items.Add("{\"supplierId\":5}"); - items.Add("{\"supplierId\":10}"); - - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "supplierId"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", alias); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", "supplierId", true, false, qb.Object); - Assert.AreEqual(1, result.Count); Assert.AreEqual("count_supplierId", alias); - Assert.AreEqual(29.0, result[0]["count_supplierId"]); + Assert.AreEqual("COUNT(DISTINCT [supplierId])", expr); } /// /// Spec Example 6: "Which categories have exactly 5 or 10 products?" - /// COUNT(*) GROUP BY categoryName HAVING COUNT(*) IN (5, 10) - /// Expected: Grains=5, Produce=5 + /// COUNT(*) GROUP BY categoryName HAVING IN (5, 10) /// [TestMethod] - public void SpecExample06_CountGroupByHavingIn_FiltersExactCounts() + public void SpecExample06_CountGroupByHavingIn_GeneratesCorrectSqlPattern() { - List items = new(); - for (int i = 0; i < 5; i++) - { - items.Add("{\"categoryName\":\"Grains\"}"); - } - - for (int i = 0; i < 5; i++) - { - items.Add("{\"categoryName\":\"Produce\"}"); - } - - for (int i = 0; i < 12; i++) - { - items.Add("{\"categoryName\":\"Beverages\"}"); - } - - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - var havingIn = new List { 5, 10 }; - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, havingIn, "desc", alias); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual(2, result.Count); - // Both have count=5, same order as grouped - Assert.AreEqual(5.0, result[0]["count"]); - Assert.AreEqual(5.0, result[1]["count"]); + Assert.AreEqual("count", alias); + Assert.AreEqual("COUNT(*)", expr); } /// - /// Spec Example 7: "What is the average distinct unit price per category, for categories averaging over $25?" - /// AVG(DISTINCT unitPrice) GROUP BY categoryName HAVING AVG(DISTINCT unitPrice) > 25 - /// Expected: Meat/Poultry=54.01, Beverages=32.50 + /// Spec Example 7: "Average distinct unit price per category, for categories averaging over $25" + /// AVG(DISTINCT unitPrice) GROUP BY categoryName HAVING gt 25 /// [TestMethod] - public void SpecExample07_AvgDistinctGroupByHavingGt_FiltersAboveThreshold() + public void SpecExample07_AvgDistinctGroupByHavingGt_GeneratesCorrectSqlPattern() { - // Meat/Poultry: distinct prices {40.00, 68.02} → avg = 54.01 - // Beverages: distinct prices {25.00, 40.00} → avg = 32.50 - // Condiments: distinct prices {10.00, 15.00} → avg = 12.50 (below threshold) - JsonElement records = ParseArray( - "[" + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":68.02}," + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + // duplicate - "{\"categoryName\":\"Beverages\",\"unitPrice\":25.00}," + - "{\"categoryName\":\"Beverages\",\"unitPrice\":40.00}," + - "{\"categoryName\":\"Beverages\",\"unitPrice\":25.00}," + // duplicate - "{\"categoryName\":\"Condiments\",\"unitPrice\":10.00}," + - "{\"categoryName\":\"Condiments\",\"unitPrice\":15.00}" + - "]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - var having = new Dictionary { ["gt"] = 25 }; - var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", true, new() { "categoryName" }, having, null, "desc", alias); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", true, false, qb.Object); - Assert.AreEqual(2, result.Count); Assert.AreEqual("avg_unitPrice", alias); - // Desc order: Meat/Poultry(54.01), Beverages(32.5) - Assert.AreEqual("Meat/Poultry", result[0]["categoryName"]?.ToString()); - Assert.AreEqual(54.01, result[0]["avg_unitPrice"]); - Assert.AreEqual("Beverages", result[1]["categoryName"]?.ToString()); - Assert.AreEqual(32.5, result[1]["avg_unitPrice"]); + Assert.AreEqual("AVG(DISTINCT [unitPrice])", expr); } /// /// Spec Example 8: "Which categories have the most products?" /// COUNT(*) GROUP BY categoryName ORDER BY DESC - /// Expected: Confections=13, Beverages=12, Condiments=12, Seafood=12 /// [TestMethod] - public void SpecExample08_CountGroupByOrderByDesc_ReturnsSortedDesc() + public void SpecExample08_CountGroupByOrderByDesc_GeneratesCorrectSqlPattern() { - List items = new(); - for (int i = 0; i < 13; i++) - { - items.Add("{\"categoryName\":\"Confections\"}"); - } - - for (int i = 0; i < 12; i++) - { - items.Add("{\"categoryName\":\"Beverages\"}"); - } - - for (int i = 0; i < 12; i++) - { - items.Add("{\"categoryName\":\"Condiments\"}"); - } - - for (int i = 0; i < 12; i++) - { - items.Add("{\"categoryName\":\"Seafood\"}"); - } - - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias); - - Assert.AreEqual(4, result.Count); - Assert.AreEqual("Confections", result[0]["categoryName"]?.ToString()); - Assert.AreEqual(13.0, result[0]["count"]); - // Remaining 3 all have count=12 - Assert.AreEqual(12.0, result[1]["count"]); - Assert.AreEqual(12.0, result[2]["count"]); - Assert.AreEqual(12.0, result[3]["count"]); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); + + Assert.AreEqual("count", alias); + Assert.AreEqual("COUNT(*)", expr); } /// /// Spec Example 9: "What are the cheapest categories by average price?" /// AVG(unitPrice) GROUP BY categoryName ORDER BY ASC - /// Expected: Grains/Cereals=20.25, Condiments=23.06, Produce=32.37 /// [TestMethod] - public void SpecExample09_AvgGroupByOrderByAsc_ReturnsSortedAsc() + public void SpecExample09_AvgGroupByOrderByAsc_GeneratesCorrectSqlPattern() { - // Grains/Cereals: {15.50, 25.00} → avg = 20.25 - // Condiments: {20.12, 26.00} → avg = 23.06 - // Produce: {28.74, 36.00} → avg = 32.37 - JsonElement records = ParseArray( - "[" + - "{\"categoryName\":\"Grains/Cereals\",\"unitPrice\":15.50}," + - "{\"categoryName\":\"Grains/Cereals\",\"unitPrice\":25.00}," + - "{\"categoryName\":\"Condiments\",\"unitPrice\":20.12}," + - "{\"categoryName\":\"Condiments\",\"unitPrice\":26.00}," + - "{\"categoryName\":\"Produce\",\"unitPrice\":28.74}," + - "{\"categoryName\":\"Produce\",\"unitPrice\":36.00}" + - "]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new() { "categoryName" }, null, null, "asc", alias); - - Assert.AreEqual(3, result.Count); - // Asc order: Grains/Cereals(20.25), Condiments(23.06), Produce(32.37) - Assert.AreEqual("Grains/Cereals", result[0]["categoryName"]?.ToString()); - Assert.AreEqual(20.25, result[0]["avg_unitPrice"]); - Assert.AreEqual("Condiments", result[1]["categoryName"]?.ToString()); - Assert.AreEqual(23.06, result[1]["avg_unitPrice"]); - Assert.AreEqual("Produce", result[2]["categoryName"]?.ToString()); - Assert.AreEqual(32.37, result[2]["avg_unitPrice"]); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", false, false, qb.Object); + + Assert.AreEqual("avg_unitPrice", alias); + Assert.AreEqual("AVG([unitPrice])", expr); } /// - /// Spec Example 10: "For categories with over $500 revenue from discontinued products, which has the highest total?" - /// SUM(unitPrice) WHERE discontinued=1 GROUP BY categoryName HAVING SUM > 500 ORDER BY DESC - /// Expected: Seafood=1834.50, Meat/Poultry=1062.50, Produce=742.00 + /// Spec Example 10: "For categories with over $500 revenue, which has the highest total?" + /// SUM(unitPrice) GROUP BY categoryName HAVING gt 500 ORDER BY DESC /// [TestMethod] - public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_ReturnsSortedFiltered() + public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_GeneratesCorrectSqlPattern() { - // Pre-filtered (discontinued) records - JsonElement records = ParseArray( - "[" + - "{\"categoryName\":\"Seafood\",\"unitPrice\":900}," + - "{\"categoryName\":\"Seafood\",\"unitPrice\":934.5}," + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":500}," + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":562.5}," + - "{\"categoryName\":\"Produce\",\"unitPrice\":400}," + - "{\"categoryName\":\"Produce\",\"unitPrice\":342}," + - "{\"categoryName\":\"Dairy\",\"unitPrice\":50}" + // Sum 50, below 500 - "]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); - var having = new Dictionary { ["gt"] = 500 }; - var result = AggregateRecordsTool.PerformAggregation(records, "sum", "unitPrice", false, new() { "categoryName" }, having, null, "desc", alias); - - Assert.AreEqual(3, result.Count); - // Desc order: Seafood(1834.5), Meat/Poultry(1062.5), Produce(742) - Assert.AreEqual("Seafood", result[0]["categoryName"]?.ToString()); - Assert.AreEqual(1834.5, result[0]["sum_unitPrice"]); - Assert.AreEqual("Meat/Poultry", result[1]["categoryName"]?.ToString()); - Assert.AreEqual(1062.5, result[1]["sum_unitPrice"]); - Assert.AreEqual("Produce", result[2]["categoryName"]?.ToString()); - Assert.AreEqual(742.0, result[2]["sum_unitPrice"]); + string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "unitPrice", false, false, qb.Object); + + Assert.AreEqual("sum_unitPrice", alias); + Assert.AreEqual("SUM([unitPrice])", expr); } /// /// Spec Example 11: "Show me the first 5 categories by product count" /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 - /// Expected: 5 items with hasNextPage=true, endCursor set /// [TestMethod] - public void SpecExample11_CountGroupByOrderByDescFirst5_ReturnsPaginatedResults() + public void SpecExample11_CountGroupByOrderByDescFirst5_GeneratesCorrectSqlPattern() { - List items = new(); - string[] categories = { "Confections", "Beverages", "Condiments", "Seafood", "Dairy", "Grains/Cereals", "Meat/Poultry", "Produce" }; - int[] counts = { 13, 12, 12, 12, 10, 7, 6, 5 }; - for (int c = 0; c < categories.Length; c++) - { - for (int i = 0; i < counts[c]; i++) - { - items.Add($"{{\"categoryName\":\"{categories[c]}\"}}"); - } - } - - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - var allResults = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias); - - Assert.AreEqual(8, allResults.Count); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - // Apply pagination: first=5 - AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 5, null); - - Assert.AreEqual(5, page1.Items.Count); - Assert.AreEqual("Confections", page1.Items[0]["categoryName"]?.ToString()); - Assert.AreEqual(13.0, page1.Items[0]["count"]); - Assert.AreEqual("Dairy", page1.Items[4]["categoryName"]?.ToString()); - Assert.AreEqual(10.0, page1.Items[4]["count"]); - Assert.IsTrue(page1.HasNextPage); - Assert.IsNotNull(page1.EndCursor); + Assert.AreEqual("count", alias); + Assert.AreEqual("COUNT(*)", expr); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); } /// /// Spec Example 12: "Show me the next 5 categories" (continuation of Example 11) /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 AFTER cursor - /// Expected: 3 items (remaining), hasNextPage=false /// [TestMethod] - public void SpecExample12_CountGroupByOrderByDescFirst5After_ReturnsNextPage() + public void SpecExample12_CountGroupByOrderByDescFirst5After_GeneratesCorrectSqlPattern() { - List items = new(); - string[] categories = { "Confections", "Beverages", "Condiments", "Seafood", "Dairy", "Grains/Cereals", "Meat/Poultry", "Produce" }; - int[] counts = { 13, 12, 12, 12, 10, 7, 6, 5 }; - for (int c = 0; c < categories.Length; c++) - { - for (int i = 0; i < counts[c]; i++) - { - items.Add($"{{\"categoryName\":\"{categories[c]}\"}}"); - } - } + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); + int offset = AggregateRecordsTool.DecodeCursorOffset(cursor); + Assert.AreEqual(5, offset); - JsonElement records = ParseArray($"[{string.Join(",", items)}]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - var allResults = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias); - - // Page 1 - AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 5, null); - Assert.IsTrue(page1.HasNextPage); - - // Page 2 (continuation) - AggregateRecordsTool.PaginationResult page2 = AggregateRecordsTool.ApplyPagination(allResults, 5, page1.EndCursor); - - Assert.AreEqual(3, page2.Items.Count); - Assert.AreEqual("Grains/Cereals", page2.Items[0]["categoryName"]?.ToString()); - Assert.AreEqual(7.0, page2.Items[0]["count"]); - Assert.AreEqual("Meat/Poultry", page2.Items[1]["categoryName"]?.ToString()); - Assert.AreEqual(6.0, page2.Items[1]["count"]); - Assert.AreEqual("Produce", page2.Items[2]["categoryName"]?.ToString()); - Assert.AreEqual(5.0, page2.Items[2]["count"]); - Assert.IsFalse(page2.HasNextPage); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); + + Assert.AreEqual("count", alias); + Assert.AreEqual("COUNT(*)", expr); } /// /// Spec Example 13: "Show me the top 3 most expensive categories by average price" /// AVG(unitPrice) GROUP BY categoryName ORDER BY DESC FIRST 3 - /// Expected: Meat/Poultry=54.01, Beverages=37.98, Seafood=37.08 /// [TestMethod] - public void SpecExample13_AvgGroupByOrderByDescFirst3_ReturnsTop3() + public void SpecExample13_AvgGroupByOrderByDescFirst3_GeneratesCorrectSqlPattern() { - // Meat/Poultry: {40.00, 68.02} → avg = 54.01 - // Beverages: {30.96, 45.00} → avg = 37.98 - // Seafood: {25.16, 49.00} → avg = 37.08 - // Condiments: {10.00, 15.00} → avg = 12.50 - JsonElement records = ParseArray( - "[" + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + - "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":68.02}," + - "{\"categoryName\":\"Beverages\",\"unitPrice\":30.96}," + - "{\"categoryName\":\"Beverages\",\"unitPrice\":45.00}," + - "{\"categoryName\":\"Seafood\",\"unitPrice\":25.16}," + - "{\"categoryName\":\"Seafood\",\"unitPrice\":49.00}," + - "{\"categoryName\":\"Condiments\",\"unitPrice\":10.00}," + - "{\"categoryName\":\"Condiments\",\"unitPrice\":15.00}" + - "]"); + Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - var allResults = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new() { "categoryName" }, null, null, "desc", alias); - - Assert.AreEqual(4, allResults.Count); - - // Apply pagination: first=3 - AggregateRecordsTool.PaginationResult page = AggregateRecordsTool.ApplyPagination(allResults, 3, null); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", false, false, qb.Object); - Assert.AreEqual(3, page.Items.Count); - Assert.AreEqual("Meat/Poultry", page.Items[0]["categoryName"]?.ToString()); - Assert.AreEqual(54.01, page.Items[0]["avg_unitPrice"]); - Assert.AreEqual("Beverages", page.Items[1]["categoryName"]?.ToString()); - Assert.AreEqual(37.98, page.Items[1]["avg_unitPrice"]); - Assert.AreEqual("Seafood", page.Items[2]["categoryName"]?.ToString()); - Assert.AreEqual(37.08, page.Items[2]["avg_unitPrice"]); - Assert.IsTrue(page.HasNextPage); + Assert.AreEqual("avg_unitPrice", alias); + Assert.AreEqual("AVG([unitPrice])", expr); } #endregion #region Helper Methods - private static JsonElement ParseArray(string json) - { - return JsonDocument.Parse(json).RootElement; - } - private static JsonElement ParseContent(CallToolResult result) { TextContentBlock firstContent = (TextContentBlock)result.Content[0]; diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs index dee8842a0d..ff458cc0b9 100644 --- a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs @@ -3,20 +3,36 @@ #nullable enable -using System.Collections.Generic; -using System.Text.Json; +using System; +using System.Text; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Mcp.BuiltInTools; using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; namespace Azure.DataApiBuilder.Service.Tests.UnitTests { /// - /// Unit tests for AggregateRecordsTool's internal helper methods. - /// Covers validation paths, aggregation logic, and pagination behavior. + /// Unit tests for AggregateRecordsTool's SQL generation methods. + /// Validates that the tool builds correct SQL queries to push aggregation to the database. + /// Tests cover: alias computation, aggregate expressions, table references, + /// cursor decoding, and full SQL generation matching blog-documented patterns. /// [TestClass] public class AggregateRecordsToolTests { + /// + /// Creates a mock IQueryBuilder that wraps identifiers with square brackets (MsSql-style). + /// + private static Mock CreateMockQueryBuilder() + { + Mock mock = new(); + mock.Setup(qb => qb.QuoteIdentifier(It.IsAny())) + .Returns((string id) => $"[{id}]"); + return mock; + } + #region ComputeAlias tests [TestMethod] @@ -34,330 +50,125 @@ public void ComputeAlias_ReturnsExpectedAlias(string function, string field, str #endregion - #region PerformAggregation tests - no groupby - - private static JsonElement CreateRecordsArray(params double[] values) - { - var list = new List(); - foreach (double v in values) - { - list.Add(new Dictionary { ["value"] = v }); - } - - string json = JsonSerializer.Serialize(list); - return JsonDocument.Parse(json).RootElement.Clone(); - } - - private static JsonElement CreateEmptyArray() - { - return JsonDocument.Parse("[]").RootElement.Clone(); - } - - private static JsonElement CreateMixedArray() - { - // Records where some have 'value' (numeric) and some have 'category' (string) - string json = """ - [ - {"value": 10.0, "category": "A"}, - {"value": 20.0, "category": "B"}, - {"value": 10.0, "category": "A"} - ] - """; - return JsonDocument.Parse(json).RootElement.Clone(); - } + #region BuildAggregateExpression tests [TestMethod] - public void PerformAggregation_CountStar_NoGroupBy_ReturnsCount() + public void BuildAggregateExpression_CountStar_ReturnsCountStar() { - JsonElement records = CreateRecordsArray(1, 2, 3, 4, 5); - var result = AggregateRecordsTool.PerformAggregation( - records, "count", "*", distinct: false, new List(), null, null, "desc", "count"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(5.0, result[0]["count"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); + Assert.AreEqual("COUNT(*)", expr); } [TestMethod] - public void PerformAggregation_CountField_NoGroupBy_CountsNumericValues() + public void BuildAggregateExpression_SumField_ReturnsSumQuotedColumn() { - JsonElement records = CreateRecordsArray(10.0, 20.0, 30.0); - var result = AggregateRecordsTool.PerformAggregation( - records, "count", "value", distinct: false, new List(), null, null, "desc", "count_value"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(3.0, result[0]["count_value"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); + Assert.AreEqual("SUM([totalRevenue])", expr); } [TestMethod] - public void PerformAggregation_CountField_Distinct_CountsUniqueValues() + public void BuildAggregateExpression_AvgDistinct_ReturnsAvgDistinct() { - JsonElement records = CreateRecordsArray(10.0, 20.0, 10.0); - var result = AggregateRecordsTool.PerformAggregation( - records, "count", "value", distinct: true, new List(), null, null, "desc", "count_value"); - - Assert.AreEqual(1, result.Count); - // 10 and 20 are the distinct values - Assert.AreEqual(2.0, result[0]["count_value"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", true, false, qb.Object); + Assert.AreEqual("AVG(DISTINCT [price])", expr); } [TestMethod] - public void PerformAggregation_Avg_NoGroupBy_ReturnsAverage() + public void BuildAggregateExpression_CountDistinctField_ReturnsCountDistinct() { - JsonElement records = CreateRecordsArray(10.0, 20.0, 30.0); - var result = AggregateRecordsTool.PerformAggregation( - records, "avg", "value", distinct: false, new List(), null, null, "desc", "avg_value"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(20.0, result[0]["avg_value"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("count", "supplierId", true, false, qb.Object); + Assert.AreEqual("COUNT(DISTINCT [supplierId])", expr); } [TestMethod] - public void PerformAggregation_Sum_NoGroupBy_ReturnsSum() + public void BuildAggregateExpression_MinField_ReturnsMin() { - JsonElement records = CreateRecordsArray(10.0, 20.0, 30.0); - var result = AggregateRecordsTool.PerformAggregation( - records, "sum", "value", distinct: false, new List(), null, null, "desc", "sum_value"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(60.0, result[0]["sum_value"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("min", "price", false, false, qb.Object); + Assert.AreEqual("MIN([price])", expr); } [TestMethod] - public void PerformAggregation_Min_NoGroupBy_ReturnsMinimum() + public void BuildAggregateExpression_MaxField_ReturnsMax() { - JsonElement records = CreateRecordsArray(30.0, 10.0, 20.0); - var result = AggregateRecordsTool.PerformAggregation( - records, "min", "value", distinct: false, new List(), null, null, "desc", "min_value"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(10.0, result[0]["min_value"]); - } - - [TestMethod] - public void PerformAggregation_Max_NoGroupBy_ReturnsMaximum() - { - JsonElement records = CreateRecordsArray(30.0, 10.0, 20.0); - var result = AggregateRecordsTool.PerformAggregation( - records, "max", "value", distinct: false, new List(), null, null, "desc", "max_value"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(30.0, result[0]["max_value"]); - } - - [TestMethod] - public void PerformAggregation_EmptyRecords_ReturnsNullForNumericFunctions() - { - JsonElement records = CreateEmptyArray(); - var result = AggregateRecordsTool.PerformAggregation( - records, "avg", "value", distinct: false, new List(), null, null, "desc", "avg_value"); - - Assert.AreEqual(1, result.Count); - Assert.IsNull(result[0]["avg_value"]); - } - - [TestMethod] - public void PerformAggregation_EmptyRecords_CountStar_ReturnsZero() - { - JsonElement records = CreateEmptyArray(); - var result = AggregateRecordsTool.PerformAggregation( - records, "count", "*", distinct: false, new List(), null, null, "desc", "count"); - - Assert.AreEqual(1, result.Count); - Assert.AreEqual(0.0, result[0]["count"]); + Mock qb = CreateMockQueryBuilder(); + string expr = AggregateRecordsTool.BuildAggregateExpression("max", "price", false, false, qb.Object); + Assert.AreEqual("MAX([price])", expr); } #endregion - #region PerformAggregation tests - with groupby - - [TestMethod] - public void PerformAggregation_GroupBy_CountStar_ReturnsGroupCounts() - { - JsonElement records = CreateMixedArray(); - var groupby = new List { "category" }; - - var result = AggregateRecordsTool.PerformAggregation( - records, "count", "*", distinct: false, groupby, null, null, "desc", "count"); - - Assert.AreEqual(2, result.Count); - // desc ordering: A has 2, B has 1 - Assert.AreEqual("A", result[0]["category"]); - Assert.AreEqual(2.0, result[0]["count"]); - Assert.AreEqual("B", result[1]["category"]); - Assert.AreEqual(1.0, result[1]["count"]); - } + #region BuildQuotedTableRef tests [TestMethod] - public void PerformAggregation_GroupBy_Avg_ReturnsGroupAverages() + public void BuildQuotedTableRef_WithSchema_ReturnsSchemaQualified() { - JsonElement records = CreateMixedArray(); - var groupby = new List { "category" }; - - var result = AggregateRecordsTool.PerformAggregation( - records, "avg", "value", distinct: false, groupby, null, null, "asc", "avg_value"); - - Assert.AreEqual(2, result.Count); - // asc ordering by avg_value: B has 20, A has average (10+10)/2=10 - Assert.AreEqual("A", result[0]["category"]); - Assert.AreEqual(10.0, result[0]["avg_value"]); - Assert.AreEqual("B", result[1]["category"]); - Assert.AreEqual(20.0, result[1]["avg_value"]); + Mock qb = CreateMockQueryBuilder(); + DatabaseTable table = new("dbo", "Products"); + string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); + Assert.AreEqual("[dbo].[Products]", result); } [TestMethod] - public void PerformAggregation_GroupBy_Having_FiltersGroups() + public void BuildQuotedTableRef_WithoutSchema_ReturnsTableOnly() { - JsonElement records = CreateMixedArray(); - var groupby = new List { "category" }; - var havingOps = new Dictionary(System.StringComparer.OrdinalIgnoreCase) - { - ["gt"] = 1.0 // Keep groups with count > 1 - }; - - var result = AggregateRecordsTool.PerformAggregation( - records, "count", "*", distinct: false, groupby, havingOps, null, "desc", "count"); - - // Only category "A" (count=2) should pass count > 1 - Assert.AreEqual(1, result.Count); - Assert.AreEqual("A", result[0]["category"]); + Mock qb = CreateMockQueryBuilder(); + DatabaseTable table = new("", "Products"); + string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); + Assert.AreEqual("[Products]", result); } #endregion - #region Pagination tests + #region DecodeCursorOffset tests [TestMethod] - public void ApplyPagination_FirstPage_ReturnsItemsAndCursor() + public void DecodeCursorOffset_NullCursor_ReturnsZero() { - var allResults = new List> - { - new() { ["id"] = 1 }, - new() { ["id"] = 2 }, - new() { ["id"] = 3 }, - new() { ["id"] = 4 }, - new() { ["id"] = 5 } - }; - - var result = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); - - Assert.AreEqual(2, result.Items.Count); - Assert.AreEqual(1, result.Items[0]["id"]); - Assert.AreEqual(2, result.Items[1]["id"]); - Assert.IsTrue(result.HasNextPage); - Assert.IsNotNull(result.EndCursor); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); } [TestMethod] - public void ApplyPagination_SecondPage_ReturnsCorrectItems() + public void DecodeCursorOffset_EmptyCursor_ReturnsZero() { - var allResults = new List> - { - new() { ["id"] = 1 }, - new() { ["id"] = 2 }, - new() { ["id"] = 3 }, - new() { ["id"] = 4 }, - new() { ["id"] = 5 } - }; - - // Get first page to obtain cursor - var firstPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); - string? cursor = firstPage.EndCursor; - - // Use cursor to get second page - var secondPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: cursor); - - Assert.AreEqual(2, secondPage.Items.Count); - Assert.AreEqual(3, secondPage.Items[0]["id"]); - Assert.AreEqual(4, secondPage.Items[1]["id"]); - Assert.IsTrue(secondPage.HasNextPage); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("")); } [TestMethod] - public void ApplyPagination_LastPage_HasNextPageFalse() + public void DecodeCursorOffset_ValidBase64_ReturnsOffset() { - var allResults = new List> - { - new() { ["id"] = 1 }, - new() { ["id"] = 2 }, - new() { ["id"] = 3 } - }; - - // Get first page - var firstPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); - // Get last page - var lastPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: firstPage.EndCursor); - - Assert.AreEqual(1, lastPage.Items.Count); - Assert.AreEqual(3, lastPage.Items[0]["id"]); - Assert.IsFalse(lastPage.HasNextPage); + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); + Assert.AreEqual(5, AggregateRecordsTool.DecodeCursorOffset(cursor)); } [TestMethod] - public void ApplyPagination_TerminalCursor_ReturnsEmptyItems() + public void DecodeCursorOffset_InvalidBase64_ReturnsZero() { - var allResults = new List> - { - new() { ["id"] = 1 }, - new() { ["id"] = 2 } - }; - - // Get last page - var lastPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: null); - Assert.IsFalse(lastPage.HasNextPage); - Assert.IsNotNull(lastPage.EndCursor); - - // Using the terminal endCursor should return empty results - var beyondLastPage = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: lastPage.EndCursor); - Assert.AreEqual(0, beyondLastPage.Items.Count); - Assert.IsFalse(beyondLastPage.HasNextPage); - Assert.IsNull(beyondLastPage.EndCursor); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("not-valid-base64!!")); } [TestMethod] - public void ApplyPagination_InvalidCursor_StartsFromBeginning() + public void DecodeCursorOffset_NonNumericBase64_ReturnsZero() { - var allResults = new List> - { - new() { ["id"] = 1 }, - new() { ["id"] = 2 } - }; - - var result = AggregateRecordsTool.ApplyPagination(allResults, first: 2, after: "not-valid-base64!!"); - - // Should start from beginning - Assert.AreEqual(2, result.Items.Count); - Assert.AreEqual(1, result.Items[0]["id"]); + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("abc")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); } [TestMethod] - public void ApplyPagination_AfterWithoutFirst_IgnoresCursor() + public void DecodeCursorOffset_RoundTrip_FirstPage() { - // When first is not provided, after should not be used - // (ApplyPagination is only called when first is provided in ExecuteAsync) - var allResults = new List> - { - new() { ["id"] = 1 }, - new() { ["id"] = 2 }, - new() { ["id"] = 3 } - }; - - // Get page 1 cursor - var page1 = AggregateRecordsTool.ApplyPagination(allResults, first: 1, after: null); - Assert.IsNotNull(page1.EndCursor); - - // Call with first=3 and the cursor - should return 2 items from offset 1 - var result = AggregateRecordsTool.ApplyPagination(allResults, first: 3, after: page1.EndCursor); - Assert.AreEqual(2, result.Items.Count); - Assert.AreEqual(2, result.Items[0]["id"]); + int offset = 3; + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(offset.ToString())); + Assert.AreEqual(offset, AggregateRecordsTool.DecodeCursorOffset(cursor)); } #endregion - #region Validation tests (via ExecuteAsync return codes) - - // Note: Full ExecuteAsync validation tests require a full service provider setup - // with database, auth etc. The validation logic is tested below by examining - // the error condition directly since validation happens before any DB call. + #region Validation logic tests [TestMethod] [DataRow("avg", "Validation: avg with star field should be rejected")] @@ -366,8 +177,6 @@ public void ApplyPagination_AfterWithoutFirst_IgnoresCursor() [DataRow("max", "Validation: max with star field should be rejected")] public void ValidateFieldFunctionCompat_StarWithNumericFunction_IsInvalid(string function, string description) { - // Verify the business rule: only count can use field='*' - // This tests the condition used in ExecuteAsync without needing a full service provider bool isCountStar = function == "count" && "*" == "*"; bool isInvalidStarUsage = "*" == "*" && function != "count"; @@ -378,7 +187,6 @@ public void ValidateFieldFunctionCompat_StarWithNumericFunction_IsInvalid(string [TestMethod] public void ValidateFieldFunctionCompat_CountStar_IsValid() { - // count with field='*' should be valid bool isCountStar = "count" == "count" && "*" == "*"; Assert.IsTrue(isCountStar, "count(*) should be valid"); } @@ -386,8 +194,6 @@ public void ValidateFieldFunctionCompat_CountStar_IsValid() [TestMethod] public void ValidateDistinctCountStar_IsInvalid() { - // count(*) with distinct=true should be rejected - // Verify the condition used in ExecuteAsync bool isCountStar = "count" == "count" && "*" == "*"; bool distinct = true; @@ -398,7 +204,6 @@ public void ValidateDistinctCountStar_IsInvalid() [TestMethod] public void ValidateDistinctCountField_IsValid() { - // count(field) with distinct=true should be valid bool isCountStar = "count" == "count" && "userId" == "*"; bool distinct = true; @@ -407,5 +212,92 @@ public void ValidateDistinctCountField_IsValid() } #endregion + + #region Blog scenario tests - SQL generation patterns + + /// + /// Blog Example 1: Strategic customer importance + /// "Who is our most important customer based on total revenue?" + /// Expected: SELECT customerId, customerName, SUM(totalRevenue) ... GROUP BY ... ORDER BY ... DESC LIMIT 1 + /// + [TestMethod] + public void BlogScenario_StrategicCustomerImportance_SqlContainsGroupByAndOrderByDesc() + { + Mock qb = CreateMockQueryBuilder(); + + // Validate the aggregate expression + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); + Assert.AreEqual("SUM([totalRevenue])", aggExpr); + + // Validate the alias + string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); + Assert.AreEqual("sum_totalRevenue", alias); + } + + /// + /// Blog Example 2: Product discontinuation candidate + /// Lowest totalRevenue with orderby=asc, first=1 + /// + [TestMethod] + public void BlogScenario_ProductDiscontinuation_SqlContainsOrderByAsc() + { + Mock qb = CreateMockQueryBuilder(); + + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); + Assert.AreEqual("SUM([totalRevenue])", aggExpr); + + string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); + Assert.AreEqual("sum_totalRevenue", alias); + } + + /// + /// Blog Example 3: Forward-looking performance expectation + /// AVG quarterlyRevenue with HAVING gt 2000000 + /// + [TestMethod] + public void BlogScenario_QuarterlyPerformance_AvgWithHaving() + { + Mock qb = CreateMockQueryBuilder(); + + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("avg", "quarterlyRevenue", false, false, qb.Object); + Assert.AreEqual("AVG([quarterlyRevenue])", aggExpr); + + string alias = AggregateRecordsTool.ComputeAlias("avg", "quarterlyRevenue"); + Assert.AreEqual("avg_quarterlyRevenue", alias); + } + + /// + /// Blog Example 4: Revenue concentration across regions + /// SUM totalRevenue grouped by region and customerTier, HAVING gt 5000000 + /// + [TestMethod] + public void BlogScenario_RevenueConcentration_MultipleGroupByFields() + { + Mock qb = CreateMockQueryBuilder(); + + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); + Assert.AreEqual("SUM([totalRevenue])", aggExpr); + + string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); + Assert.AreEqual("sum_totalRevenue", alias); + } + + /// + /// Blog Example 5: Risk exposure by product line + /// SUM onHandValue grouped by productLine and warehouseRegion, HAVING gt 2500000 + /// + [TestMethod] + public void BlogScenario_RiskExposure_SumWithMultiGroupByAndHaving() + { + Mock qb = CreateMockQueryBuilder(); + + string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "onHandValue", false, false, qb.Object); + Assert.AreEqual("SUM([onHandValue])", aggExpr); + + string alias = AggregateRecordsTool.ComputeAlias("sum", "onHandValue"); + Assert.AreEqual("sum_onHandValue", alias); + } + + #endregion } } From d35088cfd74bba24e5ede6775abfa721de880342 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 17:51:07 -0700 Subject: [PATCH 15/32] Fix negative cursor offset and add first max validation - DecodeCursorOffset now rejects negative values (returns 0) - Add max validation for 'first' parameter (100000 limit) - Prevents integer overflow on first+1 and invalid SQL OFFSET - Add tests for both edge cases Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 7 ++++++- .../Mcp/AggregateRecordsToolTests.cs | 15 +++++++++++++++ .../UnitTests/AggregateRecordsToolTests.cs | 7 +++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 42f5187092..12ad9723c4 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -224,6 +224,11 @@ public async Task ExecuteAsync( { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger); } + + if (first > 100_000) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must not exceed 100000.", logger); + } } string? after = root.TryGetProperty("after", out JsonElement afterEl) ? afterEl.GetString() : null; @@ -686,7 +691,7 @@ internal static int DecodeCursorOffset(string? after) { byte[] bytes = Convert.FromBase64String(after); string decoded = Encoding.UTF8.GetString(bytes); - return int.TryParse(decoded, out int cursorOffset) ? cursorOffset : 0; + return int.TryParse(decoded, out int cursorOffset) && cursorOffset >= 0 ? cursorOffset : 0; } catch (FormatException) { diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 161d66b4e5..2d35b1892f 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -191,6 +191,21 @@ public async Task AggregateRecords_InvalidFunction_ReturnsInvalidArguments() Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("median")); } + [TestMethod] + public async Task AggregateRecords_FirstExceedsMax_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"first\": 200000, \"groupby\": [\"title\"]}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("100000")); + } + #endregion #region Alias Convention Tests diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs index ff458cc0b9..d44240f38b 100644 --- a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs @@ -166,6 +166,13 @@ public void DecodeCursorOffset_RoundTrip_FirstPage() Assert.AreEqual(offset, AggregateRecordsTool.DecodeCursorOffset(cursor)); } + [TestMethod] + public void DecodeCursorOffset_NegativeValue_ReturnsZero() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("-5")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + #endregion #region Validation logic tests From ef7fd0d2b4cde65b983ac8c56480517992202615 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:20:12 -0700 Subject: [PATCH 16/32] Refactor AggregateRecordsTool to use engine query builder pattern Replace custom SQL string building with engine's SqlQueryStructure + GroupByMetadata + queryBuilder.Build(structure) pattern. This uses the same AggregationColumn, AggregationOperation, and Predicate types that the engine's GraphQL aggregation path uses. Removed methods: BuildAggregateSql, BuildAggregateExpression, BuildQuotedTableRef, BuildWhereClause, BuildHavingClause, AppendPagination. These are now handled by the engine's query builder. Updated both test files to remove references to removed methods. All 69 aggregate tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 365 ++++++++---------- .../Mcp/AggregateRecordsToolTests.cs | 217 +---------- .../UnitTests/AggregateRecordsToolTests.cs | 136 +------ 3 files changed, 177 insertions(+), 541 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 12ad9723c4..bb2aa4efad 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -25,6 +25,7 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using static Azure.DataApiBuilder.Mcp.Model.McpEnums; +using static Azure.DataApiBuilder.Service.GraphQLBuilder.Sql.SchemaConverter; namespace Azure.DataApiBuilder.Mcp.BuiltInTools { @@ -379,20 +380,165 @@ public async Task ExecuteAsync( string alias = ComputeAlias(function, field); - // Build aggregate SQL query that pushes all computation to the database - string sql = BuildAggregateSql( - queryBuilder, structure, dbObject, function, backingField, distinct, isCountStar, - groupbyMapping, havingOps, havingIn, orderby, first, after, alias, databaseType); + // Clear default columns from FindRequestContext + structure.Columns.Clear(); + + // Add groupby columns as LabelledColumns and GroupByMetadata.Fields + foreach (var (entityField, backingCol) in groupbyMapping) + { + structure.Columns.Add(new LabelledColumn( + dbObject.SchemaName, dbObject.Name, backingCol, entityField, structure.SourceAlias)); + structure.GroupByMetadata.Fields[backingCol] = new Column( + dbObject.SchemaName, dbObject.Name, backingCol, structure.SourceAlias); + } + + // Build aggregation column using engine's AggregationColumn type + AggregationType aggType = Enum.Parse(function); + AggregationColumn aggColumn = isCountStar + ? new AggregationColumn("", "", "*", AggregationType.count, alias, false) + : new AggregationColumn(dbObject.SchemaName, dbObject.Name, backingField!, aggType, alias, distinct, structure.SourceAlias); + + // Build HAVING predicates using engine's Predicate model + List havingPredicates = new(); + if (havingOps != null) + { + foreach (var op in havingOps) + { + PredicateOperation predOp = op.Key.ToLowerInvariant() switch + { + "eq" => PredicateOperation.Equal, + "neq" => PredicateOperation.NotEqual, + "gt" => PredicateOperation.GreaterThan, + "gte" => PredicateOperation.GreaterThanOrEqual, + "lt" => PredicateOperation.LessThan, + "lte" => PredicateOperation.LessThanOrEqual, + _ => throw new ArgumentException($"Invalid having operator: {op.Key}") + }; + string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(paramName, new DbConnectionParam(op.Value)); + havingPredicates.Add(new Predicate( + new PredicateOperand(aggColumn), + predOp, + new PredicateOperand(paramName))); + } + } + + if (havingIn != null && havingIn.Count > 0) + { + List inParams = new(); + foreach (double val in havingIn) + { + string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(paramName, new DbConnectionParam(val)); + inParams.Add(paramName); + } + + havingPredicates.Add(new Predicate( + new PredicateOperand(aggColumn), + PredicateOperation.IN, + new PredicateOperand($"({string.Join(", ", inParams)})"))); + } + + // Combine multiple HAVING predicates with AND + Predicate? combinedHaving = null; + foreach (var pred in havingPredicates) + { + combinedHaving = combinedHaving == null + ? pred + : new Predicate(new PredicateOperand(combinedHaving), PredicateOperation.AND, new PredicateOperand(pred)); + } + + structure.GroupByMetadata.Aggregations.Add( + new AggregationOperation(aggColumn, having: combinedHaving != null ? new List { combinedHaving } : null)); + structure.GroupByMetadata.RequestedAggregations = true; + + // Clear default OrderByColumns (PK-based) + structure.OrderByColumns.Clear(); + + // Set pagination limit if using first + if (first.HasValue && groupbyMapping.Count > 0) + { + structure.IsListQuery = true; + } + + // Use engine's query builder to generate SQL + string sql = queryBuilder.Build(structure); + + // For groupby queries: add ORDER BY aggregate expression before FOR JSON PATH + if (groupbyMapping.Count > 0) + { + string direction = orderby.Equals("asc", StringComparison.OrdinalIgnoreCase) ? "ASC" : "DESC"; + string orderByAggExpr = isCountStar + ? "COUNT(*)" + : distinct + ? $"{function.ToUpperInvariant()}(DISTINCT {queryBuilder.QuoteIdentifier(structure.SourceAlias)}.{queryBuilder.QuoteIdentifier(backingField!)})" + : $"{function.ToUpperInvariant()}({queryBuilder.QuoteIdentifier(structure.SourceAlias)}.{queryBuilder.QuoteIdentifier(backingField!)})"; + string orderByClause = $" ORDER BY {orderByAggExpr} {direction}"; + + // Insert ORDER BY before FOR JSON PATH (MsSql/DWSQL) or before LIMIT (PG/MySQL) + int insertIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); + if (insertIdx < 0) + { + insertIdx = sql.IndexOf(" LIMIT ", StringComparison.OrdinalIgnoreCase); + } + + if (insertIdx > 0) + { + sql = sql.Insert(insertIdx, orderByClause); + } + else + { + sql += orderByClause; + } + + // Add pagination (OFFSET/FETCH or LIMIT/OFFSET) for grouped results + if (first.HasValue) + { + int offset = DecodeCursorOffset(after); + int fetchCount = first.Value + 1; + string offsetParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(offsetParam, new DbConnectionParam(offset)); + string limitParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(limitParam, new DbConnectionParam(fetchCount)); + + int paginationIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); + string paginationClause; + if (databaseType == DatabaseType.MSSQL || databaseType == DatabaseType.DWSQL) + { + paginationClause = $" OFFSET {offsetParam} ROWS FETCH NEXT {limitParam} ROWS ONLY"; + } + else + { + paginationClause = $" LIMIT {limitParam} OFFSET {offsetParam}"; + } + + if (paginationIdx > 0) + { + sql = sql.Insert(paginationIdx, paginationClause); + } + else + { + sql += paginationClause; + } + } + } // Execute the SQL aggregate query against the database cancellationToken.ThrowIfCancellationRequested(); - JsonArray? resultArray = await queryExecutor.ExecuteQueryAsync( + JsonDocument? queryResult = await queryExecutor.ExecuteQueryAsync( sql, structure.Parameters, - queryExecutor.GetJsonArrayAsync, + queryExecutor.GetJsonResultAsync, dataSourceName, httpContext); + // Parse result + JsonArray? resultArray = null; + if (queryResult != null) + { + resultArray = JsonSerializer.Deserialize(queryResult.RootElement.GetRawText()); + } + // Format and return results if (first.HasValue && groupby.Count > 0) { @@ -469,213 +615,6 @@ internal static string ComputeAlias(string function, string field) return $"{function}_{field}"; } - /// - /// Builds a SQL aggregate query that pushes all computation to the database. - /// Generates SELECT {aggExpr} FROM {table} WHERE ... GROUP BY ... HAVING ... ORDER BY ... - /// with proper parameterization and identifier quoting. - /// - internal static string BuildAggregateSql( - IQueryBuilder queryBuilder, - SqlQueryStructure structure, - DatabaseObject dbObject, - string function, - string? backingField, - bool distinct, - bool isCountStar, - List<(string entityField, string backingCol)> groupbyMapping, - Dictionary? havingOps, - List? havingIn, - string orderby, - int? first, - string? after, - string alias, - DatabaseType databaseType) - { - string aggExpr = BuildAggregateExpression(function, backingField, distinct, isCountStar, queryBuilder); - string quotedTableRef = BuildQuotedTableRef(dbObject, queryBuilder); - - StringBuilder sql = new(); - - // SELECT - sql.Append("SELECT "); - foreach ((string entityField, string backingCol) in groupbyMapping) - { - sql.Append($"{queryBuilder.QuoteIdentifier(backingCol)} AS {queryBuilder.QuoteIdentifier(entityField)}, "); - } - - sql.Append($"{aggExpr} AS {queryBuilder.QuoteIdentifier(alias)}"); - - // FROM - sql.Append($" FROM {quotedTableRef}"); - - // WHERE (OData filter predicates + DB policy predicates) - string? whereClause = BuildWhereClause(structure); - if (!string.IsNullOrEmpty(whereClause)) - { - sql.Append($" WHERE {whereClause}"); - } - - // GROUP BY - if (groupbyMapping.Count > 0) - { - string groupByClause = string.Join(", ", groupbyMapping.Select(g => queryBuilder.QuoteIdentifier(g.backingCol))); - sql.Append($" GROUP BY {groupByClause}"); - } - - // HAVING - string? havingClause = BuildHavingClause(aggExpr, havingOps, havingIn, structure); - if (!string.IsNullOrEmpty(havingClause)) - { - sql.Append($" HAVING {havingClause}"); - } - - // ORDER BY (only with groupby) - if (groupbyMapping.Count > 0) - { - string direction = orderby.Equals("asc", StringComparison.OrdinalIgnoreCase) ? "ASC" : "DESC"; - sql.Append($" ORDER BY {aggExpr} {direction}"); - } - - // PAGINATION (only with groupby and first) - if (first.HasValue && groupbyMapping.Count > 0) - { - int offset = DecodeCursorOffset(after); - int fetchCount = first.Value + 1; // Fetch one extra row to detect hasNextPage - AppendPagination(sql, offset, fetchCount, structure, databaseType); - } - - return sql.ToString(); - } - - /// - /// Builds the SQL aggregate expression (e.g., COUNT(*), SUM(DISTINCT [column])). - /// - internal static string BuildAggregateExpression( - string function, string? backingField, bool distinct, bool isCountStar, IQueryBuilder queryBuilder) - { - if (isCountStar) - { - return "COUNT(*)"; - } - - string quotedCol = queryBuilder.QuoteIdentifier(backingField!); - string func = function.ToUpperInvariant(); - - return distinct ? $"{func}(DISTINCT {quotedCol})" : $"{func}({quotedCol})"; - } - - /// - /// Builds a properly quoted table reference from a DatabaseObject. - /// - internal static string BuildQuotedTableRef(DatabaseObject dbObject, IQueryBuilder queryBuilder) - { - return string.IsNullOrEmpty(dbObject.SchemaName) - ? queryBuilder.QuoteIdentifier(dbObject.Name) - : $"{queryBuilder.QuoteIdentifier(dbObject.SchemaName)}.{queryBuilder.QuoteIdentifier(dbObject.Name)}"; - } - - /// - /// Builds the WHERE clause from OData filter predicates and DB policy predicates. - /// Both are required for correct and secure query execution. - /// - internal static string? BuildWhereClause(SqlQueryStructure structure) - { - List clauses = new(); - - if (!string.IsNullOrEmpty(structure.FilterPredicates)) - { - clauses.Add(structure.FilterPredicates); - } - - string? dbPolicy = structure.GetDbPolicyForOperation(EntityActionOperation.Read); - if (!string.IsNullOrEmpty(dbPolicy)) - { - clauses.Add(dbPolicy); - } - - return clauses.Count > 0 ? string.Join(" AND ", clauses) : null; - } - - /// - /// Builds the HAVING clause from having operator conditions and IN list. - /// Adds parameterized values to the structure's Parameters dictionary. - /// - internal static string? BuildHavingClause( - string aggExpr, - Dictionary? havingOps, - List? havingIn, - SqlQueryStructure structure) - { - if (havingOps == null && havingIn == null) - { - return null; - } - - List conditions = new(); - - if (havingOps != null) - { - foreach (KeyValuePair op in havingOps) - { - string sqlOp = op.Key.ToLowerInvariant() switch - { - "eq" => "=", - "neq" => "<>", - "gt" => ">", - "gte" => ">=", - "lt" => "<", - "lte" => "<=", - _ => throw new ArgumentException($"Invalid having operator: {op.Key}") - }; - - string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); - structure.Parameters.Add(paramName, new DbConnectionParam(op.Value)); - conditions.Add($"{aggExpr} {sqlOp} {paramName}"); - } - } - - if (havingIn != null && havingIn.Count > 0) - { - List inParams = new(); - foreach (double val in havingIn) - { - string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); - structure.Parameters.Add(paramName, new DbConnectionParam(val)); - inParams.Add(paramName); - } - - conditions.Add($"{aggExpr} IN ({string.Join(", ", inParams)})"); - } - - return conditions.Count > 0 ? string.Join(" AND ", conditions) : null; - } - - /// - /// Appends database-specific pagination syntax to the SQL query. - /// MsSql/DWSQL: OFFSET ... ROWS FETCH NEXT ... ROWS ONLY - /// PostgreSQL/MySQL: LIMIT ... OFFSET ... - /// - internal static void AppendPagination( - StringBuilder sql, int offset, int fetchCount, - SqlQueryStructure structure, DatabaseType databaseType) - { - string offsetParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); - structure.Parameters.Add(offsetParam, new DbConnectionParam(offset)); - - string limitParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); - structure.Parameters.Add(limitParam, new DbConnectionParam(fetchCount)); - - if (databaseType == DatabaseType.MSSQL || databaseType == DatabaseType.DWSQL) - { - sql.Append($" OFFSET {offsetParam} ROWS FETCH NEXT {limitParam} ROWS ONLY"); - } - else - { - // PostgreSQL, MySQL - sql.Append($" LIMIT {limitParam} OFFSET {offsetParam}"); - } - } - /// /// Decodes a base64-encoded cursor string to an integer offset. /// Returns 0 if the cursor is null, empty, or invalid. diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 2d35b1892f..dd94ae593d 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -10,11 +10,9 @@ using System.Threading; using System.Threading.Tasks; using Azure.DataApiBuilder.Auth; -using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; -using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Mcp.BuiltInTools; using Azure.DataApiBuilder.Mcp.Model; using Microsoft.AspNetCore.Http; @@ -248,138 +246,6 @@ public void ComputeAlias_MaxField_ReturnsFunctionField() #endregion - #region SQL Expression Generation Tests - - /// - /// Creates a mock IQueryBuilder that wraps identifiers with square brackets (MsSql-style). - /// - private static Mock CreateMockQueryBuilder() - { - Mock mock = new(); - mock.Setup(qb => qb.QuoteIdentifier(It.IsAny())) - .Returns((string id) => $"[{id}]"); - return mock; - } - - [TestMethod] - public void BuildAggregateExpression_CountStar_GeneratesCountStarSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("COUNT(*)", expr); - } - - [TestMethod] - public void BuildAggregateExpression_Avg_GeneratesAvgSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", false, false, qb.Object); - Assert.AreEqual("AVG([price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_Sum_GeneratesSumSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", false, false, qb.Object); - Assert.AreEqual("SUM([price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_Min_GeneratesMinSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("min", "price", false, false, qb.Object); - Assert.AreEqual("MIN([price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_Max_GeneratesMaxSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("max", "price", false, false, qb.Object); - Assert.AreEqual("MAX([price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_CountDistinct_GeneratesCountDistinctSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", "supplierId", true, false, qb.Object); - Assert.AreEqual("COUNT(DISTINCT [supplierId])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_AvgDistinct_GeneratesAvgDistinctSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", true, false, qb.Object); - Assert.AreEqual("AVG(DISTINCT [price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_SumDistinct_GeneratesSumDistinctSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", true, false, qb.Object); - Assert.AreEqual("SUM(DISTINCT [price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_CountField_GeneratesCountFieldSql() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", "id", false, false, qb.Object); - Assert.AreEqual("COUNT([id])", expr); - } - - [TestMethod] - public void BuildQuotedTableRef_WithSchema_GeneratesSchemaQualifiedRef() - { - Mock qb = CreateMockQueryBuilder(); - DatabaseTable table = new("dbo", "Products"); - string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); - Assert.AreEqual("[dbo].[Products]", result); - } - - [TestMethod] - public void BuildQuotedTableRef_WithoutSchema_GeneratesTableOnlyRef() - { - Mock qb = CreateMockQueryBuilder(); - DatabaseTable table = new("", "Products"); - string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); - Assert.AreEqual("[Products]", result); - } - - [TestMethod] - public void BuildAggregateExpression_GroupByScenario_ExpressionAndQuotingCorrect() - { - Mock qb = CreateMockQueryBuilder(); - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", false, false, qb.Object); - Assert.AreEqual("SUM([price])", aggExpr); - Assert.AreEqual("[category]", qb.Object.QuoteIdentifier("category")); - } - - [TestMethod] - public void BuildAggregateExpression_MultipleGroupByFields_AllFieldsQuotedCorrectly() - { - Mock qb = CreateMockQueryBuilder(); - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "price", false, false, qb.Object); - Assert.AreEqual("SUM([price])", aggExpr); - Assert.AreEqual("[cat]", qb.Object.QuoteIdentifier("cat")); - Assert.AreEqual("[region]", qb.Object.QuoteIdentifier("region")); - } - - [TestMethod] - public void BuildAggregateExpression_EmptyDataset_ExpressionStillValid() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", false, false, qb.Object); - Assert.AreEqual("AVG([price])", expr); - } - - #endregion - #region Cursor and Pagination Tests [TestMethod] @@ -568,21 +434,17 @@ public void TimeoutErrorMessage_IncludesEntityName() #endregion - #region Spec Example SQL Pattern Tests + #region Spec Example Tests /// /// Spec Example 1: "How many products are there?" - /// COUNT(*) - expects alias "count" and expression COUNT(*) + /// COUNT(*) - expects alias "count" /// [TestMethod] - public void SpecExample01_CountStar_GeneratesCorrectSqlPattern() + public void SpecExample01_CountStar_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("count", alias); - Assert.AreEqual("COUNT(*)", expr); } /// @@ -590,14 +452,10 @@ public void SpecExample01_CountStar_GeneratesCorrectSqlPattern() /// AVG(unitPrice) with filter /// [TestMethod] - public void SpecExample02_AvgWithFilter_GeneratesCorrectSqlPattern() + public void SpecExample02_AvgWithFilter_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", false, false, qb.Object); - Assert.AreEqual("avg_unitPrice", alias); - Assert.AreEqual("AVG([unitPrice])", expr); } /// @@ -605,15 +463,10 @@ public void SpecExample02_AvgWithFilter_GeneratesCorrectSqlPattern() /// COUNT(*) GROUP BY categoryName HAVING gt 20 /// [TestMethod] - public void SpecExample03_CountGroupByHavingGt_GeneratesCorrectSqlPattern() + public void SpecExample03_CountGroupByHavingGt_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("count", alias); - Assert.AreEqual("COUNT(*)", expr); - Assert.AreEqual("[categoryName]", qb.Object.QuoteIdentifier("categoryName")); } /// @@ -621,14 +474,10 @@ public void SpecExample03_CountGroupByHavingGt_GeneratesCorrectSqlPattern() /// SUM(unitPrice) GROUP BY categoryName HAVING gte 500 AND lte 10000 /// [TestMethod] - public void SpecExample04_SumFilterGroupByHavingRange_GeneratesCorrectSqlPattern() + public void SpecExample04_SumFilterGroupByHavingRange_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); - string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "unitPrice", false, false, qb.Object); - Assert.AreEqual("sum_unitPrice", alias); - Assert.AreEqual("SUM([unitPrice])", expr); } /// @@ -636,14 +485,10 @@ public void SpecExample04_SumFilterGroupByHavingRange_GeneratesCorrectSqlPattern /// COUNT(DISTINCT supplierId) /// [TestMethod] - public void SpecExample05_CountDistinct_GeneratesCorrectSqlPattern() + public void SpecExample05_CountDistinct_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "supplierId"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", "supplierId", true, false, qb.Object); - Assert.AreEqual("count_supplierId", alias); - Assert.AreEqual("COUNT(DISTINCT [supplierId])", expr); } /// @@ -651,14 +496,10 @@ public void SpecExample05_CountDistinct_GeneratesCorrectSqlPattern() /// COUNT(*) GROUP BY categoryName HAVING IN (5, 10) /// [TestMethod] - public void SpecExample06_CountGroupByHavingIn_GeneratesCorrectSqlPattern() + public void SpecExample06_CountGroupByHavingIn_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("count", alias); - Assert.AreEqual("COUNT(*)", expr); } /// @@ -666,14 +507,10 @@ public void SpecExample06_CountGroupByHavingIn_GeneratesCorrectSqlPattern() /// AVG(DISTINCT unitPrice) GROUP BY categoryName HAVING gt 25 /// [TestMethod] - public void SpecExample07_AvgDistinctGroupByHavingGt_GeneratesCorrectSqlPattern() + public void SpecExample07_AvgDistinctGroupByHavingGt_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", true, false, qb.Object); - Assert.AreEqual("avg_unitPrice", alias); - Assert.AreEqual("AVG(DISTINCT [unitPrice])", expr); } /// @@ -681,14 +518,10 @@ public void SpecExample07_AvgDistinctGroupByHavingGt_GeneratesCorrectSqlPattern( /// COUNT(*) GROUP BY categoryName ORDER BY DESC /// [TestMethod] - public void SpecExample08_CountGroupByOrderByDesc_GeneratesCorrectSqlPattern() + public void SpecExample08_CountGroupByOrderByDesc_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("count", alias); - Assert.AreEqual("COUNT(*)", expr); } /// @@ -696,14 +529,10 @@ public void SpecExample08_CountGroupByOrderByDesc_GeneratesCorrectSqlPattern() /// AVG(unitPrice) GROUP BY categoryName ORDER BY ASC /// [TestMethod] - public void SpecExample09_AvgGroupByOrderByAsc_GeneratesCorrectSqlPattern() + public void SpecExample09_AvgGroupByOrderByAsc_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", false, false, qb.Object); - Assert.AreEqual("avg_unitPrice", alias); - Assert.AreEqual("AVG([unitPrice])", expr); } /// @@ -711,14 +540,10 @@ public void SpecExample09_AvgGroupByOrderByAsc_GeneratesCorrectSqlPattern() /// SUM(unitPrice) GROUP BY categoryName HAVING gt 500 ORDER BY DESC /// [TestMethod] - public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_GeneratesCorrectSqlPattern() + public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); - string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "unitPrice", false, false, qb.Object); - Assert.AreEqual("sum_unitPrice", alias); - Assert.AreEqual("SUM([unitPrice])", expr); } /// @@ -726,14 +551,10 @@ public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_GeneratesCorrectSq /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 /// [TestMethod] - public void SpecExample11_CountGroupByOrderByDescFirst5_GeneratesCorrectSqlPattern() + public void SpecExample11_CountGroupByOrderByDescFirst5_CorrectAliasAndCursor() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("count", alias); - Assert.AreEqual("COUNT(*)", expr); Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); } @@ -742,18 +563,14 @@ public void SpecExample11_CountGroupByOrderByDescFirst5_GeneratesCorrectSqlPatte /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 AFTER cursor /// [TestMethod] - public void SpecExample12_CountGroupByOrderByDescFirst5After_GeneratesCorrectSqlPattern() + public void SpecExample12_CountGroupByOrderByDescFirst5After_CorrectCursorDecode() { string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); int offset = AggregateRecordsTool.DecodeCursorOffset(cursor); Assert.AreEqual(5, offset); - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("count", "*"); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("count", alias); - Assert.AreEqual("COUNT(*)", expr); } /// @@ -761,14 +578,10 @@ public void SpecExample12_CountGroupByOrderByDescFirst5After_GeneratesCorrectSql /// AVG(unitPrice) GROUP BY categoryName ORDER BY DESC FIRST 3 /// [TestMethod] - public void SpecExample13_AvgGroupByOrderByDescFirst3_GeneratesCorrectSqlPattern() + public void SpecExample13_AvgGroupByOrderByDescFirst3_CorrectAlias() { - Mock qb = CreateMockQueryBuilder(); string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "unitPrice", false, false, qb.Object); - Assert.AreEqual("avg_unitPrice", alias); - Assert.AreEqual("AVG([unitPrice])", expr); } #endregion diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs index d44240f38b..c291d87660 100644 --- a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs @@ -5,34 +5,19 @@ using System; using System.Text; -using Azure.DataApiBuilder.Config.DatabasePrimitives; -using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Mcp.BuiltInTools; using Microsoft.VisualStudio.TestTools.UnitTesting; -using Moq; namespace Azure.DataApiBuilder.Service.Tests.UnitTests { /// - /// Unit tests for AggregateRecordsTool's SQL generation methods. - /// Validates that the tool builds correct SQL queries to push aggregation to the database. - /// Tests cover: alias computation, aggregate expressions, table references, - /// cursor decoding, and full SQL generation matching blog-documented patterns. + /// Unit tests for AggregateRecordsTool helper methods. + /// Validates alias computation, cursor decoding, and input validation logic. + /// SQL generation is delegated to the engine's query builder (GroupByMetadata/AggregationColumn). /// [TestClass] public class AggregateRecordsToolTests { - /// - /// Creates a mock IQueryBuilder that wraps identifiers with square brackets (MsSql-style). - /// - private static Mock CreateMockQueryBuilder() - { - Mock mock = new(); - mock.Setup(qb => qb.QuoteIdentifier(It.IsAny())) - .Returns((string id) => $"[{id}]"); - return mock; - } - #region ComputeAlias tests [TestMethod] @@ -50,80 +35,6 @@ public void ComputeAlias_ReturnsExpectedAlias(string function, string field, str #endregion - #region BuildAggregateExpression tests - - [TestMethod] - public void BuildAggregateExpression_CountStar_ReturnsCountStar() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", null, false, true, qb.Object); - Assert.AreEqual("COUNT(*)", expr); - } - - [TestMethod] - public void BuildAggregateExpression_SumField_ReturnsSumQuotedColumn() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); - Assert.AreEqual("SUM([totalRevenue])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_AvgDistinct_ReturnsAvgDistinct() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("avg", "price", true, false, qb.Object); - Assert.AreEqual("AVG(DISTINCT [price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_CountDistinctField_ReturnsCountDistinct() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("count", "supplierId", true, false, qb.Object); - Assert.AreEqual("COUNT(DISTINCT [supplierId])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_MinField_ReturnsMin() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("min", "price", false, false, qb.Object); - Assert.AreEqual("MIN([price])", expr); - } - - [TestMethod] - public void BuildAggregateExpression_MaxField_ReturnsMax() - { - Mock qb = CreateMockQueryBuilder(); - string expr = AggregateRecordsTool.BuildAggregateExpression("max", "price", false, false, qb.Object); - Assert.AreEqual("MAX([price])", expr); - } - - #endregion - - #region BuildQuotedTableRef tests - - [TestMethod] - public void BuildQuotedTableRef_WithSchema_ReturnsSchemaQualified() - { - Mock qb = CreateMockQueryBuilder(); - DatabaseTable table = new("dbo", "Products"); - string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); - Assert.AreEqual("[dbo].[Products]", result); - } - - [TestMethod] - public void BuildQuotedTableRef_WithoutSchema_ReturnsTableOnly() - { - Mock qb = CreateMockQueryBuilder(); - DatabaseTable table = new("", "Products"); - string result = AggregateRecordsTool.BuildQuotedTableRef(table, qb.Object); - Assert.AreEqual("[Products]", result); - } - - #endregion - #region DecodeCursorOffset tests [TestMethod] @@ -220,23 +131,16 @@ public void ValidateDistinctCountField_IsValid() #endregion - #region Blog scenario tests - SQL generation patterns + #region Blog scenario tests - alias and type validation /// /// Blog Example 1: Strategic customer importance /// "Who is our most important customer based on total revenue?" - /// Expected: SELECT customerId, customerName, SUM(totalRevenue) ... GROUP BY ... ORDER BY ... DESC LIMIT 1 + /// SUM(totalRevenue) grouped by customerId, customerName, ORDER BY DESC, FIRST 1 /// [TestMethod] - public void BlogScenario_StrategicCustomerImportance_SqlContainsGroupByAndOrderByDesc() + public void BlogScenario_StrategicCustomerImportance_AliasAndTypeCorrect() { - Mock qb = CreateMockQueryBuilder(); - - // Validate the aggregate expression - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); - Assert.AreEqual("SUM([totalRevenue])", aggExpr); - - // Validate the alias string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); Assert.AreEqual("sum_totalRevenue", alias); } @@ -246,13 +150,8 @@ public void BlogScenario_StrategicCustomerImportance_SqlContainsGroupByAndOrderB /// Lowest totalRevenue with orderby=asc, first=1 /// [TestMethod] - public void BlogScenario_ProductDiscontinuation_SqlContainsOrderByAsc() + public void BlogScenario_ProductDiscontinuation_AliasAndTypeCorrect() { - Mock qb = CreateMockQueryBuilder(); - - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); - Assert.AreEqual("SUM([totalRevenue])", aggExpr); - string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); Assert.AreEqual("sum_totalRevenue", alias); } @@ -262,13 +161,8 @@ public void BlogScenario_ProductDiscontinuation_SqlContainsOrderByAsc() /// AVG quarterlyRevenue with HAVING gt 2000000 /// [TestMethod] - public void BlogScenario_QuarterlyPerformance_AvgWithHaving() + public void BlogScenario_QuarterlyPerformance_AliasAndTypeCorrect() { - Mock qb = CreateMockQueryBuilder(); - - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("avg", "quarterlyRevenue", false, false, qb.Object); - Assert.AreEqual("AVG([quarterlyRevenue])", aggExpr); - string alias = AggregateRecordsTool.ComputeAlias("avg", "quarterlyRevenue"); Assert.AreEqual("avg_quarterlyRevenue", alias); } @@ -278,13 +172,8 @@ public void BlogScenario_QuarterlyPerformance_AvgWithHaving() /// SUM totalRevenue grouped by region and customerTier, HAVING gt 5000000 /// [TestMethod] - public void BlogScenario_RevenueConcentration_MultipleGroupByFields() + public void BlogScenario_RevenueConcentration_AliasAndTypeCorrect() { - Mock qb = CreateMockQueryBuilder(); - - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "totalRevenue", false, false, qb.Object); - Assert.AreEqual("SUM([totalRevenue])", aggExpr); - string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); Assert.AreEqual("sum_totalRevenue", alias); } @@ -294,13 +183,8 @@ public void BlogScenario_RevenueConcentration_MultipleGroupByFields() /// SUM onHandValue grouped by productLine and warehouseRegion, HAVING gt 2500000 /// [TestMethod] - public void BlogScenario_RiskExposure_SumWithMultiGroupByAndHaving() + public void BlogScenario_RiskExposure_AliasAndTypeCorrect() { - Mock qb = CreateMockQueryBuilder(); - - string aggExpr = AggregateRecordsTool.BuildAggregateExpression("sum", "onHandValue", false, false, qb.Object); - Assert.AreEqual("SUM([onHandValue])", aggExpr); - string alias = AggregateRecordsTool.ComputeAlias("sum", "onHandValue"); Assert.AreEqual("sum_onHandValue", alias); } From c5920c7bb33714f013907f8d0820d4abb189945a Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:28:49 -0700 Subject: [PATCH 17/32] Fix SQL generation bugs in AggregateRecordsTool - Fix COUNT(*): Use primary key column (PK NOT NULL, so COUNT(pk) COUNT(*)) instead of AggregationColumn with empty schema/table/'*' which produced invalid SQL like count([].[*]) - Fix TOP + OFFSET/FETCH conflict: Remove TOP N when pagination is used since SQL Server forbids both in the same query - Add database type validation: Return error for PostgreSQL/MySQL/ CosmosDB since engine only supports aggregation for MsSql/DWSQL - Add HAVING validation: Reject having without groupby - Add tests for star-field-with-avg, distinct-count-star, and having-without-groupby validation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 91 +++++++++++-------- .../Mcp/AggregateRecordsToolTests.cs | 45 +++++++++ 2 files changed, 100 insertions(+), 36 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index bb2aa4efad..2509629231 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -5,6 +5,7 @@ using System.Text; using System.Text.Json; using System.Text.Json.Nodes; +using System.Text.RegularExpressions; using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -251,6 +252,12 @@ public async Task ExecuteAsync( List? havingIn = null; if (root.TryGetProperty("having", out JsonElement havingEl) && havingEl.ValueKind == JsonValueKind.Object) { + if (groupby.Count == 0) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'having' parameter requires 'groupby' to be specified. HAVING filters groups after aggregation.", logger); + } + havingOps = new Dictionary(StringComparer.OrdinalIgnoreCase); foreach (JsonProperty prop in havingEl.EnumerateObject()) { @@ -350,6 +357,14 @@ public async Task ExecuteAsync( // Get database-specific components DatabaseType databaseType = runtimeConfig.GetDataSourceFromDataSourceName(dataSourceName).DatabaseType; + + // Aggregation is only supported for MsSql/DWSQL (matching engine's GraphQL aggregation support) + if (databaseType != DatabaseType.MSSQL && databaseType != DatabaseType.DWSQL) + { + return McpResponseBuilder.BuildErrorResult(toolName, "UnsupportedDatabase", + $"Aggregation is not supported for database type '{databaseType}'. Aggregation is only available for Azure SQL, SQL Server, and SQL Data Warehouse.", logger); + } + IAbstractQueryManagerFactory queryManagerFactory = serviceProvider.GetRequiredService(); IQueryBuilder queryBuilder = queryManagerFactory.GetQueryBuilder(databaseType); IQueryExecutor queryExecutor = queryManagerFactory.GetQueryExecutor(databaseType); @@ -364,6 +379,17 @@ public async Task ExecuteAsync( $"Field '{field}' not found for entity '{entityName}'.", logger); } } + else + { + // For COUNT(*), use primary key column since PK is always NOT NULL, + // making COUNT(pk) equivalent to COUNT(*). The engine's Build(AggregationColumn) + // does not support "*" as a column name (it would produce invalid SQL like count([].[*])). + SourceDefinition sourceDefinition = sqlMetadataProvider.GetSourceDefinition(entityName); + if (sourceDefinition.PrimaryKey.Count > 0) + { + backingField = sourceDefinition.PrimaryKey[0]; + } + } // Resolve backing column names for groupby fields List<(string entityField, string backingCol)> groupbyMapping = new(); @@ -392,11 +418,11 @@ public async Task ExecuteAsync( dbObject.SchemaName, dbObject.Name, backingCol, structure.SourceAlias); } - // Build aggregation column using engine's AggregationColumn type + // Build aggregation column using engine's AggregationColumn type. + // For COUNT(*), we use the primary key column (PK is always NOT NULL, so COUNT(pk) ≡ COUNT(*)). AggregationType aggType = Enum.Parse(function); - AggregationColumn aggColumn = isCountStar - ? new AggregationColumn("", "", "*", AggregationType.count, alias, false) - : new AggregationColumn(dbObject.SchemaName, dbObject.Name, backingField!, aggType, alias, distinct, structure.SourceAlias); + AggregationColumn aggColumn = new( + dbObject.SchemaName, dbObject.Name, backingField!, aggType, alias, distinct, structure.SourceAlias); // Build HAVING predicates using engine's Predicate model List havingPredicates = new(); @@ -464,36 +490,20 @@ public async Task ExecuteAsync( // Use engine's query builder to generate SQL string sql = queryBuilder.Build(structure); - // For groupby queries: add ORDER BY aggregate expression before FOR JSON PATH + // For groupby queries: add ORDER BY aggregate expression and pagination if (groupbyMapping.Count > 0) { string direction = orderby.Equals("asc", StringComparison.OrdinalIgnoreCase) ? "ASC" : "DESC"; - string orderByAggExpr = isCountStar - ? "COUNT(*)" - : distinct - ? $"{function.ToUpperInvariant()}(DISTINCT {queryBuilder.QuoteIdentifier(structure.SourceAlias)}.{queryBuilder.QuoteIdentifier(backingField!)})" - : $"{function.ToUpperInvariant()}({queryBuilder.QuoteIdentifier(structure.SourceAlias)}.{queryBuilder.QuoteIdentifier(backingField!)})"; + string quotedCol = $"{queryBuilder.QuoteIdentifier(structure.SourceAlias)}.{queryBuilder.QuoteIdentifier(backingField!)}"; + string orderByAggExpr = distinct + ? $"{function.ToUpperInvariant()}(DISTINCT {quotedCol})" + : $"{function.ToUpperInvariant()}({quotedCol})"; string orderByClause = $" ORDER BY {orderByAggExpr} {direction}"; - // Insert ORDER BY before FOR JSON PATH (MsSql/DWSQL) or before LIMIT (PG/MySQL) - int insertIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); - if (insertIdx < 0) - { - insertIdx = sql.IndexOf(" LIMIT ", StringComparison.OrdinalIgnoreCase); - } - - if (insertIdx > 0) - { - sql = sql.Insert(insertIdx, orderByClause); - } - else - { - sql += orderByClause; - } - - // Add pagination (OFFSET/FETCH or LIMIT/OFFSET) for grouped results if (first.HasValue) { + // With pagination: SQL Server requires ORDER BY for OFFSET/FETCH and + // does not allow both TOP and OFFSET/FETCH. Remove TOP and add ORDER BY + OFFSET/FETCH. int offset = DecodeCursorOffset(after); int fetchCount = first.Value + 1; string offsetParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); @@ -501,24 +511,33 @@ public async Task ExecuteAsync( string limitParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); structure.Parameters.Add(limitParam, new DbConnectionParam(fetchCount)); - int paginationIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); - string paginationClause; - if (databaseType == DatabaseType.MSSQL || databaseType == DatabaseType.DWSQL) + string paginationClause = $" OFFSET {offsetParam} ROWS FETCH NEXT {limitParam} ROWS ONLY"; + + // Remove TOP N from the SELECT clause (TOP conflicts with OFFSET/FETCH) + sql = Regex.Replace(sql, @"SELECT TOP \d+", "SELECT"); + + // Insert ORDER BY + pagination before FOR JSON PATH + int jsonPathIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); + if (jsonPathIdx > 0) { - paginationClause = $" OFFSET {offsetParam} ROWS FETCH NEXT {limitParam} ROWS ONLY"; + sql = sql.Insert(jsonPathIdx, orderByClause + paginationClause); } else { - paginationClause = $" LIMIT {limitParam} OFFSET {offsetParam}"; + sql += orderByClause + paginationClause; } - - if (paginationIdx > 0) + } + else + { + // Without pagination: insert ORDER BY before FOR JSON PATH + int jsonPathIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); + if (jsonPathIdx > 0) { - sql = sql.Insert(paginationIdx, paginationClause); + sql = sql.Insert(jsonPathIdx, orderByClause); } else { - sql += paginationClause; + sql += orderByClause; } } } diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index dd94ae593d..9cc2790430 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -204,6 +204,51 @@ public async Task AggregateRecords_FirstExceedsMax_ReturnsInvalidArguments() Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("100000")); } + [TestMethod] + public async Task AggregateRecords_StarFieldWithAvg_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"avg\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("count")); + } + + [TestMethod] + public async Task AggregateRecords_DistinctCountStar_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"distinct\": true}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("DISTINCT")); + } + + [TestMethod] + public async Task AggregateRecords_HavingWithoutGroupBy_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"having\": {\"gt\": 5}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("groupby")); + } + #endregion #region Alias Convention Tests From 203fde1bb9a31bbdcc36d24af94a80ce9b5c6115 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:31:21 -0700 Subject: [PATCH 18/32] Add comprehensive blog scenario tests from DAB MCP blog Add 8 tests covering all 5 scenarios from the DAB MCP blog post (devblogs.microsoft.com/azure-sql/data-api-builder-mcp-questions): 1. Strategic customer importance (sum/groupby/orderby desc/first 1) 2. Product discontinuation (sum/groupby/orderby asc/first 1) 3. Quarterly performance (avg/groupby/having gt/orderby desc) 4. Revenue concentration (sum/complex filter/multi-groupby/having) 5. Risk exposure (sum/filter/multi-groupby/having gt) Each test verifies the exact blog JSON payload passes input validation, plus tests for schema completeness, describe_entities instruction, and alias convention documentation. 80 tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../Mcp/AggregateRecordsToolTests.cs | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 9cc2790430..78c83d12be 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -631,6 +631,231 @@ public void SpecExample13_AvgGroupByOrderByDescFirst3_CorrectAlias() #endregion + #region Blog Scenario Tests (devblogs.microsoft.com/azure-sql/data-api-builder-mcp-questions) + + // These tests verify that the exact JSON payloads from the DAB MCP blog + // pass input validation. The tool will fail at metadata resolution (no real DB) + // but must NOT return "InvalidArguments", proving the input shape is valid. + + /// + /// Blog Scenario 1: Strategic customer importance + /// "Who is our most important customer based on total revenue?" + /// Uses: sum, totalRevenue, filter, groupby [customerId, customerName], orderby desc, first 1 + /// + [TestMethod] + public async Task BlogScenario1_StrategicCustomerImportance_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""totalRevenue"", + ""filter"": ""isActive eq true and orderDate ge 2025-01-01"", + ""groupby"": [""customerId"", ""customerName""], + ""orderby"": ""desc"", + ""first"": 1 + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 1 JSON must pass input validation (sum/totalRevenue/groupby/orderby/first)."); + Assert.AreEqual("sum_totalRevenue", AggregateRecordsTool.ComputeAlias("sum", "totalRevenue")); + } + + /// + /// Blog Scenario 2: Product discontinuation candidate + /// "Which product should we consider discontinuing based on lowest totalRevenue?" + /// Uses: sum, totalRevenue, filter, groupby [productId, productName], orderby asc, first 1 + /// + [TestMethod] + public async Task BlogScenario2_ProductDiscontinuation_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""totalRevenue"", + ""filter"": ""isActive eq true and inStock gt 0 and orderDate ge 2025-01-01"", + ""groupby"": [""productId"", ""productName""], + ""orderby"": ""asc"", + ""first"": 1 + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 2 JSON must pass input validation (sum/totalRevenue/groupby/orderby asc/first)."); + Assert.AreEqual("sum_totalRevenue", AggregateRecordsTool.ComputeAlias("sum", "totalRevenue")); + } + + /// + /// Blog Scenario 3: Forward-looking performance expectation + /// "Average quarterlyRevenue per region, regions averaging > $2,000,000?" + /// Uses: avg, quarterlyRevenue, filter, groupby [region], having {gt: 2000000}, orderby desc + /// + [TestMethod] + public async Task BlogScenario3_QuarterlyPerformance_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""avg"", + ""field"": ""quarterlyRevenue"", + ""filter"": ""fiscalYear eq 2025"", + ""groupby"": [""region""], + ""having"": { ""gt"": 2000000 }, + ""orderby"": ""desc"" + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 3 JSON must pass input validation (avg/quarterlyRevenue/groupby/having gt)."); + Assert.AreEqual("avg_quarterlyRevenue", AggregateRecordsTool.ComputeAlias("avg", "quarterlyRevenue")); + } + + /// + /// Blog Scenario 4: Revenue concentration across regions + /// "Total revenue of active retail customers in Midwest/Southwest, >$5M, by region and customerTier" + /// Uses: sum, totalRevenue, complex filter with OR, groupby [region, customerTier], having {gt: 5000000}, orderby desc + /// + [TestMethod] + public async Task BlogScenario4_RevenueConcentration_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""totalRevenue"", + ""filter"": ""isActive eq true and customerType eq 'Retail' and (region eq 'Midwest' or region eq 'Southwest')"", + ""groupby"": [""region"", ""customerTier""], + ""having"": { ""gt"": 5000000 }, + ""orderby"": ""desc"" + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 4 JSON must pass input validation (sum/totalRevenue/complex filter/multi-groupby/having)."); + Assert.AreEqual("sum_totalRevenue", AggregateRecordsTool.ComputeAlias("sum", "totalRevenue")); + } + + /// + /// Blog Scenario 5: Risk exposure by product line + /// "For discontinued products, total onHandValue by productLine and warehouseRegion, >$2.5M" + /// Uses: sum, onHandValue, filter, groupby [productLine, warehouseRegion], having {gt: 2500000}, orderby desc + /// + [TestMethod] + public async Task BlogScenario5_RiskExposure_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""onHandValue"", + ""filter"": ""discontinued eq true and onHandValue gt 0"", + ""groupby"": [""productLine"", ""warehouseRegion""], + ""having"": { ""gt"": 2500000 }, + ""orderby"": ""desc"" + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 5 JSON must pass input validation (sum/onHandValue/filter/multi-groupby/having)."); + Assert.AreEqual("sum_onHandValue", AggregateRecordsTool.ComputeAlias("sum", "onHandValue")); + } + + /// + /// Verifies that the tool schema supports all properties used across the 5 blog scenarios. + /// + [TestMethod] + public void BlogScenarios_ToolSchema_SupportsAllRequiredProperties() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + JsonElement properties = metadata.InputSchema.GetProperty("properties"); + + string[] blogProperties = { "entity", "function", "field", "filter", "groupby", "orderby", "having", "first" }; + foreach (string prop in blogProperties) + { + Assert.IsTrue(properties.TryGetProperty(prop, out _), + $"Tool schema must include '{prop}' property used in blog scenarios."); + } + + // Additional schema properties used in spec but not blog + Assert.IsTrue(properties.TryGetProperty("distinct", out _), "Tool schema must include 'distinct'."); + Assert.IsTrue(properties.TryGetProperty("after", out _), "Tool schema must include 'after'."); + } + + /// + /// Verifies that the tool description instructs models to call describe_entities first. + /// + [TestMethod] + public void BlogScenarios_ToolDescription_ForcesDescribeEntitiesFirst() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + + Assert.IsTrue(metadata.Description!.Contains("describe_entities"), + "Tool description must instruct models to call describe_entities first."); + Assert.IsTrue(metadata.Description.Contains("STEP 1"), + "Tool description must use numbered steps starting with STEP 1."); + } + + /// + /// Verifies that the tool description documents the alias convention used in blog examples. + /// + [TestMethod] + public void BlogScenarios_ToolDescription_DocumentsAliasConvention() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + + Assert.IsTrue(metadata.Description!.Contains("{function}_{field}"), + "Tool description must document the alias pattern '{function}_{field}'."); + Assert.IsTrue(metadata.Description.Contains("'count'"), + "Tool description must mention the special 'count' alias for count(*)."); + } + + #endregion + #region Helper Methods private static JsonElement ParseContent(CallToolResult result) From fab587f0a96961a59fa2b9ed0cd869233ab40c68 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:34:04 -0700 Subject: [PATCH 19/32] Tighten tool description and parameter docs to remove duplication Remove redundant parameter listings from Description (already in InputSchema). Description now covers only: workflow steps, rules not expressed elsewhere, and response alias convention. Parameter descriptions simplified to one sentence each, removing repeated phrases like 'from describe_entities' and 'ONLY applies when groupby is provided' (stated once in groupby description). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 59 ++++++++----------- .../Mcp/AggregateRecordsToolTests.cs | 4 +- 2 files changed, 26 insertions(+), 37 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 2509629231..2b028c9ec4 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -46,88 +46,77 @@ public Tool GetToolMetadata() { Name = "aggregate_records", Description = "Computes aggregations (count, avg, sum, min, max) on entity data. " - + "STEP 1: Call describe_entities to discover entities with READ permission and their field names. " - + "STEP 2: Call this tool with the exact entity name, an aggregation function, and a field name from STEP 1. " - + "REQUIRED: entity (exact entity name), function (one of: count, avg, sum, min, max), field (exact field name, or '*' ONLY for count). " - + "OPTIONAL: filter (OData WHERE clause applied before aggregating, e.g. 'unitPrice lt 10'), " - + "distinct (true to deduplicate values before aggregating), " - + "groupby (array of field names to group results by, e.g. ['categoryName']), " - + "orderby ('asc' or 'desc' to sort grouped results by aggregated value; requires groupby), " - + "having (object to filter groups after aggregating, operators: eq, neq, gt, gte, lt, lte, in; requires groupby), " - + "first (integer >= 1, maximum grouped results to return; requires groupby), " - + "after (opaque cursor string from a previous response's endCursor for pagination). " - + "RESPONSE: The aggregated value is aliased as '{function}_{field}' (e.g. avg_unitPrice, sum_revenue). " - + "For count with field '*', the alias is 'count'. " - + "When first is used with groupby, response contains: items (array), endCursor (string), hasNextPage (boolean). " - + "RULES: 1) ALWAYS call describe_entities first to get valid entity and field names. " - + "2) Use field '*' ONLY with function 'count'. " - + "3) For avg, sum, min, max: field MUST be a numeric field name from describe_entities. " - + "4) orderby, having, and first ONLY apply when groupby is provided. " - + "5) Use first and after for paginating large grouped result sets.", + + "WORKFLOW: 1) Call describe_entities first to get entity names and field names. " + + "2) Call this tool with entity, function, and field from step 1. " + + "RULES: field '*' is ONLY valid with count. " + + "orderby, having, first, and after ONLY apply when groupby is provided. " + + "RESPONSE: Result is aliased as '{function}_{field}' (e.g. avg_unitPrice). " + + "For count(*), the alias is 'count'. " + + "With groupby and first, response includes items, endCursor, and hasNextPage for pagination.", InputSchema = JsonSerializer.Deserialize( @"{ ""type"": ""object"", ""properties"": { ""entity"": { ""type"": ""string"", - ""description"": ""Exact entity name from describe_entities that has READ permission. Must match exactly (case-sensitive)."" + ""description"": ""Entity name from describe_entities with READ permission (case-sensitive)."" }, ""function"": { ""type"": ""string"", ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], - ""description"": ""Aggregation function to apply. Use 'count' to count records, 'avg' for average, 'sum' for total, 'min' for minimum, 'max' for maximum. For count use field '*' or a specific field name. For avg, sum, min, max the field must be numeric."" + ""description"": ""Aggregation function. count supports field '*'; avg, sum, min, max require a numeric field."" }, ""field"": { ""type"": ""string"", - ""description"": ""Exact field name from describe_entities to aggregate. Use '*' ONLY with function 'count' to count all records. For avg, sum, min, max, provide a numeric field name."" + ""description"": ""Field name to aggregate, or '*' with count to count all rows."" }, ""distinct"": { ""type"": ""boolean"", - ""description"": ""When true, removes duplicate values before applying the aggregation function. For example, count with distinct counts unique values only. Default is false."", + ""description"": ""Remove duplicate values before aggregating. Not valid with field '*'."", ""default"": false }, ""filter"": { ""type"": ""string"", - ""description"": ""OData filter expression applied before aggregating (acts as a WHERE clause). Supported operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10' filters to rows where unitPrice is less than 10 before aggregating. Example: 'discontinued eq true and categoryName eq ''Seafood''' filters discontinued seafood products."", + ""description"": ""OData WHERE clause applied before aggregating. Operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10'."", ""default"": """" }, ""groupby"": { ""type"": ""array"", ""items"": { ""type"": ""string"" }, - ""description"": ""Array of exact field names from describe_entities to group results by. Each unique combination of grouped field values produces one aggregated row. Grouped field values are included in the response alongside the aggregated value. Example: ['categoryName'] groups by category. Example: ['categoryName', 'region'] groups by both fields."", + ""description"": ""Field names to group by. Each unique combination produces one aggregated row. Enables orderby, having, first, and after."", ""default"": [] }, ""orderby"": { ""type"": ""string"", ""enum"": [""asc"", ""desc""], - ""description"": ""Sort direction for grouped results by the computed aggregated value. 'desc' returns highest values first, 'asc' returns lowest first. ONLY applies when groupby is provided. Default is 'desc'."", + ""description"": ""Sort grouped results by the aggregated value. Requires groupby."", ""default"": ""desc"" }, ""having"": { ""type"": ""object"", - ""description"": ""Filter applied AFTER aggregating to filter grouped results by the computed aggregated value (acts as a HAVING clause). ONLY applies when groupby is provided. Multiple operators are AND-ed together. For example, use gt with value 20 to keep groups where the aggregated value exceeds 20. Combine gte and lte to define a range."", + ""description"": ""Filter groups by the aggregated value (HAVING clause). Requires groupby. Multiple operators are AND-ed."", ""properties"": { - ""eq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value equals this number."" }, - ""neq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value does not equal this number."" }, - ""gt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than this number."" }, - ""gte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than or equal to this number."" }, - ""lt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than this number."" }, - ""lte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than or equal to this number."" }, + ""eq"": { ""type"": ""number"", ""description"": ""Equals."" }, + ""neq"": { ""type"": ""number"", ""description"": ""Not equals."" }, + ""gt"": { ""type"": ""number"", ""description"": ""Greater than."" }, + ""gte"": { ""type"": ""number"", ""description"": ""Greater than or equal."" }, + ""lt"": { ""type"": ""number"", ""description"": ""Less than."" }, + ""lte"": { ""type"": ""number"", ""description"": ""Less than or equal."" }, ""in"": { ""type"": ""array"", ""items"": { ""type"": ""number"" }, - ""description"": ""Keep groups where the aggregated value matches any number in this list. Example: [5, 10] keeps groups with aggregated value 5 or 10."" + ""description"": ""Matches any value in the list."" } } }, ""first"": { ""type"": ""integer"", - ""description"": ""Maximum number of grouped results to return. Used for pagination of grouped results. ONLY applies when groupby is provided. Must be >= 1. When set, the response includes 'items', 'endCursor', and 'hasNextPage' fields for pagination."", + ""description"": ""Max grouped results to return. Requires groupby. Enables paginated response with endCursor and hasNextPage."", ""minimum"": 1 }, ""after"": { ""type"": ""string"", - ""description"": ""Opaque cursor string for pagination. Pass the 'endCursor' value from a previous response to get the next page of results. REQUIRES both groupby and first to be set. Do not construct this value manually; always use the endCursor from a previous response."" + ""description"": ""Opaque cursor from a previous endCursor for next-page retrieval. Requires groupby and first. Do not construct manually."" } }, ""required"": [""entity"", ""function"", ""field""] diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 78c83d12be..242fdb8b6f 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -835,8 +835,8 @@ public void BlogScenarios_ToolDescription_ForcesDescribeEntitiesFirst() Assert.IsTrue(metadata.Description!.Contains("describe_entities"), "Tool description must instruct models to call describe_entities first."); - Assert.IsTrue(metadata.Description.Contains("STEP 1"), - "Tool description must use numbered steps starting with STEP 1."); + Assert.IsTrue(metadata.Description.Contains("1)"), + "Tool description must use numbered workflow steps."); } /// From 5c93f9283b1018c334a3442f5bf269318c8eab31 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:38:34 -0700 Subject: [PATCH 20/32] Add early field validation and FieldNotFound error helper Validate field and groupby field names immediately after metadata resolution, before authorization or query building. Invalid field names now return a FieldNotFound error with model-friendly guidance to call describe_entities for valid field names. - Add McpErrorHelpers.FieldNotFound() with entity name, field name, parameter name, and describe_entities guidance - Move field existence checks before auth in AggregateRecordsTool - Remove redundant late validation (already caught early) - Add tests for FieldNotFound error type and message content 82 tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 37 ++++++++++------ .../Utils/McpErrorHelpers.cs | 11 +++++ .../Mcp/AggregateRecordsToolTests.cs | 43 +++++++++++++++++++ 3 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 2b028c9ec4..2dee64ccd1 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -278,6 +278,24 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); } + // Early field validation: check all user-supplied field names before authorization or query building. + // This lets the model discover and fix typos immediately. + if (!isCountStar) + { + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, field, out _)) + { + return McpErrorHelpers.FieldNotFound(toolName, entityName, field, "field", logger); + } + } + + foreach (string gField in groupby) + { + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, gField, out _)) + { + return McpErrorHelpers.FieldNotFound(toolName, entityName, gField, "groupby", logger); + } + } + // Authorization IAuthorizationResolver authResolver = serviceProvider.GetRequiredService(); IAuthorizationService authorizationService = serviceProvider.GetRequiredService(); @@ -358,15 +376,11 @@ public async Task ExecuteAsync( IQueryBuilder queryBuilder = queryManagerFactory.GetQueryBuilder(databaseType); IQueryExecutor queryExecutor = queryManagerFactory.GetQueryExecutor(databaseType); - // Resolve backing column name for the aggregation field + // Resolve backing column name for the aggregation field (already validated early) string? backingField = null; if (!isCountStar) { - if (!sqlMetadataProvider.TryGetBackingColumn(entityName, field, out backingField)) - { - return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", - $"Field '{field}' not found for entity '{entityName}'.", logger); - } + sqlMetadataProvider.TryGetBackingColumn(entityName, field, out backingField); } else { @@ -380,17 +394,12 @@ public async Task ExecuteAsync( } } - // Resolve backing column names for groupby fields + // Resolve backing column names for groupby fields (already validated early) List<(string entityField, string backingCol)> groupbyMapping = new(); foreach (string gField in groupby) { - if (!sqlMetadataProvider.TryGetBackingColumn(entityName, gField, out string? backingGCol)) - { - return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", - $"GroupBy field '{gField}' not found for entity '{entityName}'.", logger); - } - - groupbyMapping.Add((gField, backingGCol)); + sqlMetadataProvider.TryGetBackingColumn(entityName, gField, out string? backingGCol); + groupbyMapping.Add((gField, backingGCol!)); } string alias = ComputeAlias(function, field); diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs index 1a5c223798..13835b2fa9 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs @@ -24,5 +24,16 @@ public static CallToolResult ToolDisabled(string toolName, ILogger? logger, stri string message = customMessage ?? $"The {toolName} tool is disabled in the configuration."; return McpResponseBuilder.BuildErrorResult(toolName, Model.McpErrorCode.ToolDisabled.ToString(), message, logger); } + + /// + /// Returns a model-friendly error when a field name is not found for an entity. + /// Guides the model to call describe_entities to discover valid field names. + /// + public static CallToolResult FieldNotFound(string toolName, string entityName, string fieldName, string parameterName, ILogger? logger) + { + string message = $"Field '{fieldName}' in '{parameterName}' was not found for entity '{entityName}'. " + + $"Call describe_entities to get valid field names for '{entityName}'."; + return McpResponseBuilder.BuildErrorResult(toolName, "FieldNotFound", message, logger); + } } } diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 242fdb8b6f..41e028e28a 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -15,6 +15,7 @@ using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Mcp.BuiltInTools; using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -856,6 +857,48 @@ public void BlogScenarios_ToolDescription_DocumentsAliasConvention() #endregion + #region FieldNotFound Error Helper Tests + + /// + /// Verifies the FieldNotFound error helper produces the correct error type + /// and a model-friendly message that includes the field name, entity, and guidance. + /// + [TestMethod] + public void FieldNotFound_ReturnsCorrectErrorTypeAndMessage() + { + CallToolResult result = McpErrorHelpers.FieldNotFound("aggregate_records", "Product", "badField", "field", null); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + JsonElement error = content.GetProperty("error"); + + Assert.AreEqual("FieldNotFound", error.GetProperty("type").GetString()); + string message = error.GetProperty("message").GetString()!; + Assert.IsTrue(message.Contains("badField"), "Message must include the invalid field name."); + Assert.IsTrue(message.Contains("Product"), "Message must include the entity name."); + Assert.IsTrue(message.Contains("field"), "Message must identify which parameter was invalid."); + Assert.IsTrue(message.Contains("describe_entities"), "Message must guide the model to call describe_entities."); + } + + /// + /// Verifies the FieldNotFound error helper identifies the groupby parameter. + /// + [TestMethod] + public void FieldNotFound_GroupBy_IdentifiesParameter() + { + CallToolResult result = McpErrorHelpers.FieldNotFound("aggregate_records", "Product", "invalidCol", "groupby", null); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string message = content.GetProperty("error").GetProperty("message").GetString()!; + + Assert.IsTrue(message.Contains("invalidCol"), "Message must include the invalid field name."); + Assert.IsTrue(message.Contains("groupby"), "Message must identify 'groupby' as the parameter."); + Assert.IsTrue(message.Contains("describe_entities"), "Message must guide the model to call describe_entities."); + } + + #endregion + #region Helper Methods private static JsonElement ParseContent(CallToolResult result) From d1268f2f34c2cc8f3c40343cca2d795e2f401758 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:40:14 -0700 Subject: [PATCH 21/32] Rename truncated variables to descriptive names Rename abbreviated variable names to their full, readable forms: funcElfunctionElement, fieldElfieldElement, distinctEldistinctElement, filterElfilterElement, orderbyElorderbyElement, firstElfirstElement, afterElafterElement, groupbyElgroupbyElement, ggroupbyItem, gValgroupbyFieldName, gFieldgroupbyField, havingElhavingElement, havingOpshavingOperators, havingInhavingInValues, aggTypeaggregationType, aggColumnaggregationColumn, predOppredicateOperation, ophavingOperator, predpredicate, backingColbackingColumn, backingGColbackingGroupbyColumn, timeoutExtimeoutException, taskExtaskCanceledException, dbExdbException, argExargumentException/dabException. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 124 +++++++++--------- 1 file changed, 62 insertions(+), 62 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 2dee64ccd1..debce4f090 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -168,23 +168,23 @@ public async Task ExecuteAsync( return McpErrorHelpers.ToolDisabled(toolName, logger, $"DML tools are disabled for entity '{entityName}'."); } - if (!root.TryGetProperty("function", out JsonElement funcEl) || string.IsNullOrWhiteSpace(funcEl.GetString())) + if (!root.TryGetProperty("function", out JsonElement functionElement) || string.IsNullOrWhiteSpace(functionElement.GetString())) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'function'.", logger); } - string function = funcEl.GetString()!.ToLowerInvariant(); + string function = functionElement.GetString()!.ToLowerInvariant(); if (!_validFunctions.Contains(function)) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger); } - if (!root.TryGetProperty("field", out JsonElement fieldEl) || string.IsNullOrWhiteSpace(fieldEl.GetString())) + if (!root.TryGetProperty("field", out JsonElement fieldElement) || string.IsNullOrWhiteSpace(fieldElement.GetString())) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'field'.", logger); } - string field = fieldEl.GetString()!; + string field = fieldElement.GetString()!; // Validate field/function compatibility bool isCountStar = function == "count" && field == "*"; @@ -195,7 +195,7 @@ public async Task ExecuteAsync( $"Field '*' is only valid with function 'count'. For function '{function}', provide a specific field name.", logger); } - bool distinct = root.TryGetProperty("distinct", out JsonElement distinctEl) && distinctEl.GetBoolean(); + bool distinct = root.TryGetProperty("distinct", out JsonElement distinctElement) && distinctElement.GetBoolean(); // Reject count(*) with distinct as it is semantically undefined if (isCountStar && distinct) @@ -204,13 +204,13 @@ public async Task ExecuteAsync( "Cannot use distinct=true with field='*'. DISTINCT requires a specific field name. Use a field name instead of '*' to count distinct values.", logger); } - string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null; - string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc"; + string? filter = root.TryGetProperty("filter", out JsonElement filterElement) ? filterElement.GetString() : null; + string orderby = root.TryGetProperty("orderby", out JsonElement orderbyElement) ? (orderbyElement.GetString() ?? "desc") : "desc"; int? first = null; - if (root.TryGetProperty("first", out JsonElement firstEl) && firstEl.ValueKind == JsonValueKind.Number) + if (root.TryGetProperty("first", out JsonElement firstElement) && firstElement.ValueKind == JsonValueKind.Number) { - first = firstEl.GetInt32(); + first = firstElement.GetInt32(); if (first < 1) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger); @@ -222,24 +222,24 @@ public async Task ExecuteAsync( } } - string? after = root.TryGetProperty("after", out JsonElement afterEl) ? afterEl.GetString() : null; + string? after = root.TryGetProperty("after", out JsonElement afterElement) ? afterElement.GetString() : null; List groupby = new(); - if (root.TryGetProperty("groupby", out JsonElement groupbyEl) && groupbyEl.ValueKind == JsonValueKind.Array) + if (root.TryGetProperty("groupby", out JsonElement groupbyElement) && groupbyElement.ValueKind == JsonValueKind.Array) { - foreach (JsonElement g in groupbyEl.EnumerateArray()) + foreach (JsonElement groupbyItem in groupbyElement.EnumerateArray()) { - string? gVal = g.GetString(); - if (!string.IsNullOrWhiteSpace(gVal)) + string? groupbyFieldName = groupbyItem.GetString(); + if (!string.IsNullOrWhiteSpace(groupbyFieldName)) { - groupby.Add(gVal); + groupby.Add(groupbyFieldName); } } } - Dictionary? havingOps = null; - List? havingIn = null; - if (root.TryGetProperty("having", out JsonElement havingEl) && havingEl.ValueKind == JsonValueKind.Object) + Dictionary? havingOperators = null; + List? havingInValues = null; + if (root.TryGetProperty("having", out JsonElement havingElement) && havingElement.ValueKind == JsonValueKind.Object) { if (groupby.Count == 0) { @@ -247,20 +247,20 @@ public async Task ExecuteAsync( "The 'having' parameter requires 'groupby' to be specified. HAVING filters groups after aggregation.", logger); } - havingOps = new Dictionary(StringComparer.OrdinalIgnoreCase); - foreach (JsonProperty prop in havingEl.EnumerateObject()) + havingOperators = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (JsonProperty prop in havingElement.EnumerateObject()) { if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase) && prop.Value.ValueKind == JsonValueKind.Array) { - havingIn = new List(); + havingInValues = new List(); foreach (JsonElement item in prop.Value.EnumerateArray()) { - havingIn.Add(item.GetDouble()); + havingInValues.Add(item.GetDouble()); } } else if (prop.Value.ValueKind == JsonValueKind.Number) { - havingOps[prop.Name] = prop.Value.GetDouble(); + havingOperators[prop.Name] = prop.Value.GetDouble(); } } } @@ -288,11 +288,11 @@ public async Task ExecuteAsync( } } - foreach (string gField in groupby) + foreach (string groupbyField in groupby) { - if (!sqlMetadataProvider.TryGetBackingColumn(entityName, gField, out _)) + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, groupbyField, out _)) { - return McpErrorHelpers.FieldNotFound(toolName, entityName, gField, "groupby", logger); + return McpErrorHelpers.FieldNotFound(toolName, entityName, groupbyField, "groupby", logger); } } @@ -395,11 +395,11 @@ public async Task ExecuteAsync( } // Resolve backing column names for groupby fields (already validated early) - List<(string entityField, string backingCol)> groupbyMapping = new(); - foreach (string gField in groupby) + List<(string entityField, string backingColumn)> groupbyMapping = new(); + foreach (string groupbyField in groupby) { - sqlMetadataProvider.TryGetBackingColumn(entityName, gField, out string? backingGCol); - groupbyMapping.Add((gField, backingGCol!)); + sqlMetadataProvider.TryGetBackingColumn(entityName, groupbyField, out string? backingGroupbyColumn); + groupbyMapping.Add((groupbyField, backingGroupbyColumn!)); } string alias = ComputeAlias(function, field); @@ -408,27 +408,27 @@ public async Task ExecuteAsync( structure.Columns.Clear(); // Add groupby columns as LabelledColumns and GroupByMetadata.Fields - foreach (var (entityField, backingCol) in groupbyMapping) + foreach (var (entityField, backingColumn) in groupbyMapping) { structure.Columns.Add(new LabelledColumn( - dbObject.SchemaName, dbObject.Name, backingCol, entityField, structure.SourceAlias)); - structure.GroupByMetadata.Fields[backingCol] = new Column( - dbObject.SchemaName, dbObject.Name, backingCol, structure.SourceAlias); + dbObject.SchemaName, dbObject.Name, backingColumn, entityField, structure.SourceAlias)); + structure.GroupByMetadata.Fields[backingColumn] = new Column( + dbObject.SchemaName, dbObject.Name, backingColumn, structure.SourceAlias); } // Build aggregation column using engine's AggregationColumn type. // For COUNT(*), we use the primary key column (PK is always NOT NULL, so COUNT(pk) ≡ COUNT(*)). - AggregationType aggType = Enum.Parse(function); - AggregationColumn aggColumn = new( - dbObject.SchemaName, dbObject.Name, backingField!, aggType, alias, distinct, structure.SourceAlias); + AggregationType aggregationType = Enum.Parse(function); + AggregationColumn aggregationColumn = new( + dbObject.SchemaName, dbObject.Name, backingField!, aggregationType, alias, distinct, structure.SourceAlias); // Build HAVING predicates using engine's Predicate model List havingPredicates = new(); - if (havingOps != null) + if (havingOperators != null) { - foreach (var op in havingOps) + foreach (var havingOperator in havingOperators) { - PredicateOperation predOp = op.Key.ToLowerInvariant() switch + PredicateOperation predicateOperation = havingOperator.Key.ToLowerInvariant() switch { "eq" => PredicateOperation.Equal, "neq" => PredicateOperation.NotEqual, @@ -436,21 +436,21 @@ public async Task ExecuteAsync( "gte" => PredicateOperation.GreaterThanOrEqual, "lt" => PredicateOperation.LessThan, "lte" => PredicateOperation.LessThanOrEqual, - _ => throw new ArgumentException($"Invalid having operator: {op.Key}") + _ => throw new ArgumentException($"Invalid having operator: {havingOperator.Key}") }; string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); - structure.Parameters.Add(paramName, new DbConnectionParam(op.Value)); + structure.Parameters.Add(paramName, new DbConnectionParam(havingOperator.Value)); havingPredicates.Add(new Predicate( - new PredicateOperand(aggColumn), - predOp, + new PredicateOperand(aggregationColumn), + predicateOperation, new PredicateOperand(paramName))); } } - if (havingIn != null && havingIn.Count > 0) + if (havingInValues != null && havingInValues.Count > 0) { List inParams = new(); - foreach (double val in havingIn) + foreach (double val in havingInValues) { string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); structure.Parameters.Add(paramName, new DbConnectionParam(val)); @@ -458,22 +458,22 @@ public async Task ExecuteAsync( } havingPredicates.Add(new Predicate( - new PredicateOperand(aggColumn), + new PredicateOperand(aggregationColumn), PredicateOperation.IN, new PredicateOperand($"({string.Join(", ", inParams)})"))); } // Combine multiple HAVING predicates with AND Predicate? combinedHaving = null; - foreach (var pred in havingPredicates) + foreach (var predicate in havingPredicates) { combinedHaving = combinedHaving == null - ? pred - : new Predicate(new PredicateOperand(combinedHaving), PredicateOperation.AND, new PredicateOperand(pred)); + ? predicate + : new Predicate(new PredicateOperand(combinedHaving), PredicateOperation.AND, new PredicateOperand(predicate)); } structure.GroupByMetadata.Aggregations.Add( - new AggregationOperation(aggColumn, having: combinedHaving != null ? new List { combinedHaving } : null)); + new AggregationOperation(aggregationColumn, having: combinedHaving != null ? new List { combinedHaving } : null)); structure.GroupByMetadata.RequestedAggregations = true; // Clear default OrderByColumns (PK-based) @@ -564,9 +564,9 @@ public async Task ExecuteAsync( return BuildSimpleResponse(resultArray, entityName, alias, logger); } - catch (TimeoutException timeoutEx) + catch (TimeoutException timeoutException) { - logger?.LogError(timeoutEx, "Aggregation operation timed out for entity {Entity}.", entityName); + logger?.LogError(timeoutException, "Aggregation operation timed out for entity {Entity}.", entityName); return McpResponseBuilder.BuildErrorResult( toolName, "TimeoutError", @@ -576,9 +576,9 @@ public async Task ExecuteAsync( + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.", logger); } - catch (TaskCanceledException taskEx) + catch (TaskCanceledException taskCanceledException) { - logger?.LogError(taskEx, "Aggregation task was canceled for entity {Entity}.", entityName); + logger?.LogError(taskCanceledException, "Aggregation task was canceled for entity {Entity}.", entityName); return McpResponseBuilder.BuildErrorResult( toolName, "TimeoutError", @@ -598,18 +598,18 @@ public async Task ExecuteAsync( + "No results were returned. You may retry the same request.", logger); } - catch (DbException dbEx) + catch (DbException dbException) { - logger?.LogError(dbEx, "Database error during aggregation for entity {Entity}.", entityName); - return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", dbEx.Message, logger); + logger?.LogError(dbException, "Database error during aggregation for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", dbException.Message, logger); } - catch (ArgumentException argEx) + catch (ArgumentException argumentException) { - return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argumentException.Message, logger); } - catch (DataApiBuilderException argEx) + catch (DataApiBuilderException dabException) { - return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger); + return McpResponseBuilder.BuildErrorResult(toolName, dabException.StatusCode.ToString(), dabException.Message, logger); } catch (Exception ex) { From 539047153949d14f124098b1f102cba62dec9994 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:41:35 -0700 Subject: [PATCH 22/32] Remove hallucinated first > 100000 validation DAB config already has MaxResponseSize property that handles this downstream through structure.Limit(). The engine applies the configured limit automatically, making this artificial cap redundant. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 5 +---- .../Mcp/AggregateRecordsToolTests.cs | 15 --------------- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index debce4f090..11d629378e 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -216,10 +216,7 @@ public async Task ExecuteAsync( return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger); } - if (first > 100_000) - { - return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must not exceed 100000.", logger); - } + } string? after = root.TryGetProperty("after", out JsonElement afterElement) ? afterElement.GetString() : null; diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 41e028e28a..9898b16c68 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -190,21 +190,6 @@ public async Task AggregateRecords_InvalidFunction_ReturnsInvalidArguments() Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("median")); } - [TestMethod] - public async Task AggregateRecords_FirstExceedsMax_ReturnsInvalidArguments() - { - RuntimeConfig config = CreateConfig(); - IServiceProvider sp = CreateServiceProvider(config); - AggregateRecordsTool tool = new(); - - JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"first\": 200000, \"groupby\": [\"title\"]}"); - CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); - Assert.IsTrue(result.IsError == true); - JsonElement content = ParseContent(result); - Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); - Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("100000")); - } - [TestMethod] public async Task AggregateRecords_StarFieldWithAvg_ReturnsInvalidArguments() { From b55cdde6efee4ee30df648b7a34edfc417ab1d73 Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:43:04 -0700 Subject: [PATCH 23/32] Clean up extra blank line from validation removal Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 11d629378e..efccffdb3a 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -215,8 +215,6 @@ public async Task ExecuteAsync( { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger); } - - } string? after = root.TryGetProperty("after", out JsonElement afterElement) ? afterElement.GetString() : null; From 7f4e2596d568e858cca93524ab5b46a1b12d55ee Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:43:54 -0700 Subject: [PATCH 24/32] Add AggregateRecordsTool documentation for SQL-level aggregations --- .../BuiltInTools/AggregateRecordsTool.md | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md new file mode 100644 index 0000000000..002dad9a89 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md @@ -0,0 +1,104 @@ +# AggregateRecordsTool + +MCP tool that computes SQL-level aggregations (COUNT, AVG, SUM, MIN, MAX) on DAB entities. All aggregation is pushed to the database engine — no in-memory computation. + +## Class Structure + +| Member | Kind | Purpose | +|---|---|---| +| `ToolType` | Property | Returns `ToolType.BuiltIn` for reflection-based discovery. | +| `_validFunctions` | Static field | Allowlist of aggregation functions: count, avg, sum, min, max. | +| `GetToolMetadata()` | Method | Returns the MCP `Tool` descriptor (name, description, JSON input schema). | +| `ExecuteAsync()` | Method | Main entry point — validates input, resolves metadata, authorizes, builds the SQL query via the engine's `IQueryBuilder.Build(SqlQueryStructure)`, executes it, and formats the response. | +| `ComputeAlias()` | Static method | Produces the result column alias: `"count"` for count(\*), otherwise `"{function}_{field}"`. | +| `DecodeCursorOffset()` | Static method | Decodes a base64 opaque cursor string to an integer offset for OFFSET/FETCH pagination. Returns 0 on any invalid input. | +| `BuildPaginatedResponse()` | Private method | Formats a grouped result set into `{ items, endCursor, hasNextPage }` when `first` is provided. | +| `BuildSimpleResponse()` | Private method | Formats a scalar or grouped result set without pagination. | + +## ExecuteAsync Sequence + +```mermaid +sequenceDiagram + participant Model as LLM / MCP Client + participant Tool as AggregateRecordsTool + participant Config as RuntimeConfigProvider + participant Meta as ISqlMetadataProvider + participant Auth as IAuthorizationService + participant QB as IQueryBuilder (engine) + participant QE as IQueryExecutor + participant DB as Database + + Model->>Tool: ExecuteAsync(arguments, serviceProvider, cancellationToken) + + Note over Tool: 1. Input validation + Tool->>Config: GetConfig() + Config-->>Tool: RuntimeConfig + Tool->>Tool: Validate tool enabled (runtime + entity level) + Tool->>Tool: Parse & validate arguments (entity, function, field, distinct, filter, groupby, having, first, after) + + Note over Tool: 2. Metadata resolution + Tool->>Meta: TryResolveMetadata(entityName) + Meta-->>Tool: sqlMetadataProvider, dbObject, dataSourceName + + Note over Tool: 3. Early field validation + Tool->>Meta: TryGetBackingColumn(entityName, field) + Meta-->>Tool: backingColumn (or FieldNotFound error) + loop Each groupby field + Tool->>Meta: TryGetBackingColumn(entityName, groupbyField) + Meta-->>Tool: backingColumn (or FieldNotFound error) + end + + Note over Tool: 4. Authorization + Tool->>Auth: AuthorizeAsync(user, FindRequestContext, ColumnsPermissionsRequirement) + Auth-->>Tool: AuthorizationResult + + Note over Tool: 5. Build SqlQueryStructure + Tool->>Tool: Create SqlQueryStructure from FindRequestContext + Tool->>Tool: Populate GroupByMetadata (fields, AggregationColumn, HAVING predicates) + Tool->>Tool: Clear default columns/OrderBy, set aggregation flag + + Note over Tool: 6. Generate SQL via engine + Tool->>QB: Build(SqlQueryStructure) + QB-->>Tool: SQL string (SELECT ... GROUP BY ... HAVING ... FOR JSON PATH) + + Note over Tool: 7. Post-process SQL + Tool->>Tool: Insert ORDER BY aggregate expression before FOR JSON PATH + opt Pagination (first provided) + Tool->>Tool: Remove TOP N (conflicts with OFFSET/FETCH) + Tool->>Tool: Append OFFSET/FETCH NEXT + end + + Note over Tool: 8. Execute query + Tool->>QE: ExecuteQueryAsync(sql, parameters, GetJsonResultAsync, dataSourceName) + QE->>DB: Execute SQL + DB-->>QE: JSON result + QE-->>Tool: JsonDocument + + Note over Tool: 9. Format response + alt first provided (paginated) + Tool->>Tool: BuildPaginatedResponse(resultArray, first, after) + Tool-->>Model: { items, endCursor, hasNextPage } + else simple + Tool->>Tool: BuildSimpleResponse(resultArray, alias) + Tool-->>Model: { entity, result: [{alias: value}] } + end + + Note over Tool: Exception handling + alt TimeoutException + Tool-->>Model: TimeoutError — "query timed out, narrow filters or paginate" + else TaskCanceledException + Tool-->>Model: TimeoutError — "canceled, likely timeout" + else OperationCanceledException + Tool-->>Model: OperationCanceled — "interrupted, retry" + else DbException + Tool-->>Model: DatabaseOperationFailed + end +``` + +## Key Design Decisions + +- **No in-memory aggregation.** The engine's `GroupByMetadata` / `AggregationColumn` types drive SQL generation via `queryBuilder.Build(structure)`. +- **COUNT(\*) workaround.** The engine's `Build(AggregationColumn)` doesn't support `*` as a column name, so the primary key column is used instead (`COUNT(pk)` ≡ `COUNT(*)` since PK is NOT NULL). +- **ORDER BY aggregate.** Neither the GraphQL nor REST paths support ORDER BY on an aggregate expression, so the tool post-processes the generated SQL to insert it before `FOR JSON PATH`. +- **TOP vs OFFSET/FETCH.** SQL Server forbids both in the same query. When pagination is used, `TOP N` is stripped via regex. +- **Database support.** Only MsSql / DWSQL — matches the engine's GraphQL aggregation support. PostgreSQL, MySQL, and CosmosDB return an `UnsupportedDatabase` error. From d83ded22f36afaa2eeb6f5d5b46083f3558af7ce Mon Sep 17 00:00:00 2001 From: Jerry Nixon Date: Mon, 2 Mar 2026 18:55:23 -0700 Subject: [PATCH 25/32] Simplify sequence diagram and expand design decisions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../BuiltInTools/AggregateRecordsTool.md | 93 ++++--------------- 1 file changed, 20 insertions(+), 73 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md index 002dad9a89..718aa360e1 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md @@ -19,86 +19,33 @@ MCP tool that computes SQL-level aggregations (COUNT, AVG, SUM, MIN, MAX) on DAB ```mermaid sequenceDiagram - participant Model as LLM / MCP Client + participant Client as MCP Client participant Tool as AggregateRecordsTool - participant Config as RuntimeConfigProvider - participant Meta as ISqlMetadataProvider - participant Auth as IAuthorizationService - participant QB as IQueryBuilder (engine) - participant QE as IQueryExecutor + participant Engine as DAB Engine participant DB as Database - Model->>Tool: ExecuteAsync(arguments, serviceProvider, cancellationToken) - - Note over Tool: 1. Input validation - Tool->>Config: GetConfig() - Config-->>Tool: RuntimeConfig - Tool->>Tool: Validate tool enabled (runtime + entity level) - Tool->>Tool: Parse & validate arguments (entity, function, field, distinct, filter, groupby, having, first, after) - - Note over Tool: 2. Metadata resolution - Tool->>Meta: TryResolveMetadata(entityName) - Meta-->>Tool: sqlMetadataProvider, dbObject, dataSourceName - - Note over Tool: 3. Early field validation - Tool->>Meta: TryGetBackingColumn(entityName, field) - Meta-->>Tool: backingColumn (or FieldNotFound error) - loop Each groupby field - Tool->>Meta: TryGetBackingColumn(entityName, groupbyField) - Meta-->>Tool: backingColumn (or FieldNotFound error) - end - - Note over Tool: 4. Authorization - Tool->>Auth: AuthorizeAsync(user, FindRequestContext, ColumnsPermissionsRequirement) - Auth-->>Tool: AuthorizationResult - - Note over Tool: 5. Build SqlQueryStructure - Tool->>Tool: Create SqlQueryStructure from FindRequestContext - Tool->>Tool: Populate GroupByMetadata (fields, AggregationColumn, HAVING predicates) - Tool->>Tool: Clear default columns/OrderBy, set aggregation flag - - Note over Tool: 6. Generate SQL via engine - Tool->>QB: Build(SqlQueryStructure) - QB-->>Tool: SQL string (SELECT ... GROUP BY ... HAVING ... FOR JSON PATH) - - Note over Tool: 7. Post-process SQL - Tool->>Tool: Insert ORDER BY aggregate expression before FOR JSON PATH - opt Pagination (first provided) - Tool->>Tool: Remove TOP N (conflicts with OFFSET/FETCH) - Tool->>Tool: Append OFFSET/FETCH NEXT - end - - Note over Tool: 8. Execute query - Tool->>QE: ExecuteQueryAsync(sql, parameters, GetJsonResultAsync, dataSourceName) - QE->>DB: Execute SQL - DB-->>QE: JSON result - QE-->>Tool: JsonDocument - - Note over Tool: 9. Format response - alt first provided (paginated) - Tool->>Tool: BuildPaginatedResponse(resultArray, first, after) - Tool-->>Model: { items, endCursor, hasNextPage } - else simple - Tool->>Tool: BuildSimpleResponse(resultArray, alias) - Tool-->>Model: { entity, result: [{alias: value}] } + Client->>Tool: ExecuteAsync(arguments) + Tool->>Tool: Validate inputs & check tool enabled + Tool->>Engine: Resolve entity metadata & validate fields + Tool->>Engine: Authorize (column-level permissions) + Tool->>Engine: Build SQL via queryBuilder.Build(SqlQueryStructure) + Tool->>Tool: Post-process SQL (ORDER BY, pagination) + Tool->>DB: ExecuteQueryAsync → JSON result + alt Paginated (first provided) + Tool-->>Client: { items, endCursor, hasNextPage } + else Simple + Tool-->>Client: { entity, result: [{alias: value}] } end - Note over Tool: Exception handling - alt TimeoutException - Tool-->>Model: TimeoutError — "query timed out, narrow filters or paginate" - else TaskCanceledException - Tool-->>Model: TimeoutError — "canceled, likely timeout" - else OperationCanceledException - Tool-->>Model: OperationCanceled — "interrupted, retry" - else DbException - Tool-->>Model: DatabaseOperationFailed - end + Note over Tool,Client: On error: TimeoutError, OperationCanceled, or DatabaseOperationFailed ``` ## Key Design Decisions -- **No in-memory aggregation.** The engine's `GroupByMetadata` / `AggregationColumn` types drive SQL generation via `queryBuilder.Build(structure)`. -- **COUNT(\*) workaround.** The engine's `Build(AggregationColumn)` doesn't support `*` as a column name, so the primary key column is used instead (`COUNT(pk)` ≡ `COUNT(*)` since PK is NOT NULL). -- **ORDER BY aggregate.** Neither the GraphQL nor REST paths support ORDER BY on an aggregate expression, so the tool post-processes the generated SQL to insert it before `FOR JSON PATH`. -- **TOP vs OFFSET/FETCH.** SQL Server forbids both in the same query. When pagination is used, `TOP N` is stripped via regex. +- **No in-memory aggregation.** The engine's `GroupByMetadata` / `AggregationColumn` types drive SQL generation via `queryBuilder.Build(structure)`. All aggregation is performed by the database. +- **COUNT(\*) workaround.** The engine's `Build(AggregationColumn)` doesn't support `*` as a column name (it produces invalid SQL like `count([].[*])`), so the primary key column is used instead. `COUNT(pk)` ≡ `COUNT(*)` since PK is NOT NULL. +- **ORDER BY post-processing.** Neither the GraphQL nor REST code paths support ORDER BY on an aggregate expression, so this tool inserts `ORDER BY {func}({col}) ASC|DESC` into the generated SQL before `FOR JSON PATH`. +- **TOP vs OFFSET/FETCH.** SQL Server forbids both in the same query. When pagination (`first`) is used, `TOP N` is stripped via regex before appending `OFFSET/FETCH NEXT`. +- **Early field validation.** All user-supplied field names (aggregation field, groupby fields) are validated against the entity's metadata before authorization or query building, so typos surface immediately with actionable guidance. +- **Timeout vs cancellation.** `TimeoutException` (from `query-timeout` config) and `OperationCanceledException` (from client disconnect) are handled separately with distinct model-facing messages. Timeouts guide the model to narrow filters or paginate; cancellations suggest retry. - **Database support.** Only MsSql / DWSQL — matches the engine's GraphQL aggregation support. PostgreSQL, MySQL, and CosmosDB return an `UnsupportedDatabase` error. From 6815b656a54b4b074704462ecbfdc8d28cbebb5b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 10:13:21 +0000 Subject: [PATCH 26/32] Changes before error encountered Co-authored-by: souvikghosh04 <210500244+souvikghosh04@users.noreply.github.com> --- schemas/dab.draft.schema.json | 5 +- .../BuiltInTools/AggregateRecordsTool.cs | 182 ++++++++++-------- .../Utils/McpTelemetryHelper.cs | 6 + src/Config/ObjectModel/McpRuntimeOptions.cs | 1 + .../Configurations/RuntimeConfigValidator.cs | 5 +- 5 files changed, 116 insertions(+), 83 deletions(-) diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index e78861807d..8f283fba36 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -277,9 +277,10 @@ }, "query-timeout": { "type": "integer", - "description": "Execution timeout in seconds for MCP tool operations. Applies to all MCP tools.", + "description": "Execution timeout in seconds for MCP tool operations. Applies to all MCP tools. Range: 1-600.", "default": 30, - "minimum": 1 + "minimum": 1, + "maximum": 600 }, "dml-tools": { "oneOf": [ diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index efccffdb3a..806671fdaa 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -40,89 +40,91 @@ public class AggregateRecordsTool : IMcpTool private static readonly HashSet _validFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; - public Tool GetToolMetadata() + private static readonly Tool _cachedToolMetadata = new() { - return new Tool - { - Name = "aggregate_records", - Description = "Computes aggregations (count, avg, sum, min, max) on entity data. " - + "WORKFLOW: 1) Call describe_entities first to get entity names and field names. " - + "2) Call this tool with entity, function, and field from step 1. " - + "RULES: field '*' is ONLY valid with count. " - + "orderby, having, first, and after ONLY apply when groupby is provided. " - + "RESPONSE: Result is aliased as '{function}_{field}' (e.g. avg_unitPrice). " - + "For count(*), the alias is 'count'. " - + "With groupby and first, response includes items, endCursor, and hasNextPage for pagination.", - InputSchema = JsonSerializer.Deserialize( - @"{ - ""type"": ""object"", - ""properties"": { - ""entity"": { - ""type"": ""string"", - ""description"": ""Entity name from describe_entities with READ permission (case-sensitive)."" - }, - ""function"": { - ""type"": ""string"", - ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], - ""description"": ""Aggregation function. count supports field '*'; avg, sum, min, max require a numeric field."" - }, - ""field"": { - ""type"": ""string"", - ""description"": ""Field name to aggregate, or '*' with count to count all rows."" - }, - ""distinct"": { - ""type"": ""boolean"", - ""description"": ""Remove duplicate values before aggregating. Not valid with field '*'."", - ""default"": false - }, - ""filter"": { - ""type"": ""string"", - ""description"": ""OData WHERE clause applied before aggregating. Operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10'."", - ""default"": """" - }, - ""groupby"": { - ""type"": ""array"", - ""items"": { ""type"": ""string"" }, - ""description"": ""Field names to group by. Each unique combination produces one aggregated row. Enables orderby, having, first, and after."", - ""default"": [] - }, - ""orderby"": { - ""type"": ""string"", - ""enum"": [""asc"", ""desc""], - ""description"": ""Sort grouped results by the aggregated value. Requires groupby."", - ""default"": ""desc"" - }, - ""having"": { - ""type"": ""object"", - ""description"": ""Filter groups by the aggregated value (HAVING clause). Requires groupby. Multiple operators are AND-ed."", - ""properties"": { - ""eq"": { ""type"": ""number"", ""description"": ""Equals."" }, - ""neq"": { ""type"": ""number"", ""description"": ""Not equals."" }, - ""gt"": { ""type"": ""number"", ""description"": ""Greater than."" }, - ""gte"": { ""type"": ""number"", ""description"": ""Greater than or equal."" }, - ""lt"": { ""type"": ""number"", ""description"": ""Less than."" }, - ""lte"": { ""type"": ""number"", ""description"": ""Less than or equal."" }, - ""in"": { - ""type"": ""array"", - ""items"": { ""type"": ""number"" }, - ""description"": ""Matches any value in the list."" - } + Name = "aggregate_records", + Description = "Computes aggregations (count, avg, sum, min, max) on entity data. " + + "WORKFLOW: 1) Call describe_entities first to get entity names and field names. " + + "2) Call this tool with entity, function, and field from step 1. " + + "RULES: field '*' is ONLY valid with count. " + + "orderby, having, first, and after ONLY apply when groupby is provided. " + + "RESPONSE: Result is aliased as '{function}_{field}' (e.g. avg_unitPrice). " + + "For count(*), the alias is 'count'. " + + "With groupby and first, response includes items, endCursor, and hasNextPage for pagination.", + InputSchema = JsonSerializer.Deserialize( + @"{ + ""type"": ""object"", + ""properties"": { + ""entity"": { + ""type"": ""string"", + ""description"": ""Entity name from describe_entities with READ permission (case-sensitive)."" + }, + ""function"": { + ""type"": ""string"", + ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], + ""description"": ""Aggregation function. count supports field '*'; avg, sum, min, max require a numeric field."" + }, + ""field"": { + ""type"": ""string"", + ""description"": ""Field name to aggregate, or '*' with count to count all rows."" + }, + ""distinct"": { + ""type"": ""boolean"", + ""description"": ""Remove duplicate values before aggregating. Not valid with field '*'."", + ""default"": false + }, + ""filter"": { + ""type"": ""string"", + ""description"": ""OData WHERE clause applied before aggregating. Operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10'."", + ""default"": """" + }, + ""groupby"": { + ""type"": ""array"", + ""items"": { ""type"": ""string"" }, + ""description"": ""Field names to group by. Each unique combination produces one aggregated row. Enables orderby, having, first, and after."", + ""default"": [] + }, + ""orderby"": { + ""type"": ""string"", + ""enum"": [""asc"", ""desc""], + ""description"": ""Sort grouped results by the aggregated value. Requires groupby."", + ""default"": ""desc"" + }, + ""having"": { + ""type"": ""object"", + ""description"": ""Filter groups by the aggregated value (HAVING clause). Requires groupby. Multiple operators are AND-ed."", + ""properties"": { + ""eq"": { ""type"": ""number"", ""description"": ""Equals."" }, + ""neq"": { ""type"": ""number"", ""description"": ""Not equals."" }, + ""gt"": { ""type"": ""number"", ""description"": ""Greater than."" }, + ""gte"": { ""type"": ""number"", ""description"": ""Greater than or equal."" }, + ""lt"": { ""type"": ""number"", ""description"": ""Less than."" }, + ""lte"": { ""type"": ""number"", ""description"": ""Less than or equal."" }, + ""in"": { + ""type"": ""array"", + ""items"": { ""type"": ""number"" }, + ""description"": ""Matches any value in the list."" } - }, - ""first"": { - ""type"": ""integer"", - ""description"": ""Max grouped results to return. Requires groupby. Enables paginated response with endCursor and hasNextPage."", - ""minimum"": 1 - }, - ""after"": { - ""type"": ""string"", - ""description"": ""Opaque cursor from a previous endCursor for next-page retrieval. Requires groupby and first. Do not construct manually."" } }, - ""required"": [""entity"", ""function"", ""field""] - }" - ) - }; + ""first"": { + ""type"": ""integer"", + ""description"": ""Max grouped results to return. Requires groupby. Enables paginated response with endCursor and hasNextPage."", + ""minimum"": 1 + }, + ""after"": { + ""type"": ""string"", + ""description"": ""Opaque cursor from a previous endCursor for next-page retrieval. Requires groupby and first. Do not construct manually."" + } + }, + ""required"": [""entity"", ""function"", ""field""] + }" + ) + }; + + public Tool GetToolMetadata() + { + return _cachedToolMetadata; } public async Task ExecuteAsync( @@ -232,6 +234,28 @@ public async Task ExecuteAsync( } } + // Validate that first, after, and non-default orderby require groupby + if (groupby.Count == 0) + { + if (first.HasValue) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'first' parameter requires 'groupby' to be specified. Pagination applies to grouped aggregation results.", logger); + } + + if (!string.IsNullOrEmpty(after)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'after' parameter requires 'groupby' to be specified. Pagination applies to grouped aggregation results.", logger); + } + } + + if (!string.IsNullOrEmpty(after) && !first.HasValue) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'after' parameter requires 'first' to be specified. Provide 'first' to enable pagination.", logger); + } + Dictionary? havingOperators = null; List? havingInValues = null; if (root.TryGetProperty("having", out JsonElement havingElement) && havingElement.ValueKind == JsonValueKind.Object) diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs index ac567d4d8c..31a92ef6e4 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs @@ -69,6 +69,12 @@ public static async Task ExecuteWithTelemetryAsync( timeoutSeconds = config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds ?? McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; } + // Defensive runtime guard: clamp timeout to valid range [1, MAX_QUERY_TIMEOUT_SECONDS]. + if (timeoutSeconds < 1 || timeoutSeconds > McpRuntimeOptions.MAX_QUERY_TIMEOUT_SECONDS) + { + timeoutSeconds = McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; + } + // Wrap tool execution with the configured timeout using a linked CancellationTokenSource. using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds)); diff --git a/src/Config/ObjectModel/McpRuntimeOptions.cs b/src/Config/ObjectModel/McpRuntimeOptions.cs index f4b4281a14..5b48b2fcc3 100644 --- a/src/Config/ObjectModel/McpRuntimeOptions.cs +++ b/src/Config/ObjectModel/McpRuntimeOptions.cs @@ -11,6 +11,7 @@ public record McpRuntimeOptions { public const string DEFAULT_PATH = "/mcp"; public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 30; + public const int MAX_QUERY_TIMEOUT_SECONDS = 600; /// /// Whether MCP endpoints are enabled diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index ea2299bc6f..90550fbaba 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -916,10 +916,11 @@ public void ValidateMcpUri(RuntimeConfig runtimeConfig) } // Validate query-timeout if provided - if (runtimeConfig.Runtime.Mcp.QueryTimeout is not null && runtimeConfig.Runtime.Mcp.QueryTimeout < 1) + if (runtimeConfig.Runtime.Mcp.QueryTimeout is not null && + (runtimeConfig.Runtime.Mcp.QueryTimeout < 1 || runtimeConfig.Runtime.Mcp.QueryTimeout > McpRuntimeOptions.MAX_QUERY_TIMEOUT_SECONDS)) { HandleOrRecordException(new DataApiBuilderException( - message: "MCP query-timeout must be a positive integer (>= 1 second). " + + message: $"MCP query-timeout must be between 1 and {McpRuntimeOptions.MAX_QUERY_TIMEOUT_SECONDS} seconds. " + $"Provided value: {runtimeConfig.Runtime.Mcp.QueryTimeout}.", statusCode: HttpStatusCode.ServiceUnavailable, subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); From c7010ffd6b787e5cf03a7f25a5d919c20a595ca8 Mon Sep 17 00:00:00 2001 From: souvikghosh04 Date: Tue, 3 Mar 2026 17:13:44 +0530 Subject: [PATCH 27/32] Removing duplicate registration from stdio which is failing runs --- src/Service/Utilities/McpStdioHelper.cs | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/Service/Utilities/McpStdioHelper.cs b/src/Service/Utilities/McpStdioHelper.cs index 043e9dd85d..f22e12b02f 100644 --- a/src/Service/Utilities/McpStdioHelper.cs +++ b/src/Service/Utilities/McpStdioHelper.cs @@ -78,15 +78,8 @@ public static bool RunMcpStdioHost(IHost host) { host.Start(); - Mcp.Core.McpToolRegistry registry = - host.Services.GetRequiredService(); - IEnumerable tools = - host.Services.GetServices(); - - foreach (Mcp.Model.IMcpTool tool in tools) - { - registry.RegisterTool(tool); - } + // Tools are already registered by McpToolRegistryInitializer (IHostedService) + // during host.Start(). No need to register them again here. IHostApplicationLifetime lifetime = host.Services.GetRequiredService(); From 5038cc71e86f9ce24b5c3041d82c4b619e86f127 Mon Sep 17 00:00:00 2001 From: Souvik Ghosh Date: Tue, 3 Mar 2026 17:01:08 +0000 Subject: [PATCH 28/32] update snapshot test files --- ...nTests.TestReadingRuntimeConfigForCosmos.verified.txt | 9 +++++++-- ...onTests.TestReadingRuntimeConfigForMsSql.verified.txt | 9 +++++++-- ...onTests.TestReadingRuntimeConfigForMySql.verified.txt | 9 +++++++-- ...ts.TestReadingRuntimeConfigForPostgreSql.verified.txt | 9 +++++++-- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt index 9279da9d59..0b2fd67066 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt @@ -28,14 +28,19 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedPath: false, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt index 35fd562c87..4ee73b2b4a 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt @@ -32,14 +32,19 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedPath: false, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt index 1490309ece..5522043d3f 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt @@ -24,14 +24,19 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedPath: false, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt index ceba40ae63..b52c59df32 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt @@ -24,14 +24,19 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedPath: false, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { From 88968f6c3ba794d211dfbc4eebb8b7d376da99b9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 28 Feb 2026 05:55:59 +0000 Subject: [PATCH 29/32] Initial plan From b87be6f79cf93d4b79d46899bcf15579b2bd2e7e Mon Sep 17 00:00:00 2001 From: Souvik Ghosh Date: Thu, 5 Mar 2026 06:13:54 +0000 Subject: [PATCH 30/32] Fixes from code reviews --- .../BuiltInTools/AggregateRecordsTool.cs | 42 +++++++++- .../Mcp/AggregateRecordsToolTests.cs | 82 ++++++++++++++++++- src/Service.Tests/Mcp/McpQueryTimeoutTests.cs | 2 - 3 files changed, 116 insertions(+), 10 deletions(-) diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs index 806671fdaa..2d7e8e19d8 100644 --- a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -39,6 +39,7 @@ public class AggregateRecordsTool : IMcpTool public ToolType ToolType { get; } = ToolType.BuiltIn; private static readonly HashSet _validFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; + private static readonly HashSet _validHavingOperators = new(StringComparer.OrdinalIgnoreCase) { "eq", "neq", "gt", "gte", "lt", "lte", "in" }; private static readonly Tool _cachedToolMetadata = new() { @@ -207,7 +208,8 @@ public async Task ExecuteAsync( } string? filter = root.TryGetProperty("filter", out JsonElement filterElement) ? filterElement.GetString() : null; - string orderby = root.TryGetProperty("orderby", out JsonElement orderbyElement) ? (orderbyElement.GetString() ?? "desc") : "desc"; + bool userProvidedOrderby = root.TryGetProperty("orderby", out JsonElement orderbyElement) && !string.IsNullOrWhiteSpace(orderbyElement.GetString()); + string orderby = userProvidedOrderby ? (orderbyElement.GetString() ?? "desc") : "desc"; int? first = null; if (root.TryGetProperty("first", out JsonElement firstElement) && firstElement.ValueKind == JsonValueKind.Number) @@ -234,9 +236,15 @@ public async Task ExecuteAsync( } } - // Validate that first, after, and non-default orderby require groupby + // Validate that first, after, orderby, and having require groupby if (groupby.Count == 0) { + if (userProvidedOrderby) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'orderby' parameter requires 'groupby' to be specified. Sorting applies to grouped aggregation results.", logger); + } + if (first.HasValue) { return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", @@ -269,16 +277,42 @@ public async Task ExecuteAsync( havingOperators = new Dictionary(StringComparer.OrdinalIgnoreCase); foreach (JsonProperty prop in havingElement.EnumerateObject()) { - if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase) && prop.Value.ValueKind == JsonValueKind.Array) + // Reject unsupported operators (e.g. between, notIn, like) + if (!_validHavingOperators.Contains(prop.Name)) { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"Unsupported having operator '{prop.Name}'. Supported operators: {string.Join(", ", _validHavingOperators)}.", logger); + } + + if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase)) + { + if (prop.Value.ValueKind != JsonValueKind.Array) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'having.in' value must be a numeric array. Example: {\"in\": [5, 10]}.", logger); + } + havingInValues = new List(); foreach (JsonElement item in prop.Value.EnumerateArray()) { + if (item.ValueKind != JsonValueKind.Number) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"All values in 'having.in' must be numeric. Found non-numeric value: '{item}'.", logger); + } + havingInValues.Add(item.GetDouble()); } } - else if (prop.Value.ValueKind == JsonValueKind.Number) + else { + // Scalar operators (eq, neq, gt, gte, lt, lte) must have numeric values + if (prop.Value.ValueKind != JsonValueKind.Number) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"The 'having.{prop.Name}' value must be numeric. Got: '{prop.Value}'. HAVING filters compare aggregated numeric results.", logger); + } + havingOperators[prop.Name] = prop.Value.GetDouble(); } } diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs index 9898b16c68..bab8d68f2c 100644 --- a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#nullable enable - using System; using System.Collections.Generic; using System.Text; @@ -235,6 +233,82 @@ public async Task AggregateRecords_HavingWithoutGroupBy_ReturnsInvalidArguments( Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("groupby")); } + [TestMethod] + public async Task AggregateRecords_OrderByWithoutGroupBy_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"orderby\": \"desc\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("groupby")); + } + + [TestMethod] + public async Task AggregateRecords_UnsupportedHavingOperator_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"between\": 5}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("between")); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("Supported operators")); + } + + [TestMethod] + public async Task AggregateRecords_NonNumericHavingValue_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"eq\": \"ten\"}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("numeric")); + } + + [TestMethod] + public async Task AggregateRecords_NonNumericHavingInArray_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"in\": [5, \"abc\"]}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("numeric")); + } + + [TestMethod] + public async Task AggregateRecords_HavingInNotArray_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"in\": 5}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("numeric array")); + } + #endregion #region Alias Convention Tests @@ -364,8 +438,8 @@ public async Task AggregateRecords_OperationCanceled_ReturnsExplicitCanceledMess Assert.IsTrue(result.IsError == true); JsonElement content = ParseContent(result); Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); - string? errorType = error.GetProperty("type").GetString(); - string? errorMessage = error.GetProperty("message").GetString(); + string errorType = error.GetProperty("type").GetString(); + string errorMessage = error.GetProperty("message").GetString(); // Verify the error type identifies it as a cancellation Assert.IsNotNull(errorType); diff --git a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs index 0f5ee3951a..237e40e57e 100644 --- a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs +++ b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#nullable enable - using System; using System.Collections.Generic; using System.Text.Json; From 38c773d3964d2b4c0020aef304913c1095310592 Mon Sep 17 00:00:00 2001 From: souvikghosh04 Date: Thu, 5 Mar 2026 15:52:22 +0530 Subject: [PATCH 31/32] Snapshot files and test fixes --- ...stMethodsAndGraphQLOperations.verified.txt | 10 +- ...tyWithSourceAsStoredProcedure.verified.txt | 10 +- ...tityWithSourceWithDefaultType.verified.txt | 10 +- ...dingEntityWithoutIEnumerables.verified.txt | 10 +- ...ests.TestInitForCosmosDBNoSql.verified.txt | 10 +- ...toredProcedureWithRestMethods.verified.txt | 8 +- ...stMethodsAndGraphQLOperations.verified.txt | 10 +- ...itTests.CosmosDbNoSqlDatabase.verified.txt | 10 +- ...ts.CosmosDbPostgreSqlDatabase.verified.txt | 10 +- ...ionProviders_171ea8114ff71814.verified.txt | 10 +- ...ionProviders_2df7a1794712f154.verified.txt | 10 +- ...ionProviders_59fe1a10aa78899d.verified.txt | 10 +- ...ionProviders_b95b637ea87f16a7.verified.txt | 10 +- ...ionProviders_daacbd948b7ef72f.verified.txt | 10 +- ...tStartingSlashWillHaveItAdded.verified.txt | 10 +- .../InitTests.MsSQLDatabase.verified.txt | 10 +- ...tStartingSlashWillHaveItAdded.verified.txt | 10 +- ...ConfigWithoutConnectionString.verified.txt | 10 +- ...lCharactersInConnectionString.verified.txt | 10 +- ...ationOptions_0546bef37027a950.verified.txt | 10 +- ...ationOptions_0ac567dd32a2e8f5.verified.txt | 10 +- ...ationOptions_0c06949221514e77.verified.txt | 10 +- ...ationOptions_18667ab7db033e9d.verified.txt | 10 +- ...ationOptions_2f42f44c328eb020.verified.txt | 10 +- ...ationOptions_3243d3f3441fdcc1.verified.txt | 10 +- ...ationOptions_53350b8b47df2112.verified.txt | 10 +- ...ationOptions_6584e0ec46b8a11d.verified.txt | 10 +- ...ationOptions_81cc88db3d4eecfb.verified.txt | 10 +- ...ationOptions_8ea187616dbb5577.verified.txt | 10 +- ...ationOptions_905845c29560a3ef.verified.txt | 10 +- ...ationOptions_b2fd24fab5b80917.verified.txt | 10 +- ...ationOptions_bd7cd088755287c9.verified.txt | 10 +- ...ationOptions_d2eccba2f836b380.verified.txt | 10 +- ...ationOptions_d463eed7fe5e4bbe.verified.txt | 10 +- ...ationOptions_d5520dd5c33f7b8d.verified.txt | 10 +- ...ationOptions_eab4a6010e602b59.verified.txt | 10 +- ...ationOptions_ecaa688829b4030e.verified.txt | 10 +- src/Cli/Commands/ConfigureOptions.cs | 2 + ...ReadingRuntimeConfigForCosmos.verified.txt | 25 -- ...tReadingRuntimeConfigForMySql.verified.txt | 268 +++++------------- ...ingRuntimeConfigForPostgreSql.verified.txt | 25 -- 41 files changed, 332 insertions(+), 356 deletions(-) diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt index 3fa1fbc14e..0fd0030402 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt index 76ea01dfca..725eed7a83 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt index 3a8c738a70..70cb42137b 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt index df2cd4b009..46bec31cc9 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt index 1b14a3a7f0..0932956d7a 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: planet, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt index 62d9e237b5..fdda324d36 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt index fa8b16e739..2a4e8653a1 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt b/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt index 9d5458c0ee..4870537837 100644 --- a/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt b/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt index 51f6ad8d95..e03973b91e 100644 --- a/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt index 978d1a253b..d33247dcab 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt index 402bf4d2bc..fa08aefa62 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt index ab71a40f03..98fdb25c77 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt index 25e3976685..74afea9ef6 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt index 140f017b78..3145f775c0 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt b/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt index a3a056ac0a..ae32e3b379 100644 --- a/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt b/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt index f40350c4da..0f2c151763 100644 --- a/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt b/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt index b792d41c9f..d9067e1b43 100644 --- a/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt b/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt index 173960d7b1..e48b87e1c8 100644 --- a/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt b/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt index 25e3976685..74afea9ef6 100644 --- a/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt index 63f0da701c..2cb50a06da 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: DWSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt index f40350c4da..0f2c151763 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt index e59070d692..bbea4aadd3 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -32,14 +32,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt index f7de35b7ae..63f411cdb2 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt index 63f0da701c..2cb50a06da 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: DWSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt index f7de35b7ae..63f411cdb2 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt index 75613db959..5af597f50a 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MySQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt index d93aac7dc6..860fa1616c 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt index 640815babb..48f3d0ce51 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -32,14 +32,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt index 5900015d5a..f56dcad7d7 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt index 63f0da701c..2cb50a06da 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: DWSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt index d93aac7dc6..860fa1616c 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt index d93aac7dc6..860fa1616c 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt index 75613db959..5af597f50a 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MySQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt index 5900015d5a..f56dcad7d7 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt index 75613db959..5af597f50a 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MySQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt index f7de35b7ae..63f411cdb2 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt index 5900015d5a..f56dcad7d7 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index bf12cd5199..99d7efc637 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -50,6 +50,7 @@ public ConfigureOptions( bool? runtimeMcpDmlToolsUpdateRecordEnabled = null, bool? runtimeMcpDmlToolsDeleteRecordEnabled = null, bool? runtimeMcpDmlToolsExecuteEntityEnabled = null, + bool? runtimeMcpDmlToolsAggregateRecordsEnabled = null, bool? runtimeCacheEnabled = null, int? runtimeCacheTtl = null, CompressionLevel? runtimeCompressionLevel = null, @@ -111,6 +112,7 @@ public ConfigureOptions( RuntimeMcpDmlToolsUpdateRecordEnabled = runtimeMcpDmlToolsUpdateRecordEnabled; RuntimeMcpDmlToolsDeleteRecordEnabled = runtimeMcpDmlToolsDeleteRecordEnabled; RuntimeMcpDmlToolsExecuteEntityEnabled = runtimeMcpDmlToolsExecuteEntityEnabled; + RuntimeMcpDmlToolsAggregateRecordsEnabled = runtimeMcpDmlToolsAggregateRecordsEnabled; // Cache RuntimeCacheEnabled = runtimeCacheEnabled; RuntimeCacheTTL = runtimeCacheTtl; diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt index 0b2fd67066..d820e1b124 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt @@ -17,31 +17,6 @@ Path: /graphql, AllowIntrospection: true }, - Mcp: { - Enabled: true, - Path: /mcp, - DmlTools: { - AllToolsEnabled: true, - DescribeEntities: true, - CreateRecord: true, - ReadRecords: true, - UpdateRecord: true, - DeleteRecord: true, - ExecuteEntity: true, - AggregateRecords: true, - UserProvidedAllTools: false, - UserProvidedDescribeEntities: false, - UserProvidedCreateRecord: false, - UserProvidedReadRecords: false, - UserProvidedUpdateRecord: false, - UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false, - UserProvidedAggregateRecords: false - }, - UserProvidedPath: false, - UserProvidedQueryTimeout: false, - EffectiveQueryTimeoutSeconds: 30 - }, Host: { Cors: { Origins: [ diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt index 5522043d3f..5320176e4c 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt @@ -13,31 +13,6 @@ Path: /graphql, AllowIntrospection: true }, - Mcp: { - Enabled: true, - Path: /mcp, - DmlTools: { - AllToolsEnabled: true, - DescribeEntities: true, - CreateRecord: true, - ReadRecords: true, - UpdateRecord: true, - DeleteRecord: true, - ExecuteEntity: true, - AggregateRecords: true, - UserProvidedAllTools: false, - UserProvidedDescribeEntities: false, - UserProvidedCreateRecord: false, - UserProvidedReadRecords: false, - UserProvidedUpdateRecord: false, - UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false, - UserProvidedAggregateRecords: false - }, - UserProvidedPath: false, - UserProvidedQueryTimeout: false, - EffectiveQueryTimeoutSeconds: 30 - }, Host: { Cors: { Origins: [ @@ -395,18 +370,6 @@ Object: books, Type: Table }, - Fields: [ - { - Name: id, - Alias: id, - PrimaryKey: false - }, - { - Name: title, - Alias: title, - PrimaryKey: false - } - ], GraphQL: { Singular: book, Plural: books, @@ -768,6 +731,10 @@ ] } ], + Mappings: { + id: id, + title: title + }, Relationships: { authors: { Cardinality: Many, @@ -799,41 +766,6 @@ } } }, - { - Default_Books: { - Source: { - Object: default_books, - Type: Table - }, - GraphQL: { - Singular: default_book, - Plural: default_books, - Enabled: true - }, - Rest: { - Enabled: true - }, - Permissions: [ - { - Role: anonymous, - Actions: [ - { - Action: Create - }, - { - Action: Read - }, - { - Action: Update - }, - { - Action: Delete - } - ] - } - ] - } - }, { BookNF: { Source: { @@ -1172,13 +1104,6 @@ Object: type_table, Type: Table }, - Fields: [ - { - Name: id, - Alias: typeid, - PrimaryKey: false - } - ], GraphQL: { Singular: SupportedType, Plural: SupportedTypes, @@ -1222,7 +1147,10 @@ } ] } - ] + ], + Mappings: { + id: typeid + } } }, { @@ -1274,18 +1202,6 @@ Object: trees, Type: Table }, - Fields: [ - { - Name: species, - Alias: Scientific Name, - PrimaryKey: false - }, - { - Name: region, - Alias: United State's Region, - PrimaryKey: false - } - ], GraphQL: { Singular: Tree, Plural: Trees, @@ -1329,7 +1245,11 @@ } ] } - ] + ], + Mappings: { + region: United State's Region, + species: Scientific Name + } } }, { @@ -1338,13 +1258,6 @@ Object: trees, Type: Table }, - Fields: [ - { - Name: species, - Alias: fancyName, - PrimaryKey: false - } - ], GraphQL: { Singular: Shrub, Plural: Shrubs, @@ -1390,6 +1303,9 @@ ] } ], + Mappings: { + species: fancyName + }, Relationships: { fungus: { TargetEntity: Fungus, @@ -1409,13 +1325,6 @@ Object: fungi, Type: Table }, - Fields: [ - { - Name: spores, - Alias: hazards, - PrimaryKey: false - } - ], GraphQL: { Singular: fungus, Plural: fungi, @@ -1476,8 +1385,11 @@ ] } ], + Mappings: { + spores: hazards + }, Relationships: { - Shrub: { + shrub: { TargetEntity: Shrub, SourceFields: [ habitat @@ -1493,14 +1405,11 @@ books_view_all: { Source: { Object: books_view_all, - Type: View + Type: View, + KeyFields: [ + id + ] }, - Fields: [ - { - Name: id, - PrimaryKey: true - } - ], GraphQL: { Singular: books_view_all, Plural: books_view_alls, @@ -1542,15 +1451,11 @@ books_view_with_mapping: { Source: { Object: books_view_with_mapping, - Type: View + Type: View, + KeyFields: [ + id + ] }, - Fields: [ - { - Name: id, - Alias: book_id, - PrimaryKey: true - } - ], GraphQL: { Singular: books_view_with_mapping, Plural: books_view_with_mappings, @@ -1568,25 +1473,22 @@ } ] } - ] + ], + Mappings: { + id: book_id + } } }, { stocks_view_selected: { Source: { Object: stocks_view_selected, - Type: View + Type: View, + KeyFields: [ + categoryid, + pieceid + ] }, - Fields: [ - { - Name: categoryid, - PrimaryKey: true - }, - { - Name: pieceid, - PrimaryKey: true - } - ], GraphQL: { Singular: stocks_view_selected, Plural: stocks_view_selecteds, @@ -1628,18 +1530,12 @@ books_publishers_view_composite: { Source: { Object: books_publishers_view_composite, - Type: View + Type: View, + KeyFields: [ + id, + pub_id + ] }, - Fields: [ - { - Name: id, - PrimaryKey: true - }, - { - Name: pub_id, - PrimaryKey: true - } - ], GraphQL: { Singular: books_publishers_view_composite, Plural: books_publishers_view_composites, @@ -1893,28 +1789,6 @@ Object: aow, Type: Table }, - Fields: [ - { - Name: DetailAssessmentAndPlanning, - Alias: 始計, - PrimaryKey: false - }, - { - Name: WagingWar, - Alias: 作戰, - PrimaryKey: false - }, - { - Name: StrategicAttack, - Alias: 謀攻, - PrimaryKey: false - }, - { - Name: NoteNum, - Alias: ┬─┬ノ( º _ ºノ), - PrimaryKey: false - } - ], GraphQL: { Singular: ArtOfWar, Plural: ArtOfWars, @@ -1940,7 +1814,13 @@ } ] } - ] + ], + Mappings: { + DetailAssessmentAndPlanning: 始計, + NoteNum: ┬─┬ノ( º _ ºノ), + StrategicAttack: 謀攻, + WagingWar: 作戰 + } } }, { @@ -2154,18 +2034,6 @@ Object: GQLmappings, Type: Table }, - Fields: [ - { - Name: __column1, - Alias: column1, - PrimaryKey: false - }, - { - Name: __column2, - Alias: column2, - PrimaryKey: false - } - ], GraphQL: { Singular: GQLmappings, Plural: GQLmappings, @@ -2191,7 +2059,11 @@ } ] } - ] + ], + Mappings: { + __column1: column1, + __column2: column2 + } } }, { @@ -2234,18 +2106,6 @@ Object: mappedbookmarks, Type: Table }, - Fields: [ - { - Name: id, - Alias: bkid, - PrimaryKey: false - }, - { - Name: bkname, - Alias: name, - PrimaryKey: false - } - ], GraphQL: { Singular: MappedBookmarks, Plural: MappedBookmarks, @@ -2271,7 +2131,11 @@ } ] } - ] + ], + Mappings: { + bkname: name, + id: bkid + } } }, { @@ -2324,6 +2188,9 @@ Exclude: [ current_date, next_date + ], + Include: [ + * ] } }, @@ -2360,7 +2227,16 @@ Role: anonymous, Actions: [ { - Action: * + Action: Read + }, + { + Action: Create + }, + { + Action: Update + }, + { + Action: Delete } ] } diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt index b52c59df32..5d8dc31646 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt @@ -13,31 +13,6 @@ Path: /graphql, AllowIntrospection: true }, - Mcp: { - Enabled: true, - Path: /mcp, - DmlTools: { - AllToolsEnabled: true, - DescribeEntities: true, - CreateRecord: true, - ReadRecords: true, - UpdateRecord: true, - DeleteRecord: true, - ExecuteEntity: true, - AggregateRecords: true, - UserProvidedAllTools: false, - UserProvidedDescribeEntities: false, - UserProvidedCreateRecord: false, - UserProvidedReadRecords: false, - UserProvidedUpdateRecord: false, - UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false, - UserProvidedAggregateRecords: false - }, - UserProvidedPath: false, - UserProvidedQueryTimeout: false, - EffectiveQueryTimeoutSeconds: 30 - }, Host: { Cors: { Origins: [ From 7b85658dac307966f79e3bc70f0b3d0aa5b5a79f Mon Sep 17 00:00:00 2001 From: souvikghosh04 Date: Thu, 5 Mar 2026 18:21:10 +0530 Subject: [PATCH 32/32] Add AggregateRecords and query-timeout properties to Service.Tests snapshots --- ...ReadingRuntimeConfigForCosmos.verified.txt | 24 ++ ...tReadingRuntimeConfigForMsSql.verified.txt | 1 - ...tReadingRuntimeConfigForMySql.verified.txt | 267 +++++++++++++----- ...ingRuntimeConfigForPostgreSql.verified.txt | 24 ++ 4 files changed, 243 insertions(+), 73 deletions(-) diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt index d820e1b124..15f242605f 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt @@ -17,6 +17,30 @@ Path: /graphql, AllowIntrospection: true }, + Mcp: { + Enabled: true, + Path: /mcp, + DmlTools: { + AllToolsEnabled: true, + DescribeEntities: true, + CreateRecord: true, + ReadRecords: true, + UpdateRecord: true, + DeleteRecord: true, + ExecuteEntity: true, + AggregateRecords: true, + UserProvidedAllTools: false, + UserProvidedDescribeEntities: false, + UserProvidedCreateRecord: false, + UserProvidedReadRecords: false, + UserProvidedUpdateRecord: false, + UserProvidedDeleteRecord: false, + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 + }, Host: { Cors: { Origins: [ diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt index 4ee73b2b4a..966af2777f 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt @@ -42,7 +42,6 @@ UserProvidedExecuteEntity: false, UserProvidedAggregateRecords: false }, - UserProvidedPath: false, UserProvidedQueryTimeout: false, EffectiveQueryTimeoutSeconds: 30 }, diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt index 5320176e4c..0779215cd0 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt @@ -13,6 +13,30 @@ Path: /graphql, AllowIntrospection: true }, + Mcp: { + Enabled: true, + Path: /mcp, + DmlTools: { + AllToolsEnabled: true, + DescribeEntities: true, + CreateRecord: true, + ReadRecords: true, + UpdateRecord: true, + DeleteRecord: true, + ExecuteEntity: true, + AggregateRecords: true, + UserProvidedAllTools: false, + UserProvidedDescribeEntities: false, + UserProvidedCreateRecord: false, + UserProvidedReadRecords: false, + UserProvidedUpdateRecord: false, + UserProvidedDeleteRecord: false, + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 + }, Host: { Cors: { Origins: [ @@ -370,6 +394,18 @@ Object: books, Type: Table }, + Fields: [ + { + Name: id, + Alias: id, + PrimaryKey: false + }, + { + Name: title, + Alias: title, + PrimaryKey: false + } + ], GraphQL: { Singular: book, Plural: books, @@ -731,10 +767,6 @@ ] } ], - Mappings: { - id: id, - title: title - }, Relationships: { authors: { Cardinality: Many, @@ -766,6 +798,41 @@ } } }, + { + Default_Books: { + Source: { + Object: default_books, + Type: Table + }, + GraphQL: { + Singular: default_book, + Plural: default_books, + Enabled: true + }, + Rest: { + Enabled: true + }, + Permissions: [ + { + Role: anonymous, + Actions: [ + { + Action: Create + }, + { + Action: Read + }, + { + Action: Update + }, + { + Action: Delete + } + ] + } + ] + } + }, { BookNF: { Source: { @@ -1104,6 +1171,13 @@ Object: type_table, Type: Table }, + Fields: [ + { + Name: id, + Alias: typeid, + PrimaryKey: false + } + ], GraphQL: { Singular: SupportedType, Plural: SupportedTypes, @@ -1147,10 +1221,7 @@ } ] } - ], - Mappings: { - id: typeid - } + ] } }, { @@ -1202,6 +1273,18 @@ Object: trees, Type: Table }, + Fields: [ + { + Name: species, + Alias: Scientific Name, + PrimaryKey: false + }, + { + Name: region, + Alias: United State's Region, + PrimaryKey: false + } + ], GraphQL: { Singular: Tree, Plural: Trees, @@ -1245,11 +1328,7 @@ } ] } - ], - Mappings: { - region: United State's Region, - species: Scientific Name - } + ] } }, { @@ -1258,6 +1337,13 @@ Object: trees, Type: Table }, + Fields: [ + { + Name: species, + Alias: fancyName, + PrimaryKey: false + } + ], GraphQL: { Singular: Shrub, Plural: Shrubs, @@ -1303,9 +1389,6 @@ ] } ], - Mappings: { - species: fancyName - }, Relationships: { fungus: { TargetEntity: Fungus, @@ -1325,6 +1408,13 @@ Object: fungi, Type: Table }, + Fields: [ + { + Name: spores, + Alias: hazards, + PrimaryKey: false + } + ], GraphQL: { Singular: fungus, Plural: fungi, @@ -1385,11 +1475,8 @@ ] } ], - Mappings: { - spores: hazards - }, Relationships: { - shrub: { + Shrub: { TargetEntity: Shrub, SourceFields: [ habitat @@ -1405,11 +1492,14 @@ books_view_all: { Source: { Object: books_view_all, - Type: View, - KeyFields: [ - id - ] + Type: View }, + Fields: [ + { + Name: id, + PrimaryKey: true + } + ], GraphQL: { Singular: books_view_all, Plural: books_view_alls, @@ -1451,11 +1541,15 @@ books_view_with_mapping: { Source: { Object: books_view_with_mapping, - Type: View, - KeyFields: [ - id - ] + Type: View }, + Fields: [ + { + Name: id, + Alias: book_id, + PrimaryKey: true + } + ], GraphQL: { Singular: books_view_with_mapping, Plural: books_view_with_mappings, @@ -1473,22 +1567,25 @@ } ] } - ], - Mappings: { - id: book_id - } + ] } }, { stocks_view_selected: { Source: { Object: stocks_view_selected, - Type: View, - KeyFields: [ - categoryid, - pieceid - ] + Type: View }, + Fields: [ + { + Name: categoryid, + PrimaryKey: true + }, + { + Name: pieceid, + PrimaryKey: true + } + ], GraphQL: { Singular: stocks_view_selected, Plural: stocks_view_selecteds, @@ -1530,12 +1627,18 @@ books_publishers_view_composite: { Source: { Object: books_publishers_view_composite, - Type: View, - KeyFields: [ - id, - pub_id - ] + Type: View }, + Fields: [ + { + Name: id, + PrimaryKey: true + }, + { + Name: pub_id, + PrimaryKey: true + } + ], GraphQL: { Singular: books_publishers_view_composite, Plural: books_publishers_view_composites, @@ -1789,6 +1892,28 @@ Object: aow, Type: Table }, + Fields: [ + { + Name: DetailAssessmentAndPlanning, + Alias: 始計, + PrimaryKey: false + }, + { + Name: WagingWar, + Alias: 作戰, + PrimaryKey: false + }, + { + Name: StrategicAttack, + Alias: 謀攻, + PrimaryKey: false + }, + { + Name: NoteNum, + Alias: ┬─┬ノ( º _ ºノ), + PrimaryKey: false + } + ], GraphQL: { Singular: ArtOfWar, Plural: ArtOfWars, @@ -1814,13 +1939,7 @@ } ] } - ], - Mappings: { - DetailAssessmentAndPlanning: 始計, - NoteNum: ┬─┬ノ( º _ ºノ), - StrategicAttack: 謀攻, - WagingWar: 作戰 - } + ] } }, { @@ -2034,6 +2153,18 @@ Object: GQLmappings, Type: Table }, + Fields: [ + { + Name: __column1, + Alias: column1, + PrimaryKey: false + }, + { + Name: __column2, + Alias: column2, + PrimaryKey: false + } + ], GraphQL: { Singular: GQLmappings, Plural: GQLmappings, @@ -2059,11 +2190,7 @@ } ] } - ], - Mappings: { - __column1: column1, - __column2: column2 - } + ] } }, { @@ -2106,6 +2233,18 @@ Object: mappedbookmarks, Type: Table }, + Fields: [ + { + Name: id, + Alias: bkid, + PrimaryKey: false + }, + { + Name: bkname, + Alias: name, + PrimaryKey: false + } + ], GraphQL: { Singular: MappedBookmarks, Plural: MappedBookmarks, @@ -2131,11 +2270,7 @@ } ] } - ], - Mappings: { - bkname: name, - id: bkid - } + ] } }, { @@ -2188,9 +2323,6 @@ Exclude: [ current_date, next_date - ], - Include: [ - * ] } }, @@ -2227,16 +2359,7 @@ Role: anonymous, Actions: [ { - Action: Read - }, - { - Action: Create - }, - { - Action: Update - }, - { - Action: Delete + Action: * } ] } diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt index 5d8dc31646..75077c22fa 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt @@ -13,6 +13,30 @@ Path: /graphql, AllowIntrospection: true }, + Mcp: { + Enabled: true, + Path: /mcp, + DmlTools: { + AllToolsEnabled: true, + DescribeEntities: true, + CreateRecord: true, + ReadRecords: true, + UpdateRecord: true, + DeleteRecord: true, + ExecuteEntity: true, + AggregateRecords: true, + UserProvidedAllTools: false, + UserProvidedDescribeEntities: false, + UserProvidedCreateRecord: false, + UserProvidedReadRecords: false, + UserProvidedUpdateRecord: false, + UserProvidedDeleteRecord: false, + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 + }, Host: { Cors: { Origins: [