diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json
index cbe38b7d72..e78861807d 100644
--- a/schemas/dab.draft.schema.json
+++ b/schemas/dab.draft.schema.json
@@ -275,6 +275,12 @@
"description": "Allow enabling/disabling MCP requests for all entities.",
"default": true
},
+ "query-timeout": {
+ "type": "integer",
+ "description": "Execution timeout in seconds for MCP tool operations. Applies to all MCP tools.",
+ "default": 30,
+ "minimum": 1
+ },
"dml-tools": {
"oneOf": [
{
@@ -315,6 +321,11 @@
"type": "boolean",
"description": "Enable/disable the execute-entity tool.",
"default": false
+ },
+ "aggregate-records": {
+ "type": "boolean",
+ "description": "Enable/disable the aggregate-records tool.",
+ "default": false
}
}
}
diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs
new file mode 100644
index 0000000000..b8dd85c175
--- /dev/null
+++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs
@@ -0,0 +1,769 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Data.Common;
+using System.Text.Json;
+using Azure.DataApiBuilder.Auth;
+using Azure.DataApiBuilder.Config.DatabasePrimitives;
+using Azure.DataApiBuilder.Config.ObjectModel;
+using Azure.DataApiBuilder.Core.Authorization;
+using Azure.DataApiBuilder.Core.Configurations;
+using Azure.DataApiBuilder.Core.Models;
+using Azure.DataApiBuilder.Core.Parsers;
+using Azure.DataApiBuilder.Core.Resolvers;
+using Azure.DataApiBuilder.Core.Resolvers.Factories;
+using Azure.DataApiBuilder.Core.Services;
+using Azure.DataApiBuilder.Core.Services.MetadataProviders;
+using Azure.DataApiBuilder.Mcp.Model;
+using Azure.DataApiBuilder.Mcp.Utils;
+using Azure.DataApiBuilder.Service.Exceptions;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Protocol;
+using static Azure.DataApiBuilder.Mcp.Model.McpEnums;
+
+namespace Azure.DataApiBuilder.Mcp.BuiltInTools
+{
+ ///
+ /// Tool to aggregate records from a table/view entity configured in DAB.
+ /// Supports count, avg, sum, min, max with optional distinct, filter, groupby, having, orderby.
+ ///
+ public class AggregateRecordsTool : IMcpTool
+ {
+ public ToolType ToolType { get; } = ToolType.BuiltIn;
+
+ private static readonly HashSet _validFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" };
+
+ public Tool GetToolMetadata()
+ {
+ return new Tool
+ {
+ Name = "aggregate_records",
+ Description = "Computes aggregations (count, avg, sum, min, max) on entity data. "
+ + "STEP 1: Call describe_entities to discover entities with READ permission and their field names. "
+ + "STEP 2: Call this tool with the exact entity name, an aggregation function, and a field name from STEP 1. "
+ + "REQUIRED: entity (exact entity name), function (one of: count, avg, sum, min, max), field (exact field name, or '*' ONLY for count). "
+ + "OPTIONAL: filter (OData WHERE clause applied before aggregating, e.g. 'unitPrice lt 10'), "
+ + "distinct (true to deduplicate values before aggregating), "
+ + "groupby (array of field names to group results by, e.g. ['categoryName']), "
+ + "orderby ('asc' or 'desc' to sort grouped results by aggregated value; requires groupby), "
+ + "having (object to filter groups after aggregating, operators: eq, neq, gt, gte, lt, lte, in; requires groupby), "
+ + "first (integer >= 1, maximum grouped results to return; requires groupby), "
+ + "after (opaque cursor string from a previous response's endCursor for pagination). "
+ + "RESPONSE: The aggregated value is aliased as '{function}_{field}' (e.g. avg_unitPrice, sum_revenue). "
+ + "For count with field '*', the alias is 'count'. "
+ + "When first is used with groupby, response contains: items (array), endCursor (string), hasNextPage (boolean). "
+ + "RULES: 1) ALWAYS call describe_entities first to get valid entity and field names. "
+ + "2) Use field '*' ONLY with function 'count'. "
+ + "3) For avg, sum, min, max: field MUST be a numeric field name from describe_entities. "
+ + "4) orderby, having, and first ONLY apply when groupby is provided. "
+ + "5) Use first and after for paginating large grouped result sets.",
+ InputSchema = JsonSerializer.Deserialize(
+ @"{
+ ""type"": ""object"",
+ ""properties"": {
+ ""entity"": {
+ ""type"": ""string"",
+ ""description"": ""Exact entity name from describe_entities that has READ permission. Must match exactly (case-sensitive).""
+ },
+ ""function"": {
+ ""type"": ""string"",
+ ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""],
+ ""description"": ""Aggregation function to apply. Use 'count' to count records, 'avg' for average, 'sum' for total, 'min' for minimum, 'max' for maximum. For count use field '*' or a specific field name. For avg, sum, min, max the field must be numeric.""
+ },
+ ""field"": {
+ ""type"": ""string"",
+ ""description"": ""Exact field name from describe_entities to aggregate. Use '*' ONLY with function 'count' to count all records. For avg, sum, min, max, provide a numeric field name.""
+ },
+ ""distinct"": {
+ ""type"": ""boolean"",
+ ""description"": ""When true, removes duplicate values before applying the aggregation function. For example, count with distinct counts unique values only. Default is false."",
+ ""default"": false
+ },
+ ""filter"": {
+ ""type"": ""string"",
+ ""description"": ""OData filter expression applied before aggregating (acts as a WHERE clause). Supported operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10' filters to rows where unitPrice is less than 10 before aggregating. Example: 'discontinued eq true and categoryName eq ''Seafood''' filters discontinued seafood products."",
+ ""default"": """"
+ },
+ ""groupby"": {
+ ""type"": ""array"",
+ ""items"": { ""type"": ""string"" },
+ ""description"": ""Array of exact field names from describe_entities to group results by. Each unique combination of grouped field values produces one aggregated row. Grouped field values are included in the response alongside the aggregated value. Example: ['categoryName'] groups by category. Example: ['categoryName', 'region'] groups by both fields."",
+ ""default"": []
+ },
+ ""orderby"": {
+ ""type"": ""string"",
+ ""enum"": [""asc"", ""desc""],
+ ""description"": ""Sort direction for grouped results by the computed aggregated value. 'desc' returns highest values first, 'asc' returns lowest first. ONLY applies when groupby is provided. Default is 'desc'."",
+ ""default"": ""desc""
+ },
+ ""having"": {
+ ""type"": ""object"",
+ ""description"": ""Filter applied AFTER aggregating to filter grouped results by the computed aggregated value (acts as a HAVING clause). ONLY applies when groupby is provided. Multiple operators are AND-ed together. For example, use gt with value 20 to keep groups where the aggregated value exceeds 20. Combine gte and lte to define a range."",
+ ""properties"": {
+ ""eq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value equals this number."" },
+ ""neq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value does not equal this number."" },
+ ""gt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than this number."" },
+ ""gte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than or equal to this number."" },
+ ""lt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than this number."" },
+ ""lte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than or equal to this number."" },
+ ""in"": {
+ ""type"": ""array"",
+ ""items"": { ""type"": ""number"" },
+ ""description"": ""Keep groups where the aggregated value matches any number in this list. Example: [5, 10] keeps groups with aggregated value 5 or 10.""
+ }
+ }
+ },
+ ""first"": {
+ ""type"": ""integer"",
+ ""description"": ""Maximum number of grouped results to return. Used for pagination of grouped results. ONLY applies when groupby is provided. Must be >= 1. When set, the response includes 'items', 'endCursor', and 'hasNextPage' fields for pagination."",
+ ""minimum"": 1
+ },
+ ""after"": {
+ ""type"": ""string"",
+ ""description"": ""Opaque cursor string for pagination. Pass the 'endCursor' value from a previous response to get the next page of results. REQUIRES both groupby and first to be set. Do not construct this value manually; always use the endCursor from a previous response.""
+ }
+ },
+ ""required"": [""entity"", ""function"", ""field""]
+ }"
+ )
+ };
+ }
+
+ public async Task ExecuteAsync(
+ JsonDocument? arguments,
+ IServiceProvider serviceProvider,
+ CancellationToken cancellationToken = default)
+ {
+ ILogger? logger = serviceProvider.GetService>();
+ string toolName = GetToolMetadata().Name;
+
+ RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService();
+ RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig();
+
+ if (runtimeConfig.McpDmlTools?.AggregateRecords is not true)
+ {
+ return McpErrorHelpers.ToolDisabled(toolName, logger);
+ }
+
+ string entityName = string.Empty;
+
+ try
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ if (arguments == null)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger);
+ }
+
+ JsonElement root = arguments.RootElement;
+
+ // Parse required arguments
+ if (!McpArgumentParser.TryParseEntity(root, out string parsedEntityName, out string parseError))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger);
+ }
+
+ entityName = parsedEntityName;
+
+ if (runtimeConfig.Entities?.TryGetValue(entityName, out Entity? entity) == true &&
+ entity.Mcp?.DmlToolEnabled == false)
+ {
+ return McpErrorHelpers.ToolDisabled(toolName, logger, $"DML tools are disabled for entity '{entityName}'.");
+ }
+
+ if (!root.TryGetProperty("function", out JsonElement funcEl) || string.IsNullOrWhiteSpace(funcEl.GetString()))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'function'.", logger);
+ }
+
+ string function = funcEl.GetString()!.ToLowerInvariant();
+ if (!_validFunctions.Contains(function))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger);
+ }
+
+ if (!root.TryGetProperty("field", out JsonElement fieldEl) || string.IsNullOrWhiteSpace(fieldEl.GetString()))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'field'.", logger);
+ }
+
+ string field = fieldEl.GetString()!;
+
+ // Validate field/function compatibility
+ bool isCountStar = function == "count" && field == "*";
+
+ if (field == "*" && function != "count")
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments",
+ $"Field '*' is only valid with function 'count'. For function '{function}', provide a specific field name.", logger);
+ }
+
+ bool distinct = root.TryGetProperty("distinct", out JsonElement distinctEl) && distinctEl.GetBoolean();
+
+ // Reject count(*) with distinct as it is semantically undefined
+ if (isCountStar && distinct)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments",
+ "Cannot use distinct=true with field='*'. DISTINCT requires a specific field name. Use a field name instead of '*' to count distinct values.", logger);
+ }
+
+ string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null;
+ string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc";
+
+ int? first = null;
+ if (root.TryGetProperty("first", out JsonElement firstEl) && firstEl.ValueKind == JsonValueKind.Number)
+ {
+ first = firstEl.GetInt32();
+ if (first < 1)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger);
+ }
+ }
+
+ string? after = root.TryGetProperty("after", out JsonElement afterEl) ? afterEl.GetString() : null;
+
+ List groupby = new();
+ if (root.TryGetProperty("groupby", out JsonElement groupbyEl) && groupbyEl.ValueKind == JsonValueKind.Array)
+ {
+ foreach (JsonElement g in groupbyEl.EnumerateArray())
+ {
+ string? gVal = g.GetString();
+ if (!string.IsNullOrWhiteSpace(gVal))
+ {
+ groupby.Add(gVal);
+ }
+ }
+ }
+
+ Dictionary? havingOps = null;
+ List? havingIn = null;
+ if (root.TryGetProperty("having", out JsonElement havingEl) && havingEl.ValueKind == JsonValueKind.Object)
+ {
+ havingOps = new Dictionary(StringComparer.OrdinalIgnoreCase);
+ foreach (JsonProperty prop in havingEl.EnumerateObject())
+ {
+ if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase) && prop.Value.ValueKind == JsonValueKind.Array)
+ {
+ havingIn = new List();
+ foreach (JsonElement item in prop.Value.EnumerateArray())
+ {
+ havingIn.Add(item.GetDouble());
+ }
+ }
+ else if (prop.Value.ValueKind == JsonValueKind.Number)
+ {
+ havingOps[prop.Name] = prop.Value.GetDouble();
+ }
+ }
+ }
+
+ // Resolve metadata
+ if (!McpMetadataHelper.TryResolveMetadata(
+ entityName,
+ runtimeConfig,
+ serviceProvider,
+ out ISqlMetadataProvider sqlMetadataProvider,
+ out DatabaseObject dbObject,
+ out string dataSourceName,
+ out string metadataError))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger);
+ }
+
+ // Authorization
+ IAuthorizationResolver authResolver = serviceProvider.GetRequiredService();
+ IAuthorizationService authorizationService = serviceProvider.GetRequiredService();
+ IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService();
+ HttpContext? httpContext = httpContextAccessor.HttpContext;
+
+ if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError))
+ {
+ return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger);
+ }
+
+ if (!McpAuthorizationHelper.TryResolveAuthorizedRole(
+ httpContext!,
+ authResolver,
+ entityName,
+ EntityActionOperation.Read,
+ out string? effectiveRole,
+ out string readAuthError))
+ {
+ string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase)
+ ? $"You do not have permission to read records for entity '{entityName}'."
+ : readAuthError;
+ return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger);
+ }
+
+ // Build select list: groupby fields + aggregation field
+ List selectFields = new(groupby);
+ if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase))
+ {
+ selectFields.Add(field);
+ }
+
+ // Build and validate Find context
+ RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider);
+ FindRequestContext context = new(entityName, dbObject, true);
+ httpContext!.Request.Method = "GET";
+
+ requestValidator.ValidateEntity(entityName);
+
+ if (selectFields.Count > 0)
+ {
+ context.UpdateReturnFields(selectFields);
+ }
+
+ if (!string.IsNullOrWhiteSpace(filter))
+ {
+ string filterQueryString = $"?{RequestParser.FILTER_URL}={filter}";
+ context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}");
+ }
+
+ requestValidator.ValidateRequestContext(context);
+
+ AuthorizationResult authorizationResult = await authorizationService.AuthorizeAsync(
+ user: httpContext.User,
+ resource: context,
+ requirements: new[] { new ColumnsPermissionsRequirement() });
+ if (!authorizationResult.Succeeded)
+ {
+ return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger);
+ }
+
+ // Execute query to get records
+ IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService();
+ IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType());
+ JsonDocument? queryResult = await queryEngine.ExecuteAsync(context);
+
+ IActionResult actionResult = queryResult is null
+ ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true)
+ : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true);
+
+ string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult);
+ using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson);
+ JsonElement resultRoot = resultDoc.RootElement;
+
+ // Extract the records array from the response
+ JsonElement records;
+ if (resultRoot.TryGetProperty("value", out JsonElement valueArray))
+ {
+ records = valueArray;
+ }
+ else if (resultRoot.ValueKind == JsonValueKind.Array)
+ {
+ records = resultRoot;
+ }
+ else
+ {
+ records = resultRoot;
+ }
+
+ // Compute alias for the response
+ string alias = ComputeAlias(function, field);
+
+ // Perform in-memory aggregation
+ List> aggregatedResults = PerformAggregation(
+ records, function, field, distinct, groupby, havingOps, havingIn, orderby, alias);
+
+ // Apply pagination if first is specified with groupby
+ if (first.HasValue && groupby.Count > 0)
+ {
+ PaginationResult paginatedResult = ApplyPagination(aggregatedResults, first.Value, after);
+ return McpResponseBuilder.BuildSuccessResult(
+ new Dictionary
+ {
+ ["entity"] = entityName,
+ ["result"] = new Dictionary
+ {
+ ["items"] = paginatedResult.Items,
+ ["endCursor"] = paginatedResult.EndCursor,
+ ["hasNextPage"] = paginatedResult.HasNextPage
+ },
+ ["message"] = $"Successfully aggregated records for entity '{entityName}'"
+ },
+ logger,
+ $"AggregateRecordsTool success for entity {entityName}.");
+ }
+
+ return McpResponseBuilder.BuildSuccessResult(
+ new Dictionary
+ {
+ ["entity"] = entityName,
+ ["result"] = aggregatedResults,
+ ["message"] = $"Successfully aggregated records for entity '{entityName}'"
+ },
+ logger,
+ $"AggregateRecordsTool success for entity {entityName}.");
+ }
+ catch (TimeoutException timeoutEx)
+ {
+ logger?.LogError(timeoutEx, "Aggregation operation timed out for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(
+ toolName,
+ "TimeoutError",
+ $"The aggregation query for entity '{entityName}' timed out. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "This may occur with large datasets or complex aggregations. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.",
+ logger);
+ }
+ catch (TaskCanceledException taskEx)
+ {
+ logger?.LogError(taskEx, "Aggregation task was canceled for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(
+ toolName,
+ "TimeoutError",
+ $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.",
+ logger);
+ }
+ catch (OperationCanceledException)
+ {
+ logger?.LogWarning("Aggregation operation was canceled for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(
+ toolName,
+ "OperationCanceled",
+ $"The aggregation query for entity '{entityName}' was canceled before completion. "
+ + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. "
+ + "No results were returned. You may retry the same request.",
+ logger);
+ }
+ catch (DbException dbEx)
+ {
+ logger?.LogError(dbEx, "Database error during aggregation for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", dbEx.Message, logger);
+ }
+ catch (ArgumentException argEx)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger);
+ }
+ catch (DataApiBuilderException argEx)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger);
+ }
+ catch (Exception ex)
+ {
+ logger?.LogError(ex, "Unexpected error in AggregateRecordsTool.");
+ return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in AggregateRecordsTool.", logger);
+ }
+ }
+
+ ///
+ /// Computes the response alias for the aggregation result.
+ /// For count with "*", the alias is "count". Otherwise it's "{function}_{field}".
+ ///
+ internal static string ComputeAlias(string function, string field)
+ {
+ if (function == "count" && field == "*")
+ {
+ return "count";
+ }
+
+ return $"{function}_{field}";
+ }
+
+ ///
+ /// Performs in-memory aggregation over a JSON array of records.
+ ///
+ internal static List> PerformAggregation(
+ JsonElement records,
+ string function,
+ string field,
+ bool distinct,
+ List groupby,
+ Dictionary? havingOps,
+ List? havingIn,
+ string orderby,
+ string alias)
+ {
+ if (records.ValueKind != JsonValueKind.Array)
+ {
+ return new List> { new() { [alias] = null } };
+ }
+
+ bool isCountStar = function == "count" && field == "*";
+
+ if (groupby.Count == 0)
+ {
+ // No groupby - single result
+ List items = new();
+ foreach (JsonElement record in records.EnumerateArray())
+ {
+ items.Add(record);
+ }
+
+ double? aggregatedValue = ComputeAggregateValue(items, function, field, distinct, isCountStar);
+
+ // Apply having
+ if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn))
+ {
+ return new List>();
+ }
+
+ return new List>
+ {
+ new() { [alias] = aggregatedValue }
+ };
+ }
+ else
+ {
+ // Group by
+ Dictionary> groups = new();
+ Dictionary> groupKeys = new();
+
+ foreach (JsonElement record in records.EnumerateArray())
+ {
+ string key = BuildGroupKey(record, groupby);
+ if (!groups.ContainsKey(key))
+ {
+ groups[key] = new List();
+ groupKeys[key] = ExtractGroupFields(record, groupby);
+ }
+
+ groups[key].Add(record);
+ }
+
+ List> results = new();
+ foreach (KeyValuePair> group in groups)
+ {
+ double? aggregatedValue = ComputeAggregateValue(group.Value, function, field, distinct, isCountStar);
+
+ if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn))
+ {
+ continue;
+ }
+
+ Dictionary row = new(groupKeys[group.Key])
+ {
+ [alias] = aggregatedValue
+ };
+ results.Add(row);
+ }
+
+ // Apply orderby
+ if (orderby.Equals("asc", StringComparison.OrdinalIgnoreCase))
+ {
+ results.Sort((a, b) => CompareNullableDoubles(a[alias] as double?, b[alias] as double?));
+ }
+ else
+ {
+ results.Sort((a, b) => CompareNullableDoubles(b[alias] as double?, a[alias] as double?));
+ }
+
+ return results;
+ }
+ }
+
+ ///
+ /// Represents the result of applying pagination to aggregated results.
+ ///
+ internal sealed class PaginationResult
+ {
+ public List> Items { get; set; } = new();
+ public string? EndCursor { get; set; }
+ public bool HasNextPage { get; set; }
+ }
+
+ ///
+ /// Applies cursor-based pagination to aggregated results.
+ /// The cursor is an opaque base64-encoded offset integer.
+ ///
+ internal static PaginationResult ApplyPagination(
+ List> allResults,
+ int first,
+ string? after)
+ {
+ int startIndex = 0;
+
+ if (!string.IsNullOrWhiteSpace(after))
+ {
+ try
+ {
+ byte[] bytes = Convert.FromBase64String(after);
+ string decoded = System.Text.Encoding.UTF8.GetString(bytes);
+ if (int.TryParse(decoded, out int cursorOffset))
+ {
+ startIndex = cursorOffset;
+ }
+ }
+ catch (FormatException)
+ {
+ // Invalid cursor format; start from beginning
+ }
+ }
+
+ List> pageItems = allResults
+ .Skip(startIndex)
+ .Take(first)
+ .ToList();
+
+ bool hasNextPage = startIndex + first < allResults.Count;
+ string? endCursor = null;
+
+ if (pageItems.Count > 0)
+ {
+ int lastItemIndex = startIndex + pageItems.Count;
+ endCursor = Convert.ToBase64String(
+ System.Text.Encoding.UTF8.GetBytes(lastItemIndex.ToString()));
+ }
+
+ return new PaginationResult
+ {
+ Items = pageItems,
+ EndCursor = endCursor,
+ HasNextPage = hasNextPage
+ };
+ }
+
+ private static double? ComputeAggregateValue(List records, string function, string field, bool distinct, bool isCountStar)
+ {
+ if (isCountStar)
+ {
+ // count(*) always counts all rows; distinct is rejected at ExecuteAsync validation level
+ return records.Count;
+ }
+
+ List values = new();
+ foreach (JsonElement record in records)
+ {
+ if (record.TryGetProperty(field, out JsonElement val) && val.ValueKind == JsonValueKind.Number)
+ {
+ values.Add(val.GetDouble());
+ }
+ }
+
+ if (distinct)
+ {
+ values = values.Distinct().ToList();
+ }
+
+ if (function == "count")
+ {
+ return values.Count;
+ }
+
+ if (values.Count == 0)
+ {
+ return null;
+ }
+
+ return function switch
+ {
+ "avg" => Math.Round(values.Average(), 2),
+ "sum" => values.Sum(),
+ "min" => values.Min(),
+ "max" => values.Max(),
+ _ => null
+ };
+ }
+
+ private static bool PassesHavingFilter(double? value, Dictionary? havingOps, List? havingIn)
+ {
+ if (havingOps == null && havingIn == null)
+ {
+ return true;
+ }
+
+ if (value == null)
+ {
+ return false;
+ }
+
+ double v = value.Value;
+
+ if (havingOps != null)
+ {
+ foreach (KeyValuePair op in havingOps)
+ {
+ bool passes = op.Key.ToLowerInvariant() switch
+ {
+ "eq" => v == op.Value,
+ "neq" => v != op.Value,
+ "gt" => v > op.Value,
+ "gte" => v >= op.Value,
+ "lt" => v < op.Value,
+ "lte" => v <= op.Value,
+ _ => true
+ };
+
+ if (!passes)
+ {
+ return false;
+ }
+ }
+ }
+
+ if (havingIn != null && !havingIn.Contains(v))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ private static string BuildGroupKey(JsonElement record, List groupby)
+ {
+ List parts = new();
+ foreach (string g in groupby)
+ {
+ if (record.TryGetProperty(g, out JsonElement val))
+ {
+ parts.Add(val.ToString());
+ }
+ else
+ {
+ parts.Add("__null__");
+ }
+ }
+
+ // Use null character (\0) as delimiter to avoid collisions with
+ // field values that may contain printable characters like '|'.
+ return string.Join("\0", parts);
+ }
+
+ private static Dictionary ExtractGroupFields(JsonElement record, List groupby)
+ {
+ Dictionary result = new();
+ foreach (string g in groupby)
+ {
+ if (record.TryGetProperty(g, out JsonElement val))
+ {
+ result[g] = McpResponseBuilder.GetJsonValue(val);
+ }
+ else
+ {
+ result[g] = null;
+ }
+ }
+
+ return result;
+ }
+
+ private static int CompareNullableDoubles(double? a, double? b)
+ {
+ if (a == null && b == null)
+ {
+ return 0;
+ }
+
+ if (a == null)
+ {
+ return -1;
+ }
+
+ if (b == null)
+ {
+ return 1;
+ }
+
+ return a.Value.CompareTo(b.Value);
+ }
+ }
+}
diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs
index f69a26fa5d..3ef3aa4d74 100644
--- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs
+++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs
@@ -37,5 +37,10 @@ internal static class McpTelemetryErrorCodes
/// Operation cancelled error code.
///
public const string OPERATION_CANCELLED = "OperationCancelled";
+
+ ///
+ /// Operation timed out error code.
+ ///
+ public const string TIMEOUT = "Timeout";
}
}
diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs
index 2a60557f8d..ac567d4d8c 100644
--- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs
+++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs
@@ -60,8 +60,33 @@ public static async Task ExecuteWithTelemetryAsync(
operation: operation,
dbProcedure: dbProcedure);
- // Execute the tool
- CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, cancellationToken);
+ // Read query-timeout from current config per invocation (hot-reload safe).
+ int timeoutSeconds = McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS;
+ RuntimeConfigProvider? runtimeConfigProvider = serviceProvider.GetService();
+ if (runtimeConfigProvider is not null)
+ {
+ RuntimeConfig config = runtimeConfigProvider.GetConfig();
+ timeoutSeconds = config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds ?? McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS;
+ }
+
+ // Wrap tool execution with the configured timeout using a linked CancellationTokenSource.
+ using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds));
+
+ CallToolResult result;
+ try
+ {
+ result = await tool.ExecuteAsync(arguments, serviceProvider, timeoutCts.Token);
+ }
+ catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
+ {
+ // The timeout CTS fired, not the caller's token. Surface as TimeoutException
+ // so downstream telemetry and tool handlers see TIMEOUT, not cancellation.
+ throw new TimeoutException(
+ $"The MCP tool '{toolName}' did not complete within {timeoutSeconds} {(timeoutSeconds == 1 ? "second" : "seconds")}. "
+ + "This is NOT a tool error. The operation exceeded the configured query-timeout. "
+ + "Try narrowing results with a filter, reducing groupby fields, or using pagination.");
+ }
// Check if the tool returned an error result (tools catch exceptions internally
// and return CallToolResult with IsError=true instead of throwing)
@@ -124,6 +149,7 @@ public static string InferOperationFromTool(IMcpTool tool, string toolName)
"delete_record" => "delete",
"describe_entities" => "describe",
"execute_entity" => "execute",
+ "aggregate_records" => "aggregate",
_ => "execute" // Fallback for any unknown built-in tools
};
}
@@ -188,6 +214,7 @@ public static string MapExceptionToErrorCode(Exception ex)
return ex switch
{
OperationCanceledException => McpTelemetryErrorCodes.OPERATION_CANCELLED,
+ TimeoutException => McpTelemetryErrorCodes.TIMEOUT,
DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthenticationChallenge
=> McpTelemetryErrorCodes.AUTHENTICATION_FAILED,
DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthorizationCheckFailed
diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs
index 262cbc9145..bf12cd5199 100644
--- a/src/Cli/Commands/ConfigureOptions.cs
+++ b/src/Cli/Commands/ConfigureOptions.cs
@@ -42,6 +42,7 @@ public ConfigureOptions(
bool? runtimeMcpEnabled = null,
string? runtimeMcpPath = null,
string? runtimeMcpDescription = null,
+ int? runtimeMcpQueryTimeout = null,
bool? runtimeMcpDmlToolsEnabled = null,
bool? runtimeMcpDmlToolsDescribeEntitiesEnabled = null,
bool? runtimeMcpDmlToolsCreateRecordEnabled = null,
@@ -102,6 +103,7 @@ public ConfigureOptions(
RuntimeMcpEnabled = runtimeMcpEnabled;
RuntimeMcpPath = runtimeMcpPath;
RuntimeMcpDescription = runtimeMcpDescription;
+ RuntimeMcpQueryTimeout = runtimeMcpQueryTimeout;
RuntimeMcpDmlToolsEnabled = runtimeMcpDmlToolsEnabled;
RuntimeMcpDmlToolsDescribeEntitiesEnabled = runtimeMcpDmlToolsDescribeEntitiesEnabled;
RuntimeMcpDmlToolsCreateRecordEnabled = runtimeMcpDmlToolsCreateRecordEnabled;
@@ -203,6 +205,9 @@ public ConfigureOptions(
[Option("runtime.mcp.description", Required = false, HelpText = "Set the MCP server description to be exposed in the initialize response.")]
public string? RuntimeMcpDescription { get; }
+ [Option("runtime.mcp.query-timeout", Required = false, HelpText = "Set the execution timeout in seconds for MCP tool operations. Applies to all MCP tools. Default: 30 (integer). Must be >= 1.")]
+ public int? RuntimeMcpQueryTimeout { get; }
+
[Option("runtime.mcp.dml-tools.enabled", Required = false, HelpText = "Enable DAB's MCP DML tools endpoint. Default: true (boolean).")]
public bool? RuntimeMcpDmlToolsEnabled { get; }
@@ -224,6 +229,9 @@ public ConfigureOptions(
[Option("runtime.mcp.dml-tools.execute-entity.enabled", Required = false, HelpText = "Enable DAB's MCP execute entity tool. Default: true (boolean).")]
public bool? RuntimeMcpDmlToolsExecuteEntityEnabled { get; }
+ [Option("runtime.mcp.dml-tools.aggregate-records.enabled", Required = false, HelpText = "Enable DAB's MCP aggregate records tool. Default: true (boolean).")]
+ public bool? RuntimeMcpDmlToolsAggregateRecordsEnabled { get; }
+
[Option("runtime.cache.enabled", Required = false, HelpText = "Enable DAB's cache globally. (You must also enable each entity's cache separately.). Default: false (boolean).")]
public bool? RuntimeCacheEnabled { get; }
diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs
index 6c51f002b7..fa632f10ac 100644
--- a/src/Cli/ConfigGenerator.cs
+++ b/src/Cli/ConfigGenerator.cs
@@ -876,13 +876,15 @@ private static bool TryUpdateConfiguredRuntimeOptions(
if (options.RuntimeMcpEnabled != null ||
options.RuntimeMcpPath != null ||
options.RuntimeMcpDescription != null ||
+ options.RuntimeMcpQueryTimeout != null ||
options.RuntimeMcpDmlToolsEnabled != null ||
options.RuntimeMcpDmlToolsDescribeEntitiesEnabled != null ||
options.RuntimeMcpDmlToolsCreateRecordEnabled != null ||
options.RuntimeMcpDmlToolsReadRecordsEnabled != null ||
options.RuntimeMcpDmlToolsUpdateRecordEnabled != null ||
options.RuntimeMcpDmlToolsDeleteRecordEnabled != null ||
- options.RuntimeMcpDmlToolsExecuteEntityEnabled != null)
+ options.RuntimeMcpDmlToolsExecuteEntityEnabled != null ||
+ options.RuntimeMcpDmlToolsAggregateRecordsEnabled != null)
{
McpRuntimeOptions updatedMcpOptions = runtimeConfig?.Runtime?.Mcp ?? new();
bool status = TryUpdateConfiguredMcpValues(options, ref updatedMcpOptions);
@@ -1161,6 +1163,14 @@ private static bool TryUpdateConfiguredMcpValues(
_logger.LogInformation("Updated RuntimeConfig with Runtime.Mcp.Description as '{updatedValue}'", updatedValue);
}
+ // Runtime.Mcp.QueryTimeout
+ updatedValue = options?.RuntimeMcpQueryTimeout;
+ if (updatedValue != null)
+ {
+ updatedMcpOptions = updatedMcpOptions! with { QueryTimeout = (int)updatedValue };
+ _logger.LogInformation("Updated RuntimeConfig with Runtime.Mcp.QueryTimeout as '{updatedValue}'", updatedValue);
+ }
+
// Handle DML tools configuration
bool hasToolUpdates = false;
DmlToolsConfig? currentDmlTools = updatedMcpOptions?.DmlTools;
@@ -1181,6 +1191,7 @@ private static bool TryUpdateConfiguredMcpValues(
bool? updateRecord = currentDmlTools?.UpdateRecord;
bool? deleteRecord = currentDmlTools?.DeleteRecord;
bool? executeEntity = currentDmlTools?.ExecuteEntity;
+ bool? aggregateRecords = currentDmlTools?.AggregateRecords;
updatedValue = options?.RuntimeMcpDmlToolsDescribeEntitiesEnabled;
if (updatedValue != null)
@@ -1230,6 +1241,14 @@ private static bool TryUpdateConfiguredMcpValues(
_logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.execute-entity as '{updatedValue}'", updatedValue);
}
+ updatedValue = options?.RuntimeMcpDmlToolsAggregateRecordsEnabled;
+ if (updatedValue != null)
+ {
+ aggregateRecords = (bool)updatedValue;
+ hasToolUpdates = true;
+ _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.aggregate-records as '{updatedValue}'", updatedValue);
+ }
+
if (hasToolUpdates)
{
updatedMcpOptions = updatedMcpOptions! with
@@ -1242,7 +1261,8 @@ private static bool TryUpdateConfiguredMcpValues(
ReadRecords = readRecord,
UpdateRecord = updateRecord,
DeleteRecord = deleteRecord,
- ExecuteEntity = executeEntity
+ ExecuteEntity = executeEntity,
+ AggregateRecords = aggregateRecords
}
};
}
diff --git a/src/Config/Converters/DmlToolsConfigConverter.cs b/src/Config/Converters/DmlToolsConfigConverter.cs
index 82ac3f6069..7e049c7926 100644
--- a/src/Config/Converters/DmlToolsConfigConverter.cs
+++ b/src/Config/Converters/DmlToolsConfigConverter.cs
@@ -44,6 +44,7 @@ internal class DmlToolsConfigConverter : JsonConverter
bool? updateRecord = null;
bool? deleteRecord = null;
bool? executeEntity = null;
+ bool? aggregateRecords = null;
while (reader.Read())
{
@@ -82,6 +83,9 @@ internal class DmlToolsConfigConverter : JsonConverter
case "execute-entity":
executeEntity = value;
break;
+ case "aggregate-records":
+ aggregateRecords = value;
+ break;
default:
// Skip unknown properties
break;
@@ -91,7 +95,8 @@ internal class DmlToolsConfigConverter : JsonConverter
{
// Error on non-boolean values for known properties
if (property?.ToLowerInvariant() is "describe-entities" or "create-record"
- or "read-records" or "update-record" or "delete-record" or "execute-entity")
+ or "read-records" or "update-record" or "delete-record" or "execute-entity"
+ or "aggregate-records")
{
throw new JsonException($"Property '{property}' must be a boolean value.");
}
@@ -110,7 +115,8 @@ internal class DmlToolsConfigConverter : JsonConverter
readRecords: readRecords,
updateRecord: updateRecord,
deleteRecord: deleteRecord,
- executeEntity: executeEntity);
+ executeEntity: executeEntity,
+ aggregateRecords: aggregateRecords);
}
// For any other unexpected token type, return default (all enabled)
@@ -135,7 +141,8 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer
value.UserProvidedReadRecords ||
value.UserProvidedUpdateRecord ||
value.UserProvidedDeleteRecord ||
- value.UserProvidedExecuteEntity;
+ value.UserProvidedExecuteEntity ||
+ value.UserProvidedAggregateRecords;
// Only write the boolean value if it's provided by user
// This prevents writing "dml-tools": true when it's the default
@@ -181,6 +188,11 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer
writer.WriteBoolean("execute-entity", value.ExecuteEntity.Value);
}
+ if (value.UserProvidedAggregateRecords && value.AggregateRecords.HasValue)
+ {
+ writer.WriteBoolean("aggregate-records", value.AggregateRecords.Value);
+ }
+
writer.WriteEndObject();
}
}
diff --git a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs
index 8b3c640725..ad4edc229e 100644
--- a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs
+++ b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs
@@ -66,12 +66,13 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings?
string? path = null;
DmlToolsConfig? dmlTools = null;
string? description = null;
+ int? queryTimeout = null;
while (reader.Read())
{
if (reader.TokenType == JsonTokenType.EndObject)
{
- return new McpRuntimeOptions(enabled, path, dmlTools, description);
+ return new McpRuntimeOptions(enabled, path, dmlTools, description, queryTimeout);
}
string? propertyName = reader.GetString();
@@ -107,6 +108,14 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings?
break;
+ case "query-timeout":
+ if (reader.TokenType is not JsonTokenType.Null)
+ {
+ queryTimeout = reader.GetInt32();
+ }
+
+ break;
+
default:
throw new JsonException($"Unexpected property {propertyName}");
}
@@ -150,6 +159,12 @@ public override void Write(Utf8JsonWriter writer, McpRuntimeOptions value, JsonS
JsonSerializer.Serialize(writer, value.Description, options);
}
+ // Write query-timeout if it's user provided
+ if (value?.UserProvidedQueryTimeout is true && value.QueryTimeout.HasValue)
+ {
+ writer.WriteNumber("query-timeout", value.QueryTimeout.Value);
+ }
+
writer.WriteEndObject();
}
}
diff --git a/src/Config/ObjectModel/DmlToolsConfig.cs b/src/Config/ObjectModel/DmlToolsConfig.cs
index 2a09e9d53c..c1f8b278cd 100644
--- a/src/Config/ObjectModel/DmlToolsConfig.cs
+++ b/src/Config/ObjectModel/DmlToolsConfig.cs
@@ -51,6 +51,11 @@ public record DmlToolsConfig
///
public bool? ExecuteEntity { get; init; }
+ ///
+ /// Whether aggregate-records tool is enabled
+ ///
+ public bool? AggregateRecords { get; init; }
+
[JsonConstructor]
public DmlToolsConfig(
bool? allToolsEnabled = null,
@@ -59,7 +64,8 @@ public DmlToolsConfig(
bool? readRecords = null,
bool? updateRecord = null,
bool? deleteRecord = null,
- bool? executeEntity = null)
+ bool? executeEntity = null,
+ bool? aggregateRecords = null)
{
if (allToolsEnabled is not null)
{
@@ -75,6 +81,7 @@ public DmlToolsConfig(
UpdateRecord = updateRecord ?? toolDefault;
DeleteRecord = deleteRecord ?? toolDefault;
ExecuteEntity = executeEntity ?? toolDefault;
+ AggregateRecords = aggregateRecords ?? toolDefault;
}
else
{
@@ -87,6 +94,7 @@ public DmlToolsConfig(
UpdateRecord = updateRecord ?? DEFAULT_ENABLED;
DeleteRecord = deleteRecord ?? DEFAULT_ENABLED;
ExecuteEntity = executeEntity ?? DEFAULT_ENABLED;
+ AggregateRecords = aggregateRecords ?? DEFAULT_ENABLED;
}
// Track user-provided status - only true if the parameter was not null
@@ -96,6 +104,7 @@ public DmlToolsConfig(
UserProvidedUpdateRecord = updateRecord is not null;
UserProvidedDeleteRecord = deleteRecord is not null;
UserProvidedExecuteEntity = executeEntity is not null;
+ UserProvidedAggregateRecords = aggregateRecords is not null;
}
///
@@ -112,7 +121,8 @@ public static DmlToolsConfig FromBoolean(bool enabled)
readRecords: null,
updateRecord: null,
deleteRecord: null,
- executeEntity: null
+ executeEntity: null,
+ aggregateRecords: null
);
}
@@ -127,7 +137,8 @@ public static DmlToolsConfig FromBoolean(bool enabled)
readRecords: null,
updateRecord: null,
deleteRecord: null,
- executeEntity: null
+ executeEntity: null,
+ aggregateRecords: null
);
///
@@ -185,4 +196,12 @@ public static DmlToolsConfig FromBoolean(bool enabled)
[JsonIgnore(Condition = JsonIgnoreCondition.Always)]
[MemberNotNullWhen(true, nameof(ExecuteEntity))]
public bool UserProvidedExecuteEntity { get; init; } = false;
+
+ ///
+ /// Flag which informs CLI and JSON serializer whether to write aggregate-records
+ /// property/value to the runtime config file.
+ ///
+ [JsonIgnore(Condition = JsonIgnoreCondition.Always)]
+ [MemberNotNullWhen(true, nameof(AggregateRecords))]
+ public bool UserProvidedAggregateRecords { get; init; } = false;
}
diff --git a/src/Config/ObjectModel/McpRuntimeOptions.cs b/src/Config/ObjectModel/McpRuntimeOptions.cs
index e17d53fc8f..f4b4281a14 100644
--- a/src/Config/ObjectModel/McpRuntimeOptions.cs
+++ b/src/Config/ObjectModel/McpRuntimeOptions.cs
@@ -10,6 +10,7 @@ namespace Azure.DataApiBuilder.Config.ObjectModel;
public record McpRuntimeOptions
{
public const string DEFAULT_PATH = "/mcp";
+ public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 30;
///
/// Whether MCP endpoints are enabled
@@ -36,12 +37,22 @@ public record McpRuntimeOptions
[JsonPropertyName("description")]
public string? Description { get; init; }
+ ///
+ /// Execution timeout in seconds for MCP tool operations.
+ /// This timeout wraps the entire tool execution including database queries.
+ /// It applies to all MCP tools, not just aggregation.
+ /// Default: 30 seconds.
+ ///
+ [JsonPropertyName("query-timeout")]
+ public int? QueryTimeout { get; init; }
+
[JsonConstructor]
public McpRuntimeOptions(
bool? Enabled = null,
string? Path = null,
DmlToolsConfig? DmlTools = null,
- string? Description = null)
+ string? Description = null,
+ int? QueryTimeout = null)
{
this.Enabled = Enabled ?? true;
@@ -67,6 +78,12 @@ public McpRuntimeOptions(
}
this.Description = Description;
+
+ if (QueryTimeout is not null)
+ {
+ this.QueryTimeout = QueryTimeout;
+ UserProvidedQueryTimeout = true;
+ }
}
///
@@ -78,4 +95,17 @@ public McpRuntimeOptions(
[JsonIgnore(Condition = JsonIgnoreCondition.Always)]
[MemberNotNullWhen(true, nameof(Enabled))]
public bool UserProvidedPath { get; init; } = false;
+
+ ///
+ /// Flag which informs CLI and JSON serializer whether to write query-timeout
+ /// property and value to the runtime config file.
+ ///
+ [JsonIgnore(Condition = JsonIgnoreCondition.Always)]
+ public bool UserProvidedQueryTimeout { get; init; } = false;
+
+ ///
+ /// Gets the effective query timeout in seconds, using the default if not specified.
+ ///
+ [JsonIgnore(Condition = JsonIgnoreCondition.Always)]
+ public int EffectiveQueryTimeoutSeconds => QueryTimeout ?? DEFAULT_QUERY_TIMEOUT_SECONDS;
}
diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs
index f5112844da..ea2299bc6f 100644
--- a/src/Core/Configurations/RuntimeConfigValidator.cs
+++ b/src/Core/Configurations/RuntimeConfigValidator.cs
@@ -914,6 +914,16 @@ public void ValidateMcpUri(RuntimeConfig runtimeConfig)
statusCode: HttpStatusCode.ServiceUnavailable,
subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError));
}
+
+ // Validate query-timeout if provided
+ if (runtimeConfig.Runtime.Mcp.QueryTimeout is not null && runtimeConfig.Runtime.Mcp.QueryTimeout < 1)
+ {
+ HandleOrRecordException(new DataApiBuilderException(
+ message: "MCP query-timeout must be a positive integer (>= 1 second). " +
+ $"Provided value: {runtimeConfig.Runtime.Mcp.QueryTimeout}.",
+ statusCode: HttpStatusCode.ServiceUnavailable,
+ subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError));
+ }
}
private void ValidateAuthenticationOptions(RuntimeConfig runtimeConfig)
diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs
new file mode 100644
index 0000000000..ce578e746e
--- /dev/null
+++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs
@@ -0,0 +1,1317 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#nullable enable
+
+using System;
+using System.Collections.Generic;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.DataApiBuilder.Auth;
+using Azure.DataApiBuilder.Config.ObjectModel;
+using Azure.DataApiBuilder.Core.Authorization;
+using Azure.DataApiBuilder.Core.Configurations;
+using Azure.DataApiBuilder.Mcp.BuiltInTools;
+using Azure.DataApiBuilder.Mcp.Model;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using ModelContextProtocol.Protocol;
+using Moq;
+
+namespace Azure.DataApiBuilder.Service.Tests.Mcp
+{
+ ///
+ /// Tests for the AggregateRecordsTool MCP tool.
+ /// Covers:
+ /// - Tool metadata and schema validation
+ /// - Runtime-level enabled/disabled configuration
+ /// - Entity-level DML tool configuration
+ /// - Input validation (missing/invalid arguments)
+ /// - In-memory aggregation logic (count, avg, sum, min, max)
+ /// - distinct, groupby, having, orderby
+ /// - Alias convention
+ ///
+ [TestClass]
+ public class AggregateRecordsToolTests
+ {
+ #region Tool Metadata Tests
+
+ [TestMethod]
+ public void GetToolMetadata_ReturnsCorrectName()
+ {
+ AggregateRecordsTool tool = new();
+ Tool metadata = tool.GetToolMetadata();
+ Assert.AreEqual("aggregate_records", metadata.Name);
+ }
+
+ [TestMethod]
+ public void GetToolMetadata_ReturnsCorrectToolType()
+ {
+ AggregateRecordsTool tool = new();
+ Assert.AreEqual(McpEnums.ToolType.BuiltIn, tool.ToolType);
+ }
+
+ [TestMethod]
+ public void GetToolMetadata_HasInputSchema()
+ {
+ AggregateRecordsTool tool = new();
+ Tool metadata = tool.GetToolMetadata();
+ Assert.AreEqual(JsonValueKind.Object, metadata.InputSchema.ValueKind);
+ Assert.IsTrue(metadata.InputSchema.TryGetProperty("properties", out JsonElement properties));
+ Assert.IsTrue(metadata.InputSchema.TryGetProperty("required", out JsonElement required));
+
+ List requiredFields = new();
+ foreach (JsonElement r in required.EnumerateArray())
+ {
+ requiredFields.Add(r.GetString()!);
+ }
+
+ CollectionAssert.Contains(requiredFields, "entity");
+ CollectionAssert.Contains(requiredFields, "function");
+ CollectionAssert.Contains(requiredFields, "field");
+
+ // Verify first and after properties exist in schema
+ Assert.IsTrue(properties.TryGetProperty("first", out JsonElement firstProp));
+ Assert.AreEqual("integer", firstProp.GetProperty("type").GetString());
+ Assert.IsTrue(properties.TryGetProperty("after", out JsonElement afterProp));
+ Assert.AreEqual("string", afterProp.GetProperty("type").GetString());
+ }
+
+ #endregion
+
+ #region Configuration Tests
+
+ [TestMethod]
+ public async Task AggregateRecords_DisabledAtRuntimeLevel_ReturnsToolDisabledError()
+ {
+ RuntimeConfig config = CreateConfig(aggregateRecordsEnabled: false);
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None);
+
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ AssertToolDisabledError(content);
+ }
+
+ [TestMethod]
+ public async Task AggregateRecords_DisabledAtEntityLevel_ReturnsToolDisabledError()
+ {
+ RuntimeConfig config = CreateConfigWithEntityDmlDisabled();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None);
+
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ AssertToolDisabledError(content);
+ }
+
+ #endregion
+
+ #region Input Validation Tests
+
+ [TestMethod]
+ public async Task AggregateRecords_NullArguments_ReturnsInvalidArguments()
+ {
+ RuntimeConfig config = CreateConfig();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ CallToolResult result = await tool.ExecuteAsync(null, sp, CancellationToken.None);
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ Assert.IsTrue(content.TryGetProperty("error", out JsonElement error));
+ Assert.AreEqual("InvalidArguments", error.GetProperty("type").GetString());
+ }
+
+ [TestMethod]
+ public async Task AggregateRecords_MissingEntity_ReturnsInvalidArguments()
+ {
+ RuntimeConfig config = CreateConfig();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ JsonDocument args = JsonDocument.Parse("{\"function\": \"count\", \"field\": \"*\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None);
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString());
+ }
+
+ [TestMethod]
+ public async Task AggregateRecords_MissingFunction_ReturnsInvalidArguments()
+ {
+ RuntimeConfig config = CreateConfig();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"field\": \"*\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None);
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString());
+ }
+
+ [TestMethod]
+ public async Task AggregateRecords_MissingField_ReturnsInvalidArguments()
+ {
+ RuntimeConfig config = CreateConfig();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None);
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString());
+ }
+
+ [TestMethod]
+ public async Task AggregateRecords_InvalidFunction_ReturnsInvalidArguments()
+ {
+ RuntimeConfig config = CreateConfig();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"median\", \"field\": \"price\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None);
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString());
+ Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("median"));
+ }
+
+ #endregion
+
+ #region Alias Convention Tests
+
+ [TestMethod]
+ public void ComputeAlias_CountStar_ReturnsCount()
+ {
+ Assert.AreEqual("count", AggregateRecordsTool.ComputeAlias("count", "*"));
+ }
+
+ [TestMethod]
+ public void ComputeAlias_CountField_ReturnsFunctionField()
+ {
+ Assert.AreEqual("count_supplierId", AggregateRecordsTool.ComputeAlias("count", "supplierId"));
+ }
+
+ [TestMethod]
+ public void ComputeAlias_AvgField_ReturnsFunctionField()
+ {
+ Assert.AreEqual("avg_unitPrice", AggregateRecordsTool.ComputeAlias("avg", "unitPrice"));
+ }
+
+ [TestMethod]
+ public void ComputeAlias_SumField_ReturnsFunctionField()
+ {
+ Assert.AreEqual("sum_unitPrice", AggregateRecordsTool.ComputeAlias("sum", "unitPrice"));
+ }
+
+ [TestMethod]
+ public void ComputeAlias_MinField_ReturnsFunctionField()
+ {
+ Assert.AreEqual("min_price", AggregateRecordsTool.ComputeAlias("min", "price"));
+ }
+
+ [TestMethod]
+ public void ComputeAlias_MaxField_ReturnsFunctionField()
+ {
+ Assert.AreEqual("max_price", AggregateRecordsTool.ComputeAlias("max", "price"));
+ }
+
+ #endregion
+
+ #region In-Memory Aggregation Tests
+
+ [TestMethod]
+ public void PerformAggregation_CountStar_ReturnsCount()
+ {
+ JsonElement records = ParseArray("[{\"id\":1},{\"id\":2},{\"id\":3}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(3.0, result[0]["count"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_Avg_ReturnsAverage()
+ {
+ JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(20.0, result[0]["avg_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_Sum_ReturnsSum()
+ {
+ JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":30}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), null, null, "desc", "sum_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(60.0, result[0]["sum_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_Min_ReturnsMin()
+ {
+ JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "min", "price", false, new(), null, null, "desc", "min_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(5.0, result[0]["min_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_Max_ReturnsMax()
+ {
+ JsonElement records = ParseArray("[{\"price\":10},{\"price\":20},{\"price\":5}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "max", "price", false, new(), null, null, "desc", "max_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(20.0, result[0]["max_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_CountDistinct_ReturnsDistinctCount()
+ {
+ JsonElement records = ParseArray("[{\"supplierId\":1},{\"supplierId\":2},{\"supplierId\":1},{\"supplierId\":3}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", "count_supplierId");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(3.0, result[0]["count_supplierId"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_AvgDistinct_ReturnsDistinctAvg()
+ {
+ JsonElement records = ParseArray("[{\"price\":10},{\"price\":10},{\"price\":20},{\"price\":30}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", true, new(), null, null, "desc", "avg_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(20.0, result[0]["avg_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_GroupBy_ReturnsGroupedResults()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":50}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "desc", "sum_price");
+
+ Assert.AreEqual(2, result.Count);
+ // Desc order: B(50) first, then A(30)
+ Assert.AreEqual("B", result[0]["category"]?.ToString());
+ Assert.AreEqual(50.0, result[0]["sum_price"]);
+ Assert.AreEqual("A", result[1]["category"]?.ToString());
+ Assert.AreEqual(30.0, result[1]["sum_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_GroupBy_Asc_ReturnsSortedAsc()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":30},{\"category\":\"A\",\"price\":20}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, null, null, "asc", "sum_price");
+
+ Assert.AreEqual(2, result.Count);
+ Assert.AreEqual("A", result[0]["category"]?.ToString());
+ Assert.AreEqual(30.0, result[0]["sum_price"]);
+ Assert.AreEqual("B", result[1]["category"]?.ToString());
+ Assert.AreEqual(30.0, result[1]["sum_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_CountStar_GroupBy_ReturnsGroupCounts()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, null, "desc", "count");
+
+ Assert.AreEqual(2, result.Count);
+ Assert.AreEqual("A", result[0]["category"]?.ToString());
+ Assert.AreEqual(2.0, result[0]["count"]);
+ Assert.AreEqual("B", result[1]["category"]?.ToString());
+ Assert.AreEqual(1.0, result[1]["count"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingGt_FiltersResults()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"A\",\"price\":20},{\"category\":\"B\",\"price\":5}]");
+ var having = new Dictionary { ["gt"] = 10 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("A", result[0]["category"]?.ToString());
+ Assert.AreEqual(30.0, result[0]["sum_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingGteLte_FiltersRange()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":100},{\"category\":\"B\",\"price\":20},{\"category\":\"C\",\"price\":1}]");
+ var having = new Dictionary { ["gte"] = 10, ["lte"] = 50 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("B", result[0]["category"]?.ToString());
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingIn_FiltersExactValues()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\"},{\"category\":\"A\"},{\"category\":\"B\"},{\"category\":\"C\"},{\"category\":\"C\"},{\"category\":\"C\"}]");
+ var havingIn = new List { 2, 3 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "category" }, null, havingIn, "desc", "count");
+
+ Assert.AreEqual(2, result.Count);
+ // C(3) desc, A(2)
+ Assert.AreEqual("C", result[0]["category"]?.ToString());
+ Assert.AreEqual(3.0, result[0]["count"]);
+ Assert.AreEqual("A", result[1]["category"]?.ToString());
+ Assert.AreEqual(2.0, result[1]["count"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingEq_FiltersSingleValue()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]");
+ var having = new Dictionary { ["eq"] = 10 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("A", result[0]["category"]?.ToString());
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingNeq_FiltersOutValue()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10},{\"category\":\"B\",\"price\":20}]");
+ var having = new Dictionary { ["neq"] = 10 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("B", result[0]["category"]?.ToString());
+ }
+
+ [TestMethod]
+ public void PerformAggregation_EmptyRecords_ReturnsNull()
+ {
+ JsonElement records = ParseArray("[]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "avg", "price", false, new(), null, null, "desc", "avg_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.IsNull(result[0]["avg_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_EmptyRecordsCountStar_ReturnsZero()
+ {
+ JsonElement records = ParseArray("[]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", "count");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(0.0, result[0]["count"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_MultipleGroupByFields_ReturnsCorrectGroups()
+ {
+ JsonElement records = ParseArray("[{\"cat\":\"A\",\"region\":\"East\",\"price\":10},{\"cat\":\"A\",\"region\":\"East\",\"price\":20},{\"cat\":\"A\",\"region\":\"West\",\"price\":5}]");
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "cat", "region" }, null, null, "desc", "sum_price");
+
+ Assert.AreEqual(2, result.Count);
+ // (A,East)=30 desc, (A,West)=5
+ Assert.AreEqual("A", result[0]["cat"]?.ToString());
+ Assert.AreEqual("East", result[0]["region"]?.ToString());
+ Assert.AreEqual(30.0, result[0]["sum_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingNoResults_ReturnsEmpty()
+ {
+ JsonElement records = ParseArray("[{\"category\":\"A\",\"price\":10}]");
+ var having = new Dictionary { ["gt"] = 100 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new() { "category" }, having, null, "desc", "sum_price");
+
+ Assert.AreEqual(0, result.Count);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingOnSingleResult_Passes()
+ {
+ JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]");
+ var having = new Dictionary { ["gte"] = 100 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price");
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual(110.0, result[0]["sum_price"]);
+ }
+
+ [TestMethod]
+ public void PerformAggregation_HavingOnSingleResult_Fails()
+ {
+ JsonElement records = ParseArray("[{\"price\":50},{\"price\":60}]");
+ var having = new Dictionary { ["gt"] = 200 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "price", false, new(), having, null, "desc", "sum_price");
+
+ Assert.AreEqual(0, result.Count);
+ }
+
+ #endregion
+
+ #region Pagination Tests
+
+ [TestMethod]
+ public void ApplyPagination_FirstOnly_ReturnsFirstNItems()
+ {
+ List> allResults = new()
+ {
+ new() { ["category"] = "A", ["count"] = 10.0 },
+ new() { ["category"] = "B", ["count"] = 8.0 },
+ new() { ["category"] = "C", ["count"] = 6.0 },
+ new() { ["category"] = "D", ["count"] = 4.0 },
+ new() { ["category"] = "E", ["count"] = 2.0 }
+ };
+
+ AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 3, null);
+
+ Assert.AreEqual(3, result.Items.Count);
+ Assert.AreEqual("A", result.Items[0]["category"]?.ToString());
+ Assert.AreEqual("C", result.Items[2]["category"]?.ToString());
+ Assert.IsTrue(result.HasNextPage);
+ Assert.IsNotNull(result.EndCursor);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_FirstWithAfter_ReturnsNextPage()
+ {
+ List> allResults = new()
+ {
+ new() { ["category"] = "A", ["count"] = 10.0 },
+ new() { ["category"] = "B", ["count"] = 8.0 },
+ new() { ["category"] = "C", ["count"] = 6.0 },
+ new() { ["category"] = "D", ["count"] = 4.0 },
+ new() { ["category"] = "E", ["count"] = 2.0 }
+ };
+
+ // First page
+ AggregateRecordsTool.PaginationResult firstPage = AggregateRecordsTool.ApplyPagination(allResults, 3, null);
+ Assert.AreEqual(3, firstPage.Items.Count);
+ Assert.IsTrue(firstPage.HasNextPage);
+
+ // Second page using cursor from first page
+ AggregateRecordsTool.PaginationResult secondPage = AggregateRecordsTool.ApplyPagination(allResults, 3, firstPage.EndCursor);
+ Assert.AreEqual(2, secondPage.Items.Count);
+ Assert.AreEqual("D", secondPage.Items[0]["category"]?.ToString());
+ Assert.AreEqual("E", secondPage.Items[1]["category"]?.ToString());
+ Assert.IsFalse(secondPage.HasNextPage);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_FirstExceedsTotalCount_ReturnsAllItems()
+ {
+ List> allResults = new()
+ {
+ new() { ["category"] = "A", ["count"] = 10.0 },
+ new() { ["category"] = "B", ["count"] = 8.0 }
+ };
+
+ AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, null);
+
+ Assert.AreEqual(2, result.Items.Count);
+ Assert.IsFalse(result.HasNextPage);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_FirstExactlyMatchesTotalCount_HasNextPageIsFalse()
+ {
+ List> allResults = new()
+ {
+ new() { ["category"] = "A", ["count"] = 10.0 },
+ new() { ["category"] = "B", ["count"] = 8.0 },
+ new() { ["category"] = "C", ["count"] = 6.0 }
+ };
+
+ AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 3, null);
+
+ Assert.AreEqual(3, result.Items.Count);
+ Assert.IsFalse(result.HasNextPage);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_EmptyResults_ReturnsEmptyPage()
+ {
+ List> allResults = new();
+
+ AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, null);
+
+ Assert.AreEqual(0, result.Items.Count);
+ Assert.IsFalse(result.HasNextPage);
+ Assert.IsNull(result.EndCursor);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_InvalidCursor_StartsFromBeginning()
+ {
+ List> allResults = new()
+ {
+ new() { ["category"] = "A", ["count"] = 10.0 },
+ new() { ["category"] = "B", ["count"] = 8.0 }
+ };
+
+ AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, "not-valid-base64!!!");
+
+ Assert.AreEqual(2, result.Items.Count);
+ Assert.AreEqual("A", result.Items[0]["category"]?.ToString());
+ Assert.IsFalse(result.HasNextPage);
+ Assert.IsNotNull(result.EndCursor);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_CursorBeyondResults_ReturnsEmptyPage()
+ {
+ List> allResults = new()
+ {
+ new() { ["category"] = "A", ["count"] = 10.0 }
+ };
+
+ // Cursor pointing beyond the end
+ string cursor = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("100"));
+ AggregateRecordsTool.PaginationResult result = AggregateRecordsTool.ApplyPagination(allResults, 5, cursor);
+
+ Assert.AreEqual(0, result.Items.Count);
+ Assert.IsFalse(result.HasNextPage);
+ Assert.IsNull(result.EndCursor);
+ }
+
+ [TestMethod]
+ public void ApplyPagination_MultiplePages_TraversesAllResults()
+ {
+ List> allResults = new();
+ for (int i = 0; i < 8; i++)
+ {
+ allResults.Add(new() { ["category"] = $"Cat{i}", ["count"] = (double)(8 - i) });
+ }
+
+ // Page 1
+ AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 3, null);
+ Assert.AreEqual(3, page1.Items.Count);
+ Assert.IsTrue(page1.HasNextPage);
+
+ // Page 2
+ AggregateRecordsTool.PaginationResult page2 = AggregateRecordsTool.ApplyPagination(allResults, 3, page1.EndCursor);
+ Assert.AreEqual(3, page2.Items.Count);
+ Assert.IsTrue(page2.HasNextPage);
+
+ // Page 3 (last page)
+ AggregateRecordsTool.PaginationResult page3 = AggregateRecordsTool.ApplyPagination(allResults, 3, page2.EndCursor);
+ Assert.AreEqual(2, page3.Items.Count);
+ Assert.IsFalse(page3.HasNextPage);
+ }
+
+ #endregion
+
+ #region Timeout and Cancellation Tests
+
+ ///
+ /// Verifies that OperationCanceledException produces a model-explicit error
+ /// that clearly states the operation was canceled, not errored.
+ ///
+ [TestMethod]
+ public async Task AggregateRecords_OperationCanceled_ReturnsExplicitCanceledMessage()
+ {
+ RuntimeConfig config = CreateConfig();
+ IServiceProvider sp = CreateServiceProvider(config);
+ AggregateRecordsTool tool = new();
+
+ // Create a pre-canceled token
+ CancellationTokenSource cts = new();
+ cts.Cancel();
+
+ JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}");
+ CallToolResult result = await tool.ExecuteAsync(args, sp, cts.Token);
+
+ Assert.IsTrue(result.IsError == true);
+ JsonElement content = ParseContent(result);
+ Assert.IsTrue(content.TryGetProperty("error", out JsonElement error));
+ string? errorType = error.GetProperty("type").GetString();
+ string? errorMessage = error.GetProperty("message").GetString();
+
+ // Verify the error type identifies it as a cancellation
+ Assert.IsNotNull(errorType);
+ Assert.AreEqual("OperationCanceled", errorType);
+
+ // Verify the message explicitly tells the model this is NOT a tool error
+ Assert.IsNotNull(errorMessage);
+ Assert.IsTrue(errorMessage!.Contains("NOT a tool error"), "Message must explicitly state this is NOT a tool error.");
+
+ // Verify the message tells the model what happened
+ Assert.IsTrue(errorMessage.Contains("canceled"), "Message must mention the operation was canceled.");
+
+ // Verify the message tells the model it can retry
+ Assert.IsTrue(errorMessage.Contains("retry"), "Message must tell the model it can retry.");
+ }
+
+ ///
+ /// Verifies that the timeout error message provides explicit guidance to the model
+ /// about what happened and what to do next.
+ ///
+ [TestMethod]
+ public void TimeoutErrorMessage_ContainsModelGuidance()
+ {
+ // Simulate what the tool builds for a TimeoutException response
+ string entityName = "Product";
+ string expectedMessage = $"The aggregation query for entity '{entityName}' timed out. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "This may occur with large datasets or complex aggregations. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.";
+
+ // Verify message explicitly states it's NOT a tool error
+ Assert.IsTrue(expectedMessage.Contains("NOT a tool error"), "Timeout message must state this is NOT a tool error.");
+
+ // Verify message explains the cause
+ Assert.IsTrue(expectedMessage.Contains("database did not respond"), "Timeout message must explain the database didn't respond.");
+
+ // Verify message mentions large datasets
+ Assert.IsTrue(expectedMessage.Contains("large datasets"), "Timeout message must mention large datasets as a possible cause.");
+
+ // Verify message provides actionable remediation steps
+ Assert.IsTrue(expectedMessage.Contains("filter"), "Timeout message must suggest using a filter.");
+ Assert.IsTrue(expectedMessage.Contains("groupby"), "Timeout message must suggest reducing groupby fields.");
+ Assert.IsTrue(expectedMessage.Contains("first"), "Timeout message must suggest using pagination with first.");
+ }
+
+ ///
+ /// Verifies that TaskCanceledException (which typically signals HTTP/DB timeout)
+ /// produces a TimeoutError, not a cancellation error.
+ ///
+ [TestMethod]
+ public void TaskCanceledErrorMessage_ContainsTimeoutGuidance()
+ {
+ // Simulate what the tool builds for a TaskCanceledException response
+ string entityName = "Product";
+ string expectedMessage = $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.";
+
+ // TaskCanceledException should produce a TimeoutError, not OperationCanceled
+ Assert.IsTrue(expectedMessage.Contains("NOT a tool error"), "TaskCanceled message must state this is NOT a tool error.");
+ Assert.IsTrue(expectedMessage.Contains("timeout"), "TaskCanceled message must reference timeout as the cause.");
+ Assert.IsTrue(expectedMessage.Contains("filter"), "TaskCanceled message must suggest filter as remediation.");
+ Assert.IsTrue(expectedMessage.Contains("first"), "TaskCanceled message must suggest first for pagination.");
+ }
+
+ ///
+ /// Verifies that the OperationCanceled error message for a specific entity
+ /// includes the entity name so the model knows which aggregation failed.
+ ///
+ [TestMethod]
+ public void CanceledErrorMessage_IncludesEntityName()
+ {
+ string entityName = "LargeProductCatalog";
+ string expectedMessage = $"The aggregation query for entity '{entityName}' was canceled before completion. "
+ + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. "
+ + "No results were returned. You may retry the same request.";
+
+ Assert.IsTrue(expectedMessage.Contains(entityName), "Canceled message must include the entity name.");
+ Assert.IsTrue(expectedMessage.Contains("No results were returned"), "Canceled message must state no results were returned.");
+ }
+
+ ///
+ /// Verifies that the timeout error message for a specific entity
+ /// includes the entity name so the model knows which aggregation timed out.
+ ///
+ [TestMethod]
+ public void TimeoutErrorMessage_IncludesEntityName()
+ {
+ string entityName = "HugeTransactionLog";
+ string expectedMessage = $"The aggregation query for entity '{entityName}' timed out. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "This may occur with large datasets or complex aggregations. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.";
+
+ Assert.IsTrue(expectedMessage.Contains(entityName), "Timeout message must include the entity name.");
+ }
+
+ #endregion
+
+ #region Spec Example Tests
+
+ ///
+ /// Spec Example 1: "How many products are there?"
+ /// COUNT(*) → 77
+ ///
+ [TestMethod]
+ public void SpecExample01_CountStar_ReturnsTotal()
+ {
+ // Build 77 product records
+ List items = new();
+ for (int i = 1; i <= 77; i++)
+ {
+ items.Add($"{{\"id\":{i}}}");
+ }
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "*");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new(), null, null, "desc", alias);
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("count", alias);
+ Assert.AreEqual(77.0, result[0]["count"]);
+ }
+
+ ///
+ /// Spec Example 2: "What is the average price of products under $10?"
+ /// AVG(unitPrice) WHERE unitPrice < 10 → 6.74
+ /// Filter is applied at DB level; we supply pre-filtered records.
+ ///
+ [TestMethod]
+ public void SpecExample02_AvgWithFilter_ReturnsFilteredAverage()
+ {
+ // Pre-filtered records (unitPrice < 10) that average to 6.74
+ // 4.50 + 6.00 + 9.72 = 20.22 / 3 = 6.74
+ JsonElement records = ParseArray("[{\"unitPrice\":4.5},{\"unitPrice\":6.0},{\"unitPrice\":9.72}]");
+ string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice");
+ var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new(), null, null, "desc", alias);
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("avg_unitPrice", alias);
+ Assert.AreEqual(6.74, result[0]["avg_unitPrice"]);
+ }
+
+ ///
+ /// Spec Example 3: "Which categories have more than 20 products?"
+ /// COUNT(*) GROUP BY categoryName HAVING COUNT(*) > 20
+ /// Expected: Beverages=24, Condiments=22
+ ///
+ [TestMethod]
+ public void SpecExample03_CountGroupByHavingGt_FiltersGroups()
+ {
+ List items = new();
+ for (int i = 0; i < 24; i++)
+ {
+ items.Add("{\"categoryName\":\"Beverages\"}");
+ }
+
+ for (int i = 0; i < 22; i++)
+ {
+ items.Add("{\"categoryName\":\"Condiments\"}");
+ }
+
+ for (int i = 0; i < 12; i++)
+ {
+ items.Add("{\"categoryName\":\"Seafood\"}");
+ }
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "*");
+ var having = new Dictionary { ["gt"] = 20 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, having, null, "desc", alias);
+
+ Assert.AreEqual(2, result.Count);
+ // Desc order: Beverages(24), Condiments(22)
+ Assert.AreEqual("Beverages", result[0]["categoryName"]?.ToString());
+ Assert.AreEqual(24.0, result[0]["count"]);
+ Assert.AreEqual("Condiments", result[1]["categoryName"]?.ToString());
+ Assert.AreEqual(22.0, result[1]["count"]);
+ }
+
+ ///
+ /// Spec Example 4: "For discontinued products, which categories have a total revenue between $500 and $10,000?"
+ /// SUM(unitPrice) WHERE discontinued=1 GROUP BY categoryName HAVING SUM >= 500 AND <= 10000
+ /// Expected: Seafood=1834.50, Produce=742.00
+ ///
+ [TestMethod]
+ public void SpecExample04_SumFilterGroupByHavingRange_ReturnsMatchingGroups()
+ {
+ // Pre-filtered (discontinued) records with prices summing per category
+ JsonElement records = ParseArray(
+ "[" +
+ "{\"categoryName\":\"Seafood\",\"unitPrice\":900}," +
+ "{\"categoryName\":\"Seafood\",\"unitPrice\":934.5}," +
+ "{\"categoryName\":\"Produce\",\"unitPrice\":400}," +
+ "{\"categoryName\":\"Produce\",\"unitPrice\":342}," +
+ "{\"categoryName\":\"Dairy\",\"unitPrice\":50}" + // Sum 50, below 500
+ "]");
+ string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice");
+ var having = new Dictionary { ["gte"] = 500, ["lte"] = 10000 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "unitPrice", false, new() { "categoryName" }, having, null, "desc", alias);
+
+ Assert.AreEqual(2, result.Count);
+ Assert.AreEqual("sum_unitPrice", alias);
+ // Desc order: Seafood(1834.5), Produce(742)
+ Assert.AreEqual("Seafood", result[0]["categoryName"]?.ToString());
+ Assert.AreEqual(1834.5, result[0]["sum_unitPrice"]);
+ Assert.AreEqual("Produce", result[1]["categoryName"]?.ToString());
+ Assert.AreEqual(742.0, result[1]["sum_unitPrice"]);
+ }
+
+ ///
+ /// Spec Example 5: "How many distinct suppliers do we have?"
+ /// COUNT(DISTINCT supplierId) → 29
+ ///
+ [TestMethod]
+ public void SpecExample05_CountDistinct_ReturnsDistinctCount()
+ {
+ // Build records with 29 distinct supplierIds plus duplicates
+ List items = new();
+ for (int i = 1; i <= 29; i++)
+ {
+ items.Add($"{{\"supplierId\":{i}}}");
+ }
+
+ // Add duplicates
+ items.Add("{\"supplierId\":1}");
+ items.Add("{\"supplierId\":5}");
+ items.Add("{\"supplierId\":10}");
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "supplierId");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "supplierId", true, new(), null, null, "desc", alias);
+
+ Assert.AreEqual(1, result.Count);
+ Assert.AreEqual("count_supplierId", alias);
+ Assert.AreEqual(29.0, result[0]["count_supplierId"]);
+ }
+
+ ///
+ /// Spec Example 6: "Which categories have exactly 5 or 10 products?"
+ /// COUNT(*) GROUP BY categoryName HAVING COUNT(*) IN (5, 10)
+ /// Expected: Grains=5, Produce=5
+ ///
+ [TestMethod]
+ public void SpecExample06_CountGroupByHavingIn_FiltersExactCounts()
+ {
+ List items = new();
+ for (int i = 0; i < 5; i++)
+ {
+ items.Add("{\"categoryName\":\"Grains\"}");
+ }
+
+ for (int i = 0; i < 5; i++)
+ {
+ items.Add("{\"categoryName\":\"Produce\"}");
+ }
+
+ for (int i = 0; i < 12; i++)
+ {
+ items.Add("{\"categoryName\":\"Beverages\"}");
+ }
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "*");
+ var havingIn = new List { 5, 10 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, havingIn, "desc", alias);
+
+ Assert.AreEqual(2, result.Count);
+ // Both have count=5, same order as grouped
+ Assert.AreEqual(5.0, result[0]["count"]);
+ Assert.AreEqual(5.0, result[1]["count"]);
+ }
+
+ ///
+ /// Spec Example 7: "What is the average distinct unit price per category, for categories averaging over $25?"
+ /// AVG(DISTINCT unitPrice) GROUP BY categoryName HAVING AVG(DISTINCT unitPrice) > 25
+ /// Expected: Meat/Poultry=54.01, Beverages=32.50
+ ///
+ [TestMethod]
+ public void SpecExample07_AvgDistinctGroupByHavingGt_FiltersAboveThreshold()
+ {
+ // Meat/Poultry: distinct prices {40.00, 68.02} → avg = 54.01
+ // Beverages: distinct prices {25.00, 40.00} → avg = 32.50
+ // Condiments: distinct prices {10.00, 15.00} → avg = 12.50 (below threshold)
+ JsonElement records = ParseArray(
+ "[" +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":68.02}," +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," + // duplicate
+ "{\"categoryName\":\"Beverages\",\"unitPrice\":25.00}," +
+ "{\"categoryName\":\"Beverages\",\"unitPrice\":40.00}," +
+ "{\"categoryName\":\"Beverages\",\"unitPrice\":25.00}," + // duplicate
+ "{\"categoryName\":\"Condiments\",\"unitPrice\":10.00}," +
+ "{\"categoryName\":\"Condiments\",\"unitPrice\":15.00}" +
+ "]");
+ string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice");
+ var having = new Dictionary { ["gt"] = 25 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", true, new() { "categoryName" }, having, null, "desc", alias);
+
+ Assert.AreEqual(2, result.Count);
+ Assert.AreEqual("avg_unitPrice", alias);
+ // Desc order: Meat/Poultry(54.01), Beverages(32.5)
+ Assert.AreEqual("Meat/Poultry", result[0]["categoryName"]?.ToString());
+ Assert.AreEqual(54.01, result[0]["avg_unitPrice"]);
+ Assert.AreEqual("Beverages", result[1]["categoryName"]?.ToString());
+ Assert.AreEqual(32.5, result[1]["avg_unitPrice"]);
+ }
+
+ ///
+ /// Spec Example 8: "Which categories have the most products?"
+ /// COUNT(*) GROUP BY categoryName ORDER BY DESC
+ /// Expected: Confections=13, Beverages=12, Condiments=12, Seafood=12
+ ///
+ [TestMethod]
+ public void SpecExample08_CountGroupByOrderByDesc_ReturnsSortedDesc()
+ {
+ List items = new();
+ for (int i = 0; i < 13; i++)
+ {
+ items.Add("{\"categoryName\":\"Confections\"}");
+ }
+
+ for (int i = 0; i < 12; i++)
+ {
+ items.Add("{\"categoryName\":\"Beverages\"}");
+ }
+
+ for (int i = 0; i < 12; i++)
+ {
+ items.Add("{\"categoryName\":\"Condiments\"}");
+ }
+
+ for (int i = 0; i < 12; i++)
+ {
+ items.Add("{\"categoryName\":\"Seafood\"}");
+ }
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "*");
+ var result = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias);
+
+ Assert.AreEqual(4, result.Count);
+ Assert.AreEqual("Confections", result[0]["categoryName"]?.ToString());
+ Assert.AreEqual(13.0, result[0]["count"]);
+ // Remaining 3 all have count=12
+ Assert.AreEqual(12.0, result[1]["count"]);
+ Assert.AreEqual(12.0, result[2]["count"]);
+ Assert.AreEqual(12.0, result[3]["count"]);
+ }
+
+ ///
+ /// Spec Example 9: "What are the cheapest categories by average price?"
+ /// AVG(unitPrice) GROUP BY categoryName ORDER BY ASC
+ /// Expected: Grains/Cereals=20.25, Condiments=23.06, Produce=32.37
+ ///
+ [TestMethod]
+ public void SpecExample09_AvgGroupByOrderByAsc_ReturnsSortedAsc()
+ {
+ // Grains/Cereals: {15.50, 25.00} → avg = 20.25
+ // Condiments: {20.12, 26.00} → avg = 23.06
+ // Produce: {28.74, 36.00} → avg = 32.37
+ JsonElement records = ParseArray(
+ "[" +
+ "{\"categoryName\":\"Grains/Cereals\",\"unitPrice\":15.50}," +
+ "{\"categoryName\":\"Grains/Cereals\",\"unitPrice\":25.00}," +
+ "{\"categoryName\":\"Condiments\",\"unitPrice\":20.12}," +
+ "{\"categoryName\":\"Condiments\",\"unitPrice\":26.00}," +
+ "{\"categoryName\":\"Produce\",\"unitPrice\":28.74}," +
+ "{\"categoryName\":\"Produce\",\"unitPrice\":36.00}" +
+ "]");
+ string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice");
+ var result = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new() { "categoryName" }, null, null, "asc", alias);
+
+ Assert.AreEqual(3, result.Count);
+ // Asc order: Grains/Cereals(20.25), Condiments(23.06), Produce(32.37)
+ Assert.AreEqual("Grains/Cereals", result[0]["categoryName"]?.ToString());
+ Assert.AreEqual(20.25, result[0]["avg_unitPrice"]);
+ Assert.AreEqual("Condiments", result[1]["categoryName"]?.ToString());
+ Assert.AreEqual(23.06, result[1]["avg_unitPrice"]);
+ Assert.AreEqual("Produce", result[2]["categoryName"]?.ToString());
+ Assert.AreEqual(32.37, result[2]["avg_unitPrice"]);
+ }
+
+ ///
+ /// Spec Example 10: "For categories with over $500 revenue from discontinued products, which has the highest total?"
+ /// SUM(unitPrice) WHERE discontinued=1 GROUP BY categoryName HAVING SUM > 500 ORDER BY DESC
+ /// Expected: Seafood=1834.50, Meat/Poultry=1062.50, Produce=742.00
+ ///
+ [TestMethod]
+ public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_ReturnsSortedFiltered()
+ {
+ // Pre-filtered (discontinued) records
+ JsonElement records = ParseArray(
+ "[" +
+ "{\"categoryName\":\"Seafood\",\"unitPrice\":900}," +
+ "{\"categoryName\":\"Seafood\",\"unitPrice\":934.5}," +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":500}," +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":562.5}," +
+ "{\"categoryName\":\"Produce\",\"unitPrice\":400}," +
+ "{\"categoryName\":\"Produce\",\"unitPrice\":342}," +
+ "{\"categoryName\":\"Dairy\",\"unitPrice\":50}" + // Sum 50, below 500
+ "]");
+ string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice");
+ var having = new Dictionary { ["gt"] = 500 };
+ var result = AggregateRecordsTool.PerformAggregation(records, "sum", "unitPrice", false, new() { "categoryName" }, having, null, "desc", alias);
+
+ Assert.AreEqual(3, result.Count);
+ // Desc order: Seafood(1834.5), Meat/Poultry(1062.5), Produce(742)
+ Assert.AreEqual("Seafood", result[0]["categoryName"]?.ToString());
+ Assert.AreEqual(1834.5, result[0]["sum_unitPrice"]);
+ Assert.AreEqual("Meat/Poultry", result[1]["categoryName"]?.ToString());
+ Assert.AreEqual(1062.5, result[1]["sum_unitPrice"]);
+ Assert.AreEqual("Produce", result[2]["categoryName"]?.ToString());
+ Assert.AreEqual(742.0, result[2]["sum_unitPrice"]);
+ }
+
+ ///
+ /// Spec Example 11: "Show me the first 5 categories by product count"
+ /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5
+ /// Expected: 5 items with hasNextPage=true, endCursor set
+ ///
+ [TestMethod]
+ public void SpecExample11_CountGroupByOrderByDescFirst5_ReturnsPaginatedResults()
+ {
+ List items = new();
+ string[] categories = { "Confections", "Beverages", "Condiments", "Seafood", "Dairy", "Grains/Cereals", "Meat/Poultry", "Produce" };
+ int[] counts = { 13, 12, 12, 12, 10, 7, 6, 5 };
+ for (int c = 0; c < categories.Length; c++)
+ {
+ for (int i = 0; i < counts[c]; i++)
+ {
+ items.Add($"{{\"categoryName\":\"{categories[c]}\"}}");
+ }
+ }
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "*");
+ var allResults = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias);
+
+ Assert.AreEqual(8, allResults.Count);
+
+ // Apply pagination: first=5
+ AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 5, null);
+
+ Assert.AreEqual(5, page1.Items.Count);
+ Assert.AreEqual("Confections", page1.Items[0]["categoryName"]?.ToString());
+ Assert.AreEqual(13.0, page1.Items[0]["count"]);
+ Assert.AreEqual("Dairy", page1.Items[4]["categoryName"]?.ToString());
+ Assert.AreEqual(10.0, page1.Items[4]["count"]);
+ Assert.IsTrue(page1.HasNextPage);
+ Assert.IsNotNull(page1.EndCursor);
+ }
+
+ ///
+ /// Spec Example 12: "Show me the next 5 categories" (continuation of Example 11)
+ /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 AFTER cursor
+ /// Expected: 3 items (remaining), hasNextPage=false
+ ///
+ [TestMethod]
+ public void SpecExample12_CountGroupByOrderByDescFirst5After_ReturnsNextPage()
+ {
+ List items = new();
+ string[] categories = { "Confections", "Beverages", "Condiments", "Seafood", "Dairy", "Grains/Cereals", "Meat/Poultry", "Produce" };
+ int[] counts = { 13, 12, 12, 12, 10, 7, 6, 5 };
+ for (int c = 0; c < categories.Length; c++)
+ {
+ for (int i = 0; i < counts[c]; i++)
+ {
+ items.Add($"{{\"categoryName\":\"{categories[c]}\"}}");
+ }
+ }
+
+ JsonElement records = ParseArray($"[{string.Join(",", items)}]");
+ string alias = AggregateRecordsTool.ComputeAlias("count", "*");
+ var allResults = AggregateRecordsTool.PerformAggregation(records, "count", "*", false, new() { "categoryName" }, null, null, "desc", alias);
+
+ // Page 1
+ AggregateRecordsTool.PaginationResult page1 = AggregateRecordsTool.ApplyPagination(allResults, 5, null);
+ Assert.IsTrue(page1.HasNextPage);
+
+ // Page 2 (continuation)
+ AggregateRecordsTool.PaginationResult page2 = AggregateRecordsTool.ApplyPagination(allResults, 5, page1.EndCursor);
+
+ Assert.AreEqual(3, page2.Items.Count);
+ Assert.AreEqual("Grains/Cereals", page2.Items[0]["categoryName"]?.ToString());
+ Assert.AreEqual(7.0, page2.Items[0]["count"]);
+ Assert.AreEqual("Meat/Poultry", page2.Items[1]["categoryName"]?.ToString());
+ Assert.AreEqual(6.0, page2.Items[1]["count"]);
+ Assert.AreEqual("Produce", page2.Items[2]["categoryName"]?.ToString());
+ Assert.AreEqual(5.0, page2.Items[2]["count"]);
+ Assert.IsFalse(page2.HasNextPage);
+ }
+
+ ///
+ /// Spec Example 13: "Show me the top 3 most expensive categories by average price"
+ /// AVG(unitPrice) GROUP BY categoryName ORDER BY DESC FIRST 3
+ /// Expected: Meat/Poultry=54.01, Beverages=37.98, Seafood=37.08
+ ///
+ [TestMethod]
+ public void SpecExample13_AvgGroupByOrderByDescFirst3_ReturnsTop3()
+ {
+ // Meat/Poultry: {40.00, 68.02} → avg = 54.01
+ // Beverages: {30.96, 45.00} → avg = 37.98
+ // Seafood: {25.16, 49.00} → avg = 37.08
+ // Condiments: {10.00, 15.00} → avg = 12.50
+ JsonElement records = ParseArray(
+ "[" +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":40.00}," +
+ "{\"categoryName\":\"Meat/Poultry\",\"unitPrice\":68.02}," +
+ "{\"categoryName\":\"Beverages\",\"unitPrice\":30.96}," +
+ "{\"categoryName\":\"Beverages\",\"unitPrice\":45.00}," +
+ "{\"categoryName\":\"Seafood\",\"unitPrice\":25.16}," +
+ "{\"categoryName\":\"Seafood\",\"unitPrice\":49.00}," +
+ "{\"categoryName\":\"Condiments\",\"unitPrice\":10.00}," +
+ "{\"categoryName\":\"Condiments\",\"unitPrice\":15.00}" +
+ "]");
+ string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice");
+ var allResults = AggregateRecordsTool.PerformAggregation(records, "avg", "unitPrice", false, new() { "categoryName" }, null, null, "desc", alias);
+
+ Assert.AreEqual(4, allResults.Count);
+
+ // Apply pagination: first=3
+ AggregateRecordsTool.PaginationResult page = AggregateRecordsTool.ApplyPagination(allResults, 3, null);
+
+ Assert.AreEqual(3, page.Items.Count);
+ Assert.AreEqual("Meat/Poultry", page.Items[0]["categoryName"]?.ToString());
+ Assert.AreEqual(54.01, page.Items[0]["avg_unitPrice"]);
+ Assert.AreEqual("Beverages", page.Items[1]["categoryName"]?.ToString());
+ Assert.AreEqual(37.98, page.Items[1]["avg_unitPrice"]);
+ Assert.AreEqual("Seafood", page.Items[2]["categoryName"]?.ToString());
+ Assert.AreEqual(37.08, page.Items[2]["avg_unitPrice"]);
+ Assert.IsTrue(page.HasNextPage);
+ }
+
+ #endregion
+
+ #region Helper Methods
+
+ private static JsonElement ParseArray(string json)
+ {
+ return JsonDocument.Parse(json).RootElement;
+ }
+
+ private static JsonElement ParseContent(CallToolResult result)
+ {
+ TextContentBlock firstContent = (TextContentBlock)result.Content[0];
+ return JsonDocument.Parse(firstContent.Text).RootElement;
+ }
+
+ private static void AssertToolDisabledError(JsonElement content)
+ {
+ Assert.IsTrue(content.TryGetProperty("error", out JsonElement error));
+ Assert.IsTrue(error.TryGetProperty("type", out JsonElement errorType));
+ Assert.AreEqual("ToolDisabled", errorType.GetString());
+ }
+
+ private static RuntimeConfig CreateConfig(bool aggregateRecordsEnabled = true)
+ {
+ Dictionary entities = new()
+ {
+ ["Book"] = new Entity(
+ Source: new("books", EntitySourceType.Table, null, null),
+ GraphQL: new("Book", "Books"),
+ Fields: null,
+ Rest: new(Enabled: true),
+ Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] {
+ new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null)
+ }) },
+ Mappings: null,
+ Relationships: null,
+ Mcp: null
+ )
+ };
+
+ return new RuntimeConfig(
+ Schema: "test-schema",
+ DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null),
+ Runtime: new(
+ Rest: new(),
+ GraphQL: new(),
+ Mcp: new(
+ Enabled: true,
+ Path: "/mcp",
+ DmlTools: new(
+ describeEntities: true,
+ readRecords: true,
+ createRecord: true,
+ updateRecord: true,
+ deleteRecord: true,
+ executeEntity: true,
+ aggregateRecords: aggregateRecordsEnabled
+ )
+ ),
+ Host: new(Cors: null, Authentication: null, Mode: HostMode.Development)
+ ),
+ Entities: new(entities)
+ );
+ }
+
+ private static RuntimeConfig CreateConfigWithEntityDmlDisabled()
+ {
+ Dictionary entities = new()
+ {
+ ["Book"] = new Entity(
+ Source: new("books", EntitySourceType.Table, null, null),
+ GraphQL: new("Book", "Books"),
+ Fields: null,
+ Rest: new(Enabled: true),
+ Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] {
+ new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null)
+ }) },
+ Mappings: null,
+ Relationships: null,
+ Mcp: new EntityMcpOptions(customToolEnabled: false, dmlToolsEnabled: false)
+ )
+ };
+
+ return new RuntimeConfig(
+ Schema: "test-schema",
+ DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null),
+ Runtime: new(
+ Rest: new(),
+ GraphQL: new(),
+ Mcp: new(
+ Enabled: true,
+ Path: "/mcp",
+ DmlTools: new(
+ describeEntities: true,
+ readRecords: true,
+ createRecord: true,
+ updateRecord: true,
+ deleteRecord: true,
+ executeEntity: true,
+ aggregateRecords: true
+ )
+ ),
+ Host: new(Cors: null, Authentication: null, Mode: HostMode.Development)
+ ),
+ Entities: new(entities)
+ );
+ }
+
+ private static IServiceProvider CreateServiceProvider(RuntimeConfig config)
+ {
+ ServiceCollection services = new();
+
+ RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config);
+ services.AddSingleton(configProvider);
+
+ Mock mockAuthResolver = new();
+ mockAuthResolver.Setup(x => x.IsValidRoleContext(It.IsAny())).Returns(true);
+ services.AddSingleton(mockAuthResolver.Object);
+
+ Mock mockHttpContext = new();
+ Mock mockRequest = new();
+ mockRequest.Setup(x => x.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]).Returns("anonymous");
+ mockHttpContext.Setup(x => x.Request).Returns(mockRequest.Object);
+
+ Mock mockHttpContextAccessor = new();
+ mockHttpContextAccessor.Setup(x => x.HttpContext).Returns(mockHttpContext.Object);
+ services.AddSingleton(mockHttpContextAccessor.Object);
+
+ services.AddLogging();
+
+ return services.BuildServiceProvider();
+ }
+
+ #endregion
+ }
+}
diff --git a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs
index d2f6554cd3..b4ae074207 100644
--- a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs
+++ b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs
@@ -48,6 +48,7 @@ public class EntityLevelDmlToolConfigurationTests
[DataRow("UpdateRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}, \"fields\": {\"title\": \"Updated\"}}", false, DisplayName = "UpdateRecord respects entity-level DmlToolEnabled=false")]
[DataRow("DeleteRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}}", false, DisplayName = "DeleteRecord respects entity-level DmlToolEnabled=false")]
[DataRow("ExecuteEntity", "{\"entity\": \"GetBook\"}", true, DisplayName = "ExecuteEntity respects entity-level DmlToolEnabled=false")]
+ [DataRow("AggregateRecords", "{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}", false, DisplayName = "AggregateRecords respects entity-level DmlToolEnabled=false")]
public async Task DmlTool_RespectsEntityLevelDmlToolDisabled(string toolType, string jsonArguments, bool isStoredProcedure)
{
// Arrange
@@ -238,6 +239,7 @@ private static IMcpTool CreateTool(string toolType)
"UpdateRecord" => new UpdateRecordTool(),
"DeleteRecord" => new DeleteRecordTool(),
"ExecuteEntity" => new ExecuteEntityTool(),
+ "AggregateRecords" => new AggregateRecordsTool(),
_ => throw new ArgumentException($"Unknown tool type: {toolType}", nameof(toolType))
};
}
diff --git a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs
new file mode 100644
index 0000000000..0f5ee3951a
--- /dev/null
+++ b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs
@@ -0,0 +1,452 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#nullable enable
+
+using System;
+using System.Collections.Generic;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.DataApiBuilder.Config;
+using Azure.DataApiBuilder.Config.ObjectModel;
+using Azure.DataApiBuilder.Core.Configurations;
+using Azure.DataApiBuilder.Mcp.Model;
+using Azure.DataApiBuilder.Mcp.Utils;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using ModelContextProtocol.Protocol;
+using static Azure.DataApiBuilder.Mcp.Model.McpEnums;
+
+namespace Azure.DataApiBuilder.Service.Tests.Mcp
+{
+ ///
+ /// Tests for the MCP query-timeout configuration property.
+ /// Verifies:
+ /// - Default value of 30 seconds when not configured
+ /// - Custom value overrides default
+ /// - Timeout wrapping applies to all MCP tools via ExecuteWithTelemetryAsync
+ /// - Hot reload: changing config value updates behavior without restart
+ /// - Timeout surfaces as TimeoutException, not generic cancellation
+ /// - Telemetry maps timeout to TIMEOUT error code
+ ///
+ [TestClass]
+ public class McpQueryTimeoutTests
+ {
+ #region Default Value Tests
+
+ [TestMethod]
+ public void McpRuntimeOptions_DefaultQueryTimeout_Is30Seconds()
+ {
+ Assert.AreEqual(30, McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS);
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_EffectiveTimeout_ReturnsDefault_WhenNotConfigured()
+ {
+ McpRuntimeOptions options = new();
+ Assert.IsNull(options.QueryTimeout);
+ Assert.AreEqual(30, options.EffectiveQueryTimeoutSeconds);
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_EffectiveTimeout_ReturnsConfiguredValue()
+ {
+ McpRuntimeOptions options = new(QueryTimeout: 60);
+ Assert.AreEqual(60, options.QueryTimeout);
+ Assert.AreEqual(60, options.EffectiveQueryTimeoutSeconds);
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_UserProvidedQueryTimeout_FalseByDefault()
+ {
+ McpRuntimeOptions options = new();
+ Assert.IsFalse(options.UserProvidedQueryTimeout);
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_UserProvidedQueryTimeout_TrueWhenSet()
+ {
+ McpRuntimeOptions options = new(QueryTimeout: 45);
+ Assert.IsTrue(options.UserProvidedQueryTimeout);
+ }
+
+ #endregion
+
+ #region Custom Value Tests
+
+ [TestMethod]
+ public void McpRuntimeOptions_CustomTimeout_1Second()
+ {
+ McpRuntimeOptions options = new(QueryTimeout: 1);
+ Assert.AreEqual(1, options.EffectiveQueryTimeoutSeconds);
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_CustomTimeout_120Seconds()
+ {
+ McpRuntimeOptions options = new(QueryTimeout: 120);
+ Assert.AreEqual(120, options.EffectiveQueryTimeoutSeconds);
+ }
+
+ [TestMethod]
+ public void RuntimeConfig_McpQueryTimeout_ExposedInConfig()
+ {
+ RuntimeConfig config = CreateConfigWithQueryTimeout(45);
+ Assert.AreEqual(45, config.Runtime?.Mcp?.QueryTimeout);
+ Assert.AreEqual(45, config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds);
+ }
+
+ [TestMethod]
+ public void RuntimeConfig_McpQueryTimeout_DefaultWhenNotSet()
+ {
+ RuntimeConfig config = CreateConfigWithoutQueryTimeout();
+ Assert.IsNull(config.Runtime?.Mcp?.QueryTimeout);
+ Assert.AreEqual(30, config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds);
+ }
+
+ #endregion
+
+ #region Timeout Wrapping Tests
+
+ [TestMethod]
+ public async Task ExecuteWithTelemetry_CompletesSuccessfully_WithinTimeout()
+ {
+ // A tool that completes immediately should succeed
+ RuntimeConfig config = CreateConfigWithQueryTimeout(30);
+ IServiceProvider sp = CreateServiceProviderWithConfig(config);
+ IMcpTool tool = new ImmediateCompletionTool();
+
+ CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ tool, "test_tool", null, sp, CancellationToken.None);
+
+ // Tool should complete without throwing TimeoutException
+ Assert.IsNotNull(result);
+ Assert.IsTrue(result.IsError != true, "Tool result should not be an error");
+ }
+
+ [TestMethod]
+ public async Task ExecuteWithTelemetry_ThrowsTimeoutException_WhenToolExceedsTimeout()
+ {
+ // Configure a very short timeout (1 second) and a tool that takes longer
+ RuntimeConfig config = CreateConfigWithQueryTimeout(1);
+ IServiceProvider sp = CreateServiceProviderWithConfig(config);
+ IMcpTool tool = new SlowTool(delaySeconds: 30);
+
+ await Assert.ThrowsExceptionAsync(async () =>
+ {
+ await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ tool, "slow_tool", null, sp, CancellationToken.None);
+ });
+ }
+
+ [TestMethod]
+ public async Task ExecuteWithTelemetry_TimeoutMessage_ContainsToolName()
+ {
+ RuntimeConfig config = CreateConfigWithQueryTimeout(1);
+ IServiceProvider sp = CreateServiceProviderWithConfig(config);
+ IMcpTool tool = new SlowTool(delaySeconds: 30);
+
+ try
+ {
+ await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ tool, "aggregate_records", null, sp, CancellationToken.None);
+ Assert.Fail("Expected TimeoutException");
+ }
+ catch (TimeoutException ex)
+ {
+ Assert.IsTrue(ex.Message.Contains("aggregate_records"), "Message should contain tool name");
+ Assert.IsTrue(ex.Message.Contains("1 second"), "Message should contain timeout value");
+ Assert.IsTrue(ex.Message.Contains("NOT a tool error"), "Message should clarify it is not a tool error");
+ }
+ }
+
+ [TestMethod]
+ public async Task ExecuteWithTelemetry_ClientCancellation_PropagatesAsCancellation()
+ {
+ // Client cancellation (not timeout) should propagate as OperationCanceledException
+ // rather than being converted to TimeoutException.
+ RuntimeConfig config = CreateConfigWithQueryTimeout(30);
+ IServiceProvider sp = CreateServiceProviderWithConfig(config);
+ IMcpTool tool = new SlowTool(delaySeconds: 30);
+
+ using CancellationTokenSource cts = new();
+ cts.Cancel(); // Cancel immediately
+
+ try
+ {
+ await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ tool, "test_tool", null, sp, cts.Token);
+ Assert.Fail("Expected OperationCanceledException or subclass to be thrown");
+ }
+ catch (TimeoutException)
+ {
+ Assert.Fail("Client cancellation should NOT be converted to TimeoutException");
+ }
+ catch (OperationCanceledException)
+ {
+ // Expected: client-initiated cancellation propagates as OperationCanceledException
+ // (or subclass TaskCanceledException)
+ }
+ }
+
+ [TestMethod]
+ public async Task ExecuteWithTelemetry_AppliesTimeout_ToAllToolTypes()
+ {
+ // Verify timeout applies to both built-in and custom tool types
+ RuntimeConfig config = CreateConfigWithQueryTimeout(1);
+ IServiceProvider sp = CreateServiceProviderWithConfig(config);
+
+ // Test with built-in tool type
+ IMcpTool builtInTool = new SlowTool(delaySeconds: 30, toolType: ToolType.BuiltIn);
+ await Assert.ThrowsExceptionAsync(async () =>
+ {
+ await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ builtInTool, "builtin_slow", null, sp, CancellationToken.None);
+ });
+
+ // Test with custom tool type
+ IMcpTool customTool = new SlowTool(delaySeconds: 30, toolType: ToolType.Custom);
+ await Assert.ThrowsExceptionAsync(async () =>
+ {
+ await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ customTool, "custom_slow", null, sp, CancellationToken.None);
+ });
+ }
+
+ #endregion
+
+ #region Hot Reload Tests
+
+ [TestMethod]
+ public async Task ExecuteWithTelemetry_ReadsConfigPerInvocation_HotReload()
+ {
+ // First invocation with long timeout should succeed
+ RuntimeConfig config1 = CreateConfigWithQueryTimeout(30);
+ IServiceProvider sp1 = CreateServiceProviderWithConfig(config1);
+
+ IMcpTool fastTool = new ImmediateCompletionTool();
+ CallToolResult result1 = await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ fastTool, "test_tool", null, sp1, CancellationToken.None);
+ Assert.IsNotNull(result1);
+
+ // Second invocation with very short timeout and a slow tool should timeout.
+ // This demonstrates that each invocation reads the current config value.
+ RuntimeConfig config2 = CreateConfigWithQueryTimeout(1);
+ IServiceProvider sp2 = CreateServiceProviderWithConfig(config2);
+
+ IMcpTool slowTool = new SlowTool(delaySeconds: 30);
+ await Assert.ThrowsExceptionAsync(async () =>
+ {
+ await McpTelemetryHelper.ExecuteWithTelemetryAsync(
+ slowTool, "test_tool", null, sp2, CancellationToken.None);
+ });
+ }
+
+ #endregion
+
+ #region Telemetry Tests
+
+ [TestMethod]
+ public void MapExceptionToErrorCode_TimeoutException_ReturnsTIMEOUT()
+ {
+ string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new TimeoutException());
+ Assert.AreEqual(McpTelemetryErrorCodes.TIMEOUT, errorCode);
+ }
+
+ [TestMethod]
+ public void MapExceptionToErrorCode_OperationCanceled_ReturnsOPERATION_CANCELLED()
+ {
+ string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new OperationCanceledException());
+ Assert.AreEqual(McpTelemetryErrorCodes.OPERATION_CANCELLED, errorCode);
+ }
+
+ [TestMethod]
+ public void MapExceptionToErrorCode_TaskCanceled_ReturnsOPERATION_CANCELLED()
+ {
+ string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new TaskCanceledException());
+ Assert.AreEqual(McpTelemetryErrorCodes.OPERATION_CANCELLED, errorCode);
+ }
+
+ #endregion
+
+ #region JSON Serialization Tests
+
+ [TestMethod]
+ public void McpRuntimeOptions_Serialization_IncludesQueryTimeout_WhenUserProvided()
+ {
+ McpRuntimeOptions options = new(QueryTimeout: 45);
+ JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions();
+ string json = JsonSerializer.Serialize(options, serializerOptions);
+ Assert.IsTrue(json.Contains("\"query-timeout\": 45") || json.Contains("\"query-timeout\":45"));
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_Deserialization_ReadsQueryTimeout()
+ {
+ string json = @"{""enabled"": true, ""query-timeout"": 60}";
+ JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions();
+ McpRuntimeOptions? options = JsonSerializer.Deserialize(json, serializerOptions);
+ Assert.IsNotNull(options);
+ Assert.AreEqual(60, options.QueryTimeout);
+ Assert.AreEqual(60, options.EffectiveQueryTimeoutSeconds);
+ }
+
+ [TestMethod]
+ public void McpRuntimeOptions_Deserialization_DefaultsWhenOmitted()
+ {
+ string json = @"{""enabled"": true}";
+ JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions();
+ McpRuntimeOptions? options = JsonSerializer.Deserialize(json, serializerOptions);
+ Assert.IsNotNull(options);
+ Assert.IsNull(options.QueryTimeout);
+ Assert.AreEqual(30, options.EffectiveQueryTimeoutSeconds);
+ }
+
+ #endregion
+
+ #region Helpers
+
+ private static RuntimeConfig CreateConfigWithQueryTimeout(int timeoutSeconds)
+ {
+ return new RuntimeConfig(
+ Schema: "test-schema",
+ DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null),
+ Runtime: new(
+ Rest: new(),
+ GraphQL: new(),
+ Mcp: new(
+ Enabled: true,
+ Path: "/mcp",
+ QueryTimeout: timeoutSeconds,
+ DmlTools: new(
+ describeEntities: true,
+ readRecords: true,
+ createRecord: true,
+ updateRecord: true,
+ deleteRecord: true,
+ executeEntity: true,
+ aggregateRecords: true
+ )
+ ),
+ Host: new(Cors: null, Authentication: null, Mode: HostMode.Development)
+ ),
+ Entities: new(new Dictionary())
+ );
+ }
+
+ private static RuntimeConfig CreateConfigWithoutQueryTimeout()
+ {
+ return new RuntimeConfig(
+ Schema: "test-schema",
+ DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null),
+ Runtime: new(
+ Rest: new(),
+ GraphQL: new(),
+ Mcp: new(
+ Enabled: true,
+ Path: "/mcp",
+ DmlTools: new(
+ describeEntities: true,
+ readRecords: true,
+ createRecord: true,
+ updateRecord: true,
+ deleteRecord: true,
+ executeEntity: true,
+ aggregateRecords: true
+ )
+ ),
+ Host: new(Cors: null, Authentication: null, Mode: HostMode.Development)
+ ),
+ Entities: new(new Dictionary())
+ );
+ }
+
+ private static IServiceProvider CreateServiceProviderWithConfig(RuntimeConfig config)
+ {
+ ServiceCollection services = new();
+ RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config);
+ services.AddSingleton(configProvider);
+ services.AddLogging();
+ return services.BuildServiceProvider();
+ }
+
+ ///
+ /// A mock tool that completes immediately with a success result.
+ ///
+ private class ImmediateCompletionTool : IMcpTool
+ {
+ public ToolType ToolType { get; } = ToolType.BuiltIn;
+
+ public Tool GetToolMetadata()
+ {
+ using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}");
+ return new Tool
+ {
+ Name = "test_tool",
+ Description = "A test tool that completes immediately",
+ InputSchema = doc.RootElement.Clone()
+ };
+ }
+
+ public Task ExecuteAsync(
+ JsonDocument? arguments,
+ IServiceProvider serviceProvider,
+ CancellationToken cancellationToken = default)
+ {
+ return Task.FromResult(new CallToolResult
+ {
+ Content = new List
+ {
+ new TextContentBlock { Text = "{\"result\": \"success\"}" }
+ }
+ });
+ }
+ }
+
+ ///
+ /// A mock tool that delays for a specified duration, respecting cancellation.
+ /// Used to test timeout behavior.
+ ///
+ private class SlowTool : IMcpTool
+ {
+ private readonly int _delaySeconds;
+
+ public SlowTool(int delaySeconds, ToolType toolType = ToolType.BuiltIn)
+ {
+ _delaySeconds = delaySeconds;
+ ToolType = toolType;
+ }
+
+ public ToolType ToolType { get; }
+
+ public Tool GetToolMetadata()
+ {
+ using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}");
+ return new Tool
+ {
+ Name = "slow_tool",
+ Description = "A test tool that takes a long time",
+ InputSchema = doc.RootElement.Clone()
+ };
+ }
+
+ public async Task ExecuteAsync(
+ JsonDocument? arguments,
+ IServiceProvider serviceProvider,
+ CancellationToken cancellationToken = default)
+ {
+ await Task.Delay(TimeSpan.FromSeconds(_delaySeconds), cancellationToken);
+ return new CallToolResult
+ {
+ Content = new List
+ {
+ new TextContentBlock { Text = "{\"result\": \"completed\"}" }
+ }
+ };
+ }
+ }
+
+ #endregion
+ }
+}
diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs
new file mode 100644
index 0000000000..dee8842a0d
--- /dev/null
+++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs
@@ -0,0 +1,411 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#nullable enable
+
+using System.Collections.Generic;
+using System.Text.Json;
+using Azure.DataApiBuilder.Mcp.BuiltInTools;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Azure.DataApiBuilder.Service.Tests.UnitTests
+{
+ ///
+ /// Unit tests for AggregateRecordsTool's internal helper methods.
+ /// Covers validation paths, aggregation logic, and pagination behavior.
+ ///
+ [TestClass]
+ public class AggregateRecordsToolTests
+ {
+ #region ComputeAlias tests
+
+ [TestMethod]
+ [DataRow("count", "*", "count", DisplayName = "count(*) alias is 'count'")]
+ [DataRow("count", "userId", "count_userId", DisplayName = "count(field) alias is 'count_field'")]
+ [DataRow("avg", "price", "avg_price", DisplayName = "avg alias")]
+ [DataRow("sum", "amount", "sum_amount", DisplayName = "sum alias")]
+ [DataRow("min", "age", "min_age", DisplayName = "min alias")]
+ [DataRow("max", "score", "max_score", DisplayName = "max alias")]
+ public void ComputeAlias_ReturnsExpectedAlias(string function, string field, string expectedAlias)
+ {
+ string result = AggregateRecordsTool.ComputeAlias(function, field);
+ Assert.AreEqual(expectedAlias, result);
+ }
+
+ #endregion
+
+ #region PerformAggregation tests - no groupby
+
+ private static JsonElement CreateRecordsArray(params double[] values)
+ {
+ var list = new List