From 614c7ee82e84c14fa66400bfc702537ae0cc9cf7 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Mon, 2 Mar 2026 18:53:06 +0000
Subject: [PATCH 1/4] Initial plan
From 42e0567a0d4be0219439d298606c7e4d0f5c9772 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Mon, 2 Mar 2026 19:15:27 +0000
Subject: [PATCH 2/4] Add aggregate_records tool with validation fixes and
query-timeout support
- Add AggregateRecordsTool with validation for field/function compatibility
- Reject count+star+distinct=true with InvalidArguments
- Reject field='*' with non-count functions
- Add QueryTimeout to McpRuntimeOptions and converter
- Add AggregateRecords to DmlToolsConfig and converter
- Add TIMEOUT error code to McpTelemetryErrorCodes
- Add timeout wrapping and aggregate_records mapping to McpTelemetryHelper
- Update JSON schema with query-timeout and aggregate-records
- Update CLI ConfigGenerator and ConfigureOptions for aggregate-records
- Add AggregateRecordsToolTests with validation, aggregation, and pagination tests
- Add timeout and aggregate_records tests to McpTelemetryTests
Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com>
---
schemas/dab.draft.schema.json | 11 +
.../BuiltInTools/AggregateRecordsTool.cs | 780 ++++++++++++++++++
.../Utils/McpTelemetryErrorCodes.cs | 5 +
.../Utils/McpTelemetryHelper.cs | 31 +-
src/Cli/Commands/ConfigureOptions.cs | 5 +
src/Cli/ConfigGenerator.cs | 15 +-
.../Converters/DmlToolsConfigConverter.cs | 18 +-
.../McpRuntimeOptionsConverterFactory.cs | 17 +-
src/Config/ObjectModel/DmlToolsConfig.cs | 25 +-
src/Config/ObjectModel/McpRuntimeOptions.cs | 32 +-
.../UnitTests/AggregateRecordsToolTests.cs | 411 +++++++++
.../UnitTests/McpTelemetryTests.cs | 168 +++-
12 files changed, 1505 insertions(+), 13 deletions(-)
create mode 100644 src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs
create mode 100644 src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs
diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json
index cbe38b7d72..e78861807d 100644
--- a/schemas/dab.draft.schema.json
+++ b/schemas/dab.draft.schema.json
@@ -275,6 +275,12 @@
"description": "Allow enabling/disabling MCP requests for all entities.",
"default": true
},
+ "query-timeout": {
+ "type": "integer",
+ "description": "Execution timeout in seconds for MCP tool operations. Applies to all MCP tools.",
+ "default": 30,
+ "minimum": 1
+ },
"dml-tools": {
"oneOf": [
{
@@ -315,6 +321,11 @@
"type": "boolean",
"description": "Enable/disable the execute-entity tool.",
"default": false
+ },
+ "aggregate-records": {
+ "type": "boolean",
+ "description": "Enable/disable the aggregate-records tool.",
+ "default": false
}
}
}
diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs
new file mode 100644
index 0000000000..2bf0836326
--- /dev/null
+++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs
@@ -0,0 +1,780 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Data.Common;
+using System.Text.Json;
+using Azure.DataApiBuilder.Auth;
+using Azure.DataApiBuilder.Config.DatabasePrimitives;
+using Azure.DataApiBuilder.Config.ObjectModel;
+using Azure.DataApiBuilder.Core.Authorization;
+using Azure.DataApiBuilder.Core.Configurations;
+using Azure.DataApiBuilder.Core.Models;
+using Azure.DataApiBuilder.Core.Parsers;
+using Azure.DataApiBuilder.Core.Resolvers;
+using Azure.DataApiBuilder.Core.Resolvers.Factories;
+using Azure.DataApiBuilder.Core.Services;
+using Azure.DataApiBuilder.Core.Services.MetadataProviders;
+using Azure.DataApiBuilder.Mcp.Model;
+using Azure.DataApiBuilder.Mcp.Utils;
+using Azure.DataApiBuilder.Service.Exceptions;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Protocol;
+using static Azure.DataApiBuilder.Mcp.Model.McpEnums;
+
+namespace Azure.DataApiBuilder.Mcp.BuiltInTools
+{
+ ///
+ /// Tool to aggregate records from a table/view entity configured in DAB.
+ /// Supports count, avg, sum, min, max with optional distinct, filter, groupby, having, orderby.
+ ///
+ public class AggregateRecordsTool : IMcpTool
+ {
+ public ToolType ToolType { get; } = ToolType.BuiltIn;
+
+ private static readonly HashSet _validFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" };
+
+ private static readonly HashSet _numericFunctions = new(StringComparer.OrdinalIgnoreCase) { "avg", "sum", "min", "max" };
+
+ public Tool GetToolMetadata()
+ {
+ return new Tool
+ {
+ Name = "aggregate_records",
+ Description = "Computes aggregations (count, avg, sum, min, max) on entity data. "
+ + "STEP 1: Call describe_entities to discover entities with READ permission and their field names. "
+ + "STEP 2: Call this tool with the exact entity name, an aggregation function, and a field name from STEP 1. "
+ + "REQUIRED: entity (exact entity name), function (one of: count, avg, sum, min, max), field (exact field name, or '*' ONLY for count). "
+ + "OPTIONAL: filter (OData WHERE clause applied before aggregating, e.g. 'unitPrice lt 10'), "
+ + "distinct (true to deduplicate values before aggregating), "
+ + "groupby (array of field names to group results by, e.g. ['categoryName']), "
+ + "orderby ('asc' or 'desc' to sort grouped results by aggregated value; requires groupby), "
+ + "having (object to filter groups after aggregating, operators: eq, neq, gt, gte, lt, lte, in; requires groupby), "
+ + "first (integer >= 1, maximum grouped results to return; requires groupby), "
+ + "after (opaque cursor string from a previous response's endCursor for pagination). "
+ + "RESPONSE: The aggregated value is aliased as '{function}_{field}' (e.g. avg_unitPrice, sum_revenue). "
+ + "For count with field '*', the alias is 'count'. "
+ + "When first is used with groupby, response contains: items (array), endCursor (string), hasNextPage (boolean). "
+ + "RULES: 1) ALWAYS call describe_entities first to get valid entity and field names. "
+ + "2) Use field '*' ONLY with function 'count'. "
+ + "3) For avg, sum, min, max: field MUST be a numeric field name from describe_entities. "
+ + "4) orderby, having, and first ONLY apply when groupby is provided. "
+ + "5) Use first and after for paginating large grouped result sets.",
+ InputSchema = JsonSerializer.Deserialize(
+ @"{
+ ""type"": ""object"",
+ ""properties"": {
+ ""entity"": {
+ ""type"": ""string"",
+ ""description"": ""Exact entity name from describe_entities that has READ permission. Must match exactly (case-sensitive).""
+ },
+ ""function"": {
+ ""type"": ""string"",
+ ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""],
+ ""description"": ""Aggregation function to apply. Use 'count' to count records, 'avg' for average, 'sum' for total, 'min' for minimum, 'max' for maximum. For count use field '*' or a specific field name. For avg, sum, min, max the field must be numeric.""
+ },
+ ""field"": {
+ ""type"": ""string"",
+ ""description"": ""Exact field name from describe_entities to aggregate. Use '*' ONLY with function 'count' to count all records. For avg, sum, min, max, provide a numeric field name.""
+ },
+ ""distinct"": {
+ ""type"": ""boolean"",
+ ""description"": ""When true, removes duplicate values before applying the aggregation function. Not applicable when field is '*'. Default is false."",
+ ""default"": false
+ },
+ ""filter"": {
+ ""type"": ""string"",
+ ""description"": ""OData filter expression applied before aggregating (acts as a WHERE clause). Supported operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10' filters to rows where unitPrice is less than 10 before aggregating."",
+ ""default"": """"
+ },
+ ""groupby"": {
+ ""type"": ""array"",
+ ""items"": { ""type"": ""string"" },
+ ""description"": ""Array of exact field names from describe_entities to group results by. Each unique combination of grouped field values produces one aggregated row."",
+ ""default"": []
+ },
+ ""orderby"": {
+ ""type"": ""string"",
+ ""enum"": [""asc"", ""desc""],
+ ""description"": ""Sort direction for grouped results by the computed aggregated value. 'desc' returns highest values first, 'asc' returns lowest first. ONLY applies when groupby is provided. Default is 'desc'."",
+ ""default"": ""desc""
+ },
+ ""having"": {
+ ""type"": ""object"",
+ ""description"": ""Filter applied AFTER aggregating to filter grouped results by the computed aggregated value (acts as a HAVING clause). ONLY applies when groupby is provided."",
+ ""properties"": {
+ ""eq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value equals this number."" },
+ ""neq"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value does not equal this number."" },
+ ""gt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than this number."" },
+ ""gte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is greater than or equal to this number."" },
+ ""lt"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than this number."" },
+ ""lte"": { ""type"": ""number"", ""description"": ""Keep groups where the aggregated value is less than or equal to this number."" },
+ ""in"": {
+ ""type"": ""array"",
+ ""items"": { ""type"": ""number"" },
+ ""description"": ""Keep groups where the aggregated value matches any number in this list.""
+ }
+ }
+ },
+ ""first"": {
+ ""type"": ""integer"",
+ ""description"": ""Maximum number of grouped results to return. Used for pagination of grouped results. ONLY applies when groupby is provided. Must be >= 1."",
+ ""minimum"": 1
+ },
+ ""after"": {
+ ""type"": ""string"",
+ ""description"": ""Opaque cursor string for pagination. Pass the 'endCursor' value from a previous response to get the next page of results.""
+ }
+ },
+ ""required"": [""entity"", ""function"", ""field""]
+ }"
+ )
+ };
+ }
+
+ public async Task ExecuteAsync(
+ JsonDocument? arguments,
+ IServiceProvider serviceProvider,
+ CancellationToken cancellationToken = default)
+ {
+ ILogger? logger = serviceProvider.GetService>();
+ string toolName = GetToolMetadata().Name;
+
+ RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService();
+ RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig();
+
+ if (runtimeConfig.McpDmlTools?.AggregateRecords is not true)
+ {
+ return McpErrorHelpers.ToolDisabled(toolName, logger);
+ }
+
+ string entityName = string.Empty;
+
+ try
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ if (arguments == null)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger);
+ }
+
+ JsonElement root = arguments.RootElement;
+
+ // Parse required arguments
+ if (!McpArgumentParser.TryParseEntity(root, out string parsedEntityName, out string parseError))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger);
+ }
+
+ entityName = parsedEntityName;
+
+ if (runtimeConfig.Entities?.TryGetValue(entityName, out Entity? entity) == true &&
+ entity.Mcp?.DmlToolEnabled == false)
+ {
+ return McpErrorHelpers.ToolDisabled(toolName, logger, $"DML tools are disabled for entity '{entityName}'.");
+ }
+
+ if (!root.TryGetProperty("function", out JsonElement funcEl) || string.IsNullOrWhiteSpace(funcEl.GetString()))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'function'.", logger);
+ }
+
+ string function = funcEl.GetString()!.ToLowerInvariant();
+ if (!_validFunctions.Contains(function))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger);
+ }
+
+ if (!root.TryGetProperty("field", out JsonElement fieldEl) || string.IsNullOrWhiteSpace(fieldEl.GetString()))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'field'.", logger);
+ }
+
+ string field = fieldEl.GetString()!;
+
+ // Validate field/function compatibility
+ bool isCountStar = function == "count" && field == "*";
+
+ if (field == "*" && function != "count")
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments",
+ $"Field '*' is only valid with function 'count'. For function '{function}', provide a specific field name.", logger);
+ }
+
+ bool distinct = root.TryGetProperty("distinct", out JsonElement distinctEl) && distinctEl.GetBoolean();
+
+ // Reject count(*) with distinct as it is semantically undefined
+ if (isCountStar && distinct)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments",
+ "Cannot use distinct=true with field='*'. DISTINCT requires a specific field name. Use a field name instead of '*' to count distinct values.", logger);
+ }
+
+ // For avg/sum/min/max, warn the caller that they need a numeric field
+ if (_numericFunctions.Contains(function) && field == "*")
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments",
+ $"Function '{function}' requires a numeric field name. Use a numeric field name from describe_entities instead of '*'.", logger);
+ }
+
+ string? filter = root.TryGetProperty("filter", out JsonElement filterEl) ? filterEl.GetString() : null;
+ string orderby = root.TryGetProperty("orderby", out JsonElement orderbyEl) ? (orderbyEl.GetString() ?? "desc") : "desc";
+
+ int? first = null;
+ if (root.TryGetProperty("first", out JsonElement firstEl) && firstEl.ValueKind == JsonValueKind.Number)
+ {
+ first = firstEl.GetInt32();
+ if (first < 1)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger);
+ }
+ }
+
+ string? after = root.TryGetProperty("after", out JsonElement afterEl) ? afterEl.GetString() : null;
+
+ List groupby = new();
+ if (root.TryGetProperty("groupby", out JsonElement groupbyEl) && groupbyEl.ValueKind == JsonValueKind.Array)
+ {
+ foreach (JsonElement g in groupbyEl.EnumerateArray())
+ {
+ string? gVal = g.GetString();
+ if (!string.IsNullOrWhiteSpace(gVal))
+ {
+ groupby.Add(gVal);
+ }
+ }
+ }
+
+ Dictionary? havingOps = null;
+ List? havingIn = null;
+ if (root.TryGetProperty("having", out JsonElement havingEl) && havingEl.ValueKind == JsonValueKind.Object)
+ {
+ havingOps = new Dictionary(StringComparer.OrdinalIgnoreCase);
+ foreach (JsonProperty prop in havingEl.EnumerateObject())
+ {
+ if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase) && prop.Value.ValueKind == JsonValueKind.Array)
+ {
+ havingIn = new List();
+ foreach (JsonElement item in prop.Value.EnumerateArray())
+ {
+ havingIn.Add(item.GetDouble());
+ }
+ }
+ else if (prop.Value.ValueKind == JsonValueKind.Number)
+ {
+ havingOps[prop.Name] = prop.Value.GetDouble();
+ }
+ }
+ }
+
+ // Resolve metadata
+ if (!McpMetadataHelper.TryResolveMetadata(
+ entityName,
+ runtimeConfig,
+ serviceProvider,
+ out ISqlMetadataProvider sqlMetadataProvider,
+ out DatabaseObject dbObject,
+ out string dataSourceName,
+ out string metadataError))
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger);
+ }
+
+ // Authorization
+ IAuthorizationResolver authResolver = serviceProvider.GetRequiredService();
+ IAuthorizationService authorizationService = serviceProvider.GetRequiredService();
+ IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService();
+ HttpContext? httpContext = httpContextAccessor.HttpContext;
+
+ if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError))
+ {
+ return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger);
+ }
+
+ if (!McpAuthorizationHelper.TryResolveAuthorizedRole(
+ httpContext!,
+ authResolver,
+ entityName,
+ EntityActionOperation.Read,
+ out string? effectiveRole,
+ out string readAuthError))
+ {
+ string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase)
+ ? $"You do not have permission to read records for entity '{entityName}'."
+ : readAuthError;
+ return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger);
+ }
+
+ // Build select list: groupby fields + aggregation field
+ List selectFields = new(groupby);
+ if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase))
+ {
+ selectFields.Add(field);
+ }
+
+ // Build and validate Find context
+ RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider);
+ FindRequestContext context = new(entityName, dbObject, true);
+ httpContext!.Request.Method = "GET";
+
+ requestValidator.ValidateEntity(entityName);
+
+ if (selectFields.Count > 0)
+ {
+ context.UpdateReturnFields(selectFields);
+ }
+
+ if (!string.IsNullOrWhiteSpace(filter))
+ {
+ string filterQueryString = $"?{RequestParser.FILTER_URL}={filter}";
+ context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}");
+ }
+
+ requestValidator.ValidateRequestContext(context);
+
+ AuthorizationResult authorizationResult = await authorizationService.AuthorizeAsync(
+ user: httpContext.User,
+ resource: context,
+ requirements: new[] { new ColumnsPermissionsRequirement() });
+ if (!authorizationResult.Succeeded)
+ {
+ return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger);
+ }
+
+ // Execute query to get records
+ IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService();
+ IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType());
+
+ JsonDocument? queryResult = await queryEngine.ExecuteAsync(context);
+
+ IActionResult actionResult = queryResult is null
+ ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true)
+ : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, sqlMetadataProvider, runtimeConfig, httpContext, true);
+
+ string rawPayloadJson = McpResponseBuilder.ExtractResultJson(actionResult);
+ using JsonDocument resultDoc = JsonDocument.Parse(rawPayloadJson);
+ JsonElement resultRoot = resultDoc.RootElement;
+
+ // Extract the records array from the response
+ JsonElement records;
+ if (resultRoot.TryGetProperty("value", out JsonElement valueArray))
+ {
+ records = valueArray;
+ }
+ else if (resultRoot.ValueKind == JsonValueKind.Array)
+ {
+ records = resultRoot;
+ }
+ else
+ {
+ records = resultRoot;
+ }
+
+ // Compute alias for the response
+ string alias = ComputeAlias(function, field);
+
+ // Perform in-memory aggregation
+ List> aggregatedResults = PerformAggregation(
+ records, function, field, distinct, groupby, havingOps, havingIn, orderby, alias);
+
+ // Apply pagination if first is specified with groupby
+ if (first.HasValue && groupby.Count > 0)
+ {
+ PaginationResult paginatedResult = ApplyPagination(aggregatedResults, first.Value, after);
+ return McpResponseBuilder.BuildSuccessResult(
+ new Dictionary
+ {
+ ["entity"] = entityName,
+ ["result"] = new Dictionary
+ {
+ ["items"] = paginatedResult.Items,
+ ["endCursor"] = paginatedResult.EndCursor,
+ ["hasNextPage"] = paginatedResult.HasNextPage
+ },
+ ["message"] = $"Successfully aggregated records for entity '{entityName}'"
+ },
+ logger,
+ $"AggregateRecordsTool success for entity {entityName}.");
+ }
+
+ return McpResponseBuilder.BuildSuccessResult(
+ new Dictionary
+ {
+ ["entity"] = entityName,
+ ["result"] = aggregatedResults,
+ ["message"] = $"Successfully aggregated records for entity '{entityName}'"
+ },
+ logger,
+ $"AggregateRecordsTool success for entity {entityName}.");
+ }
+ catch (TimeoutException timeoutEx)
+ {
+ logger?.LogError(timeoutEx, "Aggregation operation timed out for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(
+ toolName,
+ "TimeoutError",
+ $"The aggregation query for entity '{entityName}' timed out. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "This may occur with large datasets or complex aggregations. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.",
+ logger);
+ }
+ catch (TaskCanceledException taskEx)
+ {
+ logger?.LogError(taskEx, "Aggregation task was canceled for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(
+ toolName,
+ "TimeoutError",
+ $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. "
+ + "This is NOT a tool error. The database did not respond in time. "
+ + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.",
+ logger);
+ }
+ catch (OperationCanceledException)
+ {
+ logger?.LogWarning("Aggregation operation was canceled for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(
+ toolName,
+ "OperationCanceled",
+ $"The aggregation query for entity '{entityName}' was canceled before completion. "
+ + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. "
+ + "No results were returned. You may retry the same request.",
+ logger);
+ }
+ catch (DbException dbEx)
+ {
+ logger?.LogError(dbEx, "Database error during aggregation for entity {Entity}.", entityName);
+ return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", dbEx.Message, logger);
+ }
+ catch (ArgumentException argEx)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argEx.Message, logger);
+ }
+ catch (DataApiBuilderException argEx)
+ {
+ return McpResponseBuilder.BuildErrorResult(toolName, argEx.StatusCode.ToString(), argEx.Message, logger);
+ }
+ catch (Exception ex)
+ {
+ logger?.LogError(ex, "Unexpected error in AggregateRecordsTool.");
+ return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in AggregateRecordsTool.", logger);
+ }
+ }
+
+ ///
+ /// Computes the response alias for the aggregation result.
+ /// For count with "*", the alias is "count". Otherwise it's "{function}_{field}".
+ ///
+ internal static string ComputeAlias(string function, string field)
+ {
+ if (function == "count" && field == "*")
+ {
+ return "count";
+ }
+
+ return $"{function}_{field}";
+ }
+
+ ///
+ /// Performs in-memory aggregation over a JSON array of records.
+ ///
+ internal static List> PerformAggregation(
+ JsonElement records,
+ string function,
+ string field,
+ bool distinct,
+ List groupby,
+ Dictionary? havingOps,
+ List? havingIn,
+ string orderby,
+ string alias)
+ {
+ if (records.ValueKind != JsonValueKind.Array)
+ {
+ return new List> { new() { [alias] = null } };
+ }
+
+ bool isCountStar = function == "count" && field == "*";
+
+ if (groupby.Count == 0)
+ {
+ // No groupby - single result
+ List items = new();
+ foreach (JsonElement record in records.EnumerateArray())
+ {
+ items.Add(record);
+ }
+
+ double? aggregatedValue = ComputeAggregateValue(items, function, field, distinct, isCountStar);
+
+ // Apply having
+ if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn))
+ {
+ return new List>();
+ }
+
+ return new List>
+ {
+ new() { [alias] = aggregatedValue }
+ };
+ }
+ else
+ {
+ // Group by
+ Dictionary> groups = new();
+ Dictionary> groupKeys = new();
+
+ foreach (JsonElement record in records.EnumerateArray())
+ {
+ string key = BuildGroupKey(record, groupby);
+ if (!groups.ContainsKey(key))
+ {
+ groups[key] = new List();
+ groupKeys[key] = ExtractGroupFields(record, groupby);
+ }
+
+ groups[key].Add(record);
+ }
+
+ List> results = new();
+ foreach (KeyValuePair> group in groups)
+ {
+ double? aggregatedValue = ComputeAggregateValue(group.Value, function, field, distinct, isCountStar);
+
+ if (!PassesHavingFilter(aggregatedValue, havingOps, havingIn))
+ {
+ continue;
+ }
+
+ Dictionary row = new(groupKeys[group.Key])
+ {
+ [alias] = aggregatedValue
+ };
+ results.Add(row);
+ }
+
+ // Apply orderby
+ if (orderby.Equals("asc", StringComparison.OrdinalIgnoreCase))
+ {
+ results.Sort((a, b) => CompareNullableDoubles(a[alias] as double?, b[alias] as double?));
+ }
+ else
+ {
+ results.Sort((a, b) => CompareNullableDoubles(b[alias] as double?, a[alias] as double?));
+ }
+
+ return results;
+ }
+ }
+
+ ///
+ /// Represents the result of applying pagination to aggregated results.
+ ///
+ internal sealed class PaginationResult
+ {
+ public List> Items { get; set; } = new();
+ public string? EndCursor { get; set; }
+ public bool HasNextPage { get; set; }
+ }
+
+ ///
+ /// Applies cursor-based pagination to aggregated results.
+ /// The cursor is an opaque base64-encoded offset integer.
+ /// When after is provided without first, the cursor is ignored and all results from the start are returned.
+ ///
+ internal static PaginationResult ApplyPagination(
+ List> allResults,
+ int first,
+ string? after)
+ {
+ int startIndex = 0;
+
+ if (!string.IsNullOrWhiteSpace(after))
+ {
+ try
+ {
+ byte[] bytes = Convert.FromBase64String(after);
+ string decoded = System.Text.Encoding.UTF8.GetString(bytes);
+ if (int.TryParse(decoded, out int cursorOffset))
+ {
+ startIndex = cursorOffset;
+ }
+ }
+ catch (FormatException)
+ {
+ // Invalid cursor format; start from beginning
+ }
+ }
+
+ List> pageItems = allResults
+ .Skip(startIndex)
+ .Take(first)
+ .ToList();
+
+ bool hasNextPage = startIndex + first < allResults.Count;
+ string? endCursor = null;
+
+ if (pageItems.Count > 0)
+ {
+ int lastItemIndex = startIndex + pageItems.Count;
+ endCursor = Convert.ToBase64String(
+ System.Text.Encoding.UTF8.GetBytes(lastItemIndex.ToString()));
+ }
+
+ return new PaginationResult
+ {
+ Items = pageItems,
+ EndCursor = endCursor,
+ HasNextPage = hasNextPage
+ };
+ }
+
+ private static double? ComputeAggregateValue(List records, string function, string field, bool distinct, bool isCountStar)
+ {
+ if (isCountStar)
+ {
+ // count(*) always counts all rows; distinct is rejected at ExecuteAsync validation level
+ return records.Count;
+ }
+
+ List values = new();
+ foreach (JsonElement record in records)
+ {
+ if (record.TryGetProperty(field, out JsonElement val) && val.ValueKind == JsonValueKind.Number)
+ {
+ values.Add(val.GetDouble());
+ }
+ }
+
+ if (distinct)
+ {
+ values = values.Distinct().ToList();
+ }
+
+ if (function == "count")
+ {
+ return values.Count;
+ }
+
+ if (values.Count == 0)
+ {
+ return null;
+ }
+
+ return function switch
+ {
+ "avg" => Math.Round(values.Average(), 2),
+ "sum" => values.Sum(),
+ "min" => values.Min(),
+ "max" => values.Max(),
+ _ => null
+ };
+ }
+
+ private static bool PassesHavingFilter(double? value, Dictionary? havingOps, List? havingIn)
+ {
+ if (havingOps == null && havingIn == null)
+ {
+ return true;
+ }
+
+ if (value == null)
+ {
+ return false;
+ }
+
+ double v = value.Value;
+
+ if (havingOps != null)
+ {
+ foreach (KeyValuePair op in havingOps)
+ {
+ bool passes = op.Key.ToLowerInvariant() switch
+ {
+ "eq" => v == op.Value,
+ "neq" => v != op.Value,
+ "gt" => v > op.Value,
+ "gte" => v >= op.Value,
+ "lt" => v < op.Value,
+ "lte" => v <= op.Value,
+ _ => true
+ };
+
+ if (!passes)
+ {
+ return false;
+ }
+ }
+ }
+
+ if (havingIn != null && !havingIn.Contains(v))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ private static string BuildGroupKey(JsonElement record, List groupby)
+ {
+ List parts = new();
+ foreach (string g in groupby)
+ {
+ if (record.TryGetProperty(g, out JsonElement val))
+ {
+ parts.Add(val.ToString());
+ }
+ else
+ {
+ parts.Add("__null__");
+ }
+ }
+
+ // Use null character (\0) as delimiter to avoid collisions with
+ // field values that may contain printable characters like '|'.
+ return string.Join("\0", parts);
+ }
+
+ private static Dictionary ExtractGroupFields(JsonElement record, List groupby)
+ {
+ Dictionary result = new();
+ foreach (string g in groupby)
+ {
+ if (record.TryGetProperty(g, out JsonElement val))
+ {
+ result[g] = McpResponseBuilder.GetJsonValue(val);
+ }
+ else
+ {
+ result[g] = null;
+ }
+ }
+
+ return result;
+ }
+
+ private static int CompareNullableDoubles(double? a, double? b)
+ {
+ if (a == null && b == null)
+ {
+ return 0;
+ }
+
+ if (a == null)
+ {
+ return -1;
+ }
+
+ if (b == null)
+ {
+ return 1;
+ }
+
+ return a.Value.CompareTo(b.Value);
+ }
+ }
+}
diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs
index f69a26fa5d..4362ecaf3c 100644
--- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs
+++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs
@@ -37,5 +37,10 @@ internal static class McpTelemetryErrorCodes
/// Operation cancelled error code.
///
public const string OPERATION_CANCELLED = "OperationCancelled";
+
+ ///
+ /// Timeout error code.
+ ///
+ public const string TIMEOUT = "Timeout";
}
}
diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs
index 2a60557f8d..105bb57ced 100644
--- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs
+++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs
@@ -60,8 +60,33 @@ public static async Task ExecuteWithTelemetryAsync(
operation: operation,
dbProcedure: dbProcedure);
- // Execute the tool
- CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, cancellationToken);
+ // Read query-timeout from current config per invocation (hot-reload safe).
+ int timeoutSeconds = McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS;
+ RuntimeConfigProvider? runtimeConfigProvider = serviceProvider.GetService();
+ if (runtimeConfigProvider is not null)
+ {
+ RuntimeConfig config = runtimeConfigProvider.GetConfig();
+ timeoutSeconds = config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds ?? McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS;
+ }
+
+ // Wrap tool execution with the configured timeout using a linked CancellationTokenSource.
+ using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds));
+
+ CallToolResult result;
+ try
+ {
+ result = await tool.ExecuteAsync(arguments, serviceProvider, timeoutCts.Token);
+ }
+ catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
+ {
+ // The timeout CTS fired, not the caller's token. Surface as TimeoutException
+ // so downstream telemetry and tool handlers see TIMEOUT, not cancellation.
+ throw new TimeoutException(
+ $"The MCP tool '{toolName}' did not complete within {timeoutSeconds} seconds. "
+ + "This is NOT a tool error. The operation exceeded the configured query-timeout. "
+ + "Try narrowing results with a filter, reducing groupby fields, or using pagination.");
+ }
// Check if the tool returned an error result (tools catch exceptions internally
// and return CallToolResult with IsError=true instead of throwing)
@@ -124,6 +149,7 @@ public static string InferOperationFromTool(IMcpTool tool, string toolName)
"delete_record" => "delete",
"describe_entities" => "describe",
"execute_entity" => "execute",
+ "aggregate_records" => "aggregate",
_ => "execute" // Fallback for any unknown built-in tools
};
}
@@ -188,6 +214,7 @@ public static string MapExceptionToErrorCode(Exception ex)
return ex switch
{
OperationCanceledException => McpTelemetryErrorCodes.OPERATION_CANCELLED,
+ TimeoutException => McpTelemetryErrorCodes.TIMEOUT,
DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthenticationChallenge
=> McpTelemetryErrorCodes.AUTHENTICATION_FAILED,
DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthorizationCheckFailed
diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs
index 262cbc9145..2fd20ea4ca 100644
--- a/src/Cli/Commands/ConfigureOptions.cs
+++ b/src/Cli/Commands/ConfigureOptions.cs
@@ -49,6 +49,7 @@ public ConfigureOptions(
bool? runtimeMcpDmlToolsUpdateRecordEnabled = null,
bool? runtimeMcpDmlToolsDeleteRecordEnabled = null,
bool? runtimeMcpDmlToolsExecuteEntityEnabled = null,
+ bool? runtimeMcpDmlToolsAggregateRecordsEnabled = null,
bool? runtimeCacheEnabled = null,
int? runtimeCacheTtl = null,
CompressionLevel? runtimeCompressionLevel = null,
@@ -109,6 +110,7 @@ public ConfigureOptions(
RuntimeMcpDmlToolsUpdateRecordEnabled = runtimeMcpDmlToolsUpdateRecordEnabled;
RuntimeMcpDmlToolsDeleteRecordEnabled = runtimeMcpDmlToolsDeleteRecordEnabled;
RuntimeMcpDmlToolsExecuteEntityEnabled = runtimeMcpDmlToolsExecuteEntityEnabled;
+ RuntimeMcpDmlToolsAggregateRecordsEnabled = runtimeMcpDmlToolsAggregateRecordsEnabled;
// Cache
RuntimeCacheEnabled = runtimeCacheEnabled;
RuntimeCacheTTL = runtimeCacheTtl;
@@ -224,6 +226,9 @@ public ConfigureOptions(
[Option("runtime.mcp.dml-tools.execute-entity.enabled", Required = false, HelpText = "Enable DAB's MCP execute entity tool. Default: true (boolean).")]
public bool? RuntimeMcpDmlToolsExecuteEntityEnabled { get; }
+ [Option("runtime.mcp.dml-tools.aggregate-records.enabled", Required = false, HelpText = "Enable DAB's MCP aggregate records tool. Default: true (boolean).")]
+ public bool? RuntimeMcpDmlToolsAggregateRecordsEnabled { get; }
+
[Option("runtime.cache.enabled", Required = false, HelpText = "Enable DAB's cache globally. (You must also enable each entity's cache separately.). Default: false (boolean).")]
public bool? RuntimeCacheEnabled { get; }
diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs
index 6c51f002b7..ac377c51de 100644
--- a/src/Cli/ConfigGenerator.cs
+++ b/src/Cli/ConfigGenerator.cs
@@ -882,7 +882,8 @@ private static bool TryUpdateConfiguredRuntimeOptions(
options.RuntimeMcpDmlToolsReadRecordsEnabled != null ||
options.RuntimeMcpDmlToolsUpdateRecordEnabled != null ||
options.RuntimeMcpDmlToolsDeleteRecordEnabled != null ||
- options.RuntimeMcpDmlToolsExecuteEntityEnabled != null)
+ options.RuntimeMcpDmlToolsExecuteEntityEnabled != null ||
+ options.RuntimeMcpDmlToolsAggregateRecordsEnabled != null)
{
McpRuntimeOptions updatedMcpOptions = runtimeConfig?.Runtime?.Mcp ?? new();
bool status = TryUpdateConfiguredMcpValues(options, ref updatedMcpOptions);
@@ -1181,6 +1182,7 @@ private static bool TryUpdateConfiguredMcpValues(
bool? updateRecord = currentDmlTools?.UpdateRecord;
bool? deleteRecord = currentDmlTools?.DeleteRecord;
bool? executeEntity = currentDmlTools?.ExecuteEntity;
+ bool? aggregateRecords = currentDmlTools?.AggregateRecords;
updatedValue = options?.RuntimeMcpDmlToolsDescribeEntitiesEnabled;
if (updatedValue != null)
@@ -1230,6 +1232,14 @@ private static bool TryUpdateConfiguredMcpValues(
_logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.execute-entity as '{updatedValue}'", updatedValue);
}
+ updatedValue = options?.RuntimeMcpDmlToolsAggregateRecordsEnabled;
+ if (updatedValue != null)
+ {
+ aggregateRecords = (bool)updatedValue;
+ hasToolUpdates = true;
+ _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.aggregate-records as '{updatedValue}'", updatedValue);
+ }
+
if (hasToolUpdates)
{
updatedMcpOptions = updatedMcpOptions! with
@@ -1242,7 +1252,8 @@ private static bool TryUpdateConfiguredMcpValues(
ReadRecords = readRecord,
UpdateRecord = updateRecord,
DeleteRecord = deleteRecord,
- ExecuteEntity = executeEntity
+ ExecuteEntity = executeEntity,
+ AggregateRecords = aggregateRecords
}
};
}
diff --git a/src/Config/Converters/DmlToolsConfigConverter.cs b/src/Config/Converters/DmlToolsConfigConverter.cs
index 82ac3f6069..7e049c7926 100644
--- a/src/Config/Converters/DmlToolsConfigConverter.cs
+++ b/src/Config/Converters/DmlToolsConfigConverter.cs
@@ -44,6 +44,7 @@ internal class DmlToolsConfigConverter : JsonConverter
bool? updateRecord = null;
bool? deleteRecord = null;
bool? executeEntity = null;
+ bool? aggregateRecords = null;
while (reader.Read())
{
@@ -82,6 +83,9 @@ internal class DmlToolsConfigConverter : JsonConverter
case "execute-entity":
executeEntity = value;
break;
+ case "aggregate-records":
+ aggregateRecords = value;
+ break;
default:
// Skip unknown properties
break;
@@ -91,7 +95,8 @@ internal class DmlToolsConfigConverter : JsonConverter
{
// Error on non-boolean values for known properties
if (property?.ToLowerInvariant() is "describe-entities" or "create-record"
- or "read-records" or "update-record" or "delete-record" or "execute-entity")
+ or "read-records" or "update-record" or "delete-record" or "execute-entity"
+ or "aggregate-records")
{
throw new JsonException($"Property '{property}' must be a boolean value.");
}
@@ -110,7 +115,8 @@ internal class DmlToolsConfigConverter : JsonConverter
readRecords: readRecords,
updateRecord: updateRecord,
deleteRecord: deleteRecord,
- executeEntity: executeEntity);
+ executeEntity: executeEntity,
+ aggregateRecords: aggregateRecords);
}
// For any other unexpected token type, return default (all enabled)
@@ -135,7 +141,8 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer
value.UserProvidedReadRecords ||
value.UserProvidedUpdateRecord ||
value.UserProvidedDeleteRecord ||
- value.UserProvidedExecuteEntity;
+ value.UserProvidedExecuteEntity ||
+ value.UserProvidedAggregateRecords;
// Only write the boolean value if it's provided by user
// This prevents writing "dml-tools": true when it's the default
@@ -181,6 +188,11 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer
writer.WriteBoolean("execute-entity", value.ExecuteEntity.Value);
}
+ if (value.UserProvidedAggregateRecords && value.AggregateRecords.HasValue)
+ {
+ writer.WriteBoolean("aggregate-records", value.AggregateRecords.Value);
+ }
+
writer.WriteEndObject();
}
}
diff --git a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs
index 8b3c640725..ad4edc229e 100644
--- a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs
+++ b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs
@@ -66,12 +66,13 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings?
string? path = null;
DmlToolsConfig? dmlTools = null;
string? description = null;
+ int? queryTimeout = null;
while (reader.Read())
{
if (reader.TokenType == JsonTokenType.EndObject)
{
- return new McpRuntimeOptions(enabled, path, dmlTools, description);
+ return new McpRuntimeOptions(enabled, path, dmlTools, description, queryTimeout);
}
string? propertyName = reader.GetString();
@@ -107,6 +108,14 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings?
break;
+ case "query-timeout":
+ if (reader.TokenType is not JsonTokenType.Null)
+ {
+ queryTimeout = reader.GetInt32();
+ }
+
+ break;
+
default:
throw new JsonException($"Unexpected property {propertyName}");
}
@@ -150,6 +159,12 @@ public override void Write(Utf8JsonWriter writer, McpRuntimeOptions value, JsonS
JsonSerializer.Serialize(writer, value.Description, options);
}
+ // Write query-timeout if it's user provided
+ if (value?.UserProvidedQueryTimeout is true && value.QueryTimeout.HasValue)
+ {
+ writer.WriteNumber("query-timeout", value.QueryTimeout.Value);
+ }
+
writer.WriteEndObject();
}
}
diff --git a/src/Config/ObjectModel/DmlToolsConfig.cs b/src/Config/ObjectModel/DmlToolsConfig.cs
index 2a09e9d53c..c1f8b278cd 100644
--- a/src/Config/ObjectModel/DmlToolsConfig.cs
+++ b/src/Config/ObjectModel/DmlToolsConfig.cs
@@ -51,6 +51,11 @@ public record DmlToolsConfig
///
public bool? ExecuteEntity { get; init; }
+ ///
+ /// Whether aggregate-records tool is enabled
+ ///
+ public bool? AggregateRecords { get; init; }
+
[JsonConstructor]
public DmlToolsConfig(
bool? allToolsEnabled = null,
@@ -59,7 +64,8 @@ public DmlToolsConfig(
bool? readRecords = null,
bool? updateRecord = null,
bool? deleteRecord = null,
- bool? executeEntity = null)
+ bool? executeEntity = null,
+ bool? aggregateRecords = null)
{
if (allToolsEnabled is not null)
{
@@ -75,6 +81,7 @@ public DmlToolsConfig(
UpdateRecord = updateRecord ?? toolDefault;
DeleteRecord = deleteRecord ?? toolDefault;
ExecuteEntity = executeEntity ?? toolDefault;
+ AggregateRecords = aggregateRecords ?? toolDefault;
}
else
{
@@ -87,6 +94,7 @@ public DmlToolsConfig(
UpdateRecord = updateRecord ?? DEFAULT_ENABLED;
DeleteRecord = deleteRecord ?? DEFAULT_ENABLED;
ExecuteEntity = executeEntity ?? DEFAULT_ENABLED;
+ AggregateRecords = aggregateRecords ?? DEFAULT_ENABLED;
}
// Track user-provided status - only true if the parameter was not null
@@ -96,6 +104,7 @@ public DmlToolsConfig(
UserProvidedUpdateRecord = updateRecord is not null;
UserProvidedDeleteRecord = deleteRecord is not null;
UserProvidedExecuteEntity = executeEntity is not null;
+ UserProvidedAggregateRecords = aggregateRecords is not null;
}
///
@@ -112,7 +121,8 @@ public static DmlToolsConfig FromBoolean(bool enabled)
readRecords: null,
updateRecord: null,
deleteRecord: null,
- executeEntity: null
+ executeEntity: null,
+ aggregateRecords: null
);
}
@@ -127,7 +137,8 @@ public static DmlToolsConfig FromBoolean(bool enabled)
readRecords: null,
updateRecord: null,
deleteRecord: null,
- executeEntity: null
+ executeEntity: null,
+ aggregateRecords: null
);
///
@@ -185,4 +196,12 @@ public static DmlToolsConfig FromBoolean(bool enabled)
[JsonIgnore(Condition = JsonIgnoreCondition.Always)]
[MemberNotNullWhen(true, nameof(ExecuteEntity))]
public bool UserProvidedExecuteEntity { get; init; } = false;
+
+ ///
+ /// Flag which informs CLI and JSON serializer whether to write aggregate-records
+ /// property/value to the runtime config file.
+ ///
+ [JsonIgnore(Condition = JsonIgnoreCondition.Always)]
+ [MemberNotNullWhen(true, nameof(AggregateRecords))]
+ public bool UserProvidedAggregateRecords { get; init; } = false;
}
diff --git a/src/Config/ObjectModel/McpRuntimeOptions.cs b/src/Config/ObjectModel/McpRuntimeOptions.cs
index e17d53fc8f..f4b4281a14 100644
--- a/src/Config/ObjectModel/McpRuntimeOptions.cs
+++ b/src/Config/ObjectModel/McpRuntimeOptions.cs
@@ -10,6 +10,7 @@ namespace Azure.DataApiBuilder.Config.ObjectModel;
public record McpRuntimeOptions
{
public const string DEFAULT_PATH = "/mcp";
+ public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 30;
///
/// Whether MCP endpoints are enabled
@@ -36,12 +37,22 @@ public record McpRuntimeOptions
[JsonPropertyName("description")]
public string? Description { get; init; }
+ ///
+ /// Execution timeout in seconds for MCP tool operations.
+ /// This timeout wraps the entire tool execution including database queries.
+ /// It applies to all MCP tools, not just aggregation.
+ /// Default: 30 seconds.
+ ///
+ [JsonPropertyName("query-timeout")]
+ public int? QueryTimeout { get; init; }
+
[JsonConstructor]
public McpRuntimeOptions(
bool? Enabled = null,
string? Path = null,
DmlToolsConfig? DmlTools = null,
- string? Description = null)
+ string? Description = null,
+ int? QueryTimeout = null)
{
this.Enabled = Enabled ?? true;
@@ -67,6 +78,12 @@ public McpRuntimeOptions(
}
this.Description = Description;
+
+ if (QueryTimeout is not null)
+ {
+ this.QueryTimeout = QueryTimeout;
+ UserProvidedQueryTimeout = true;
+ }
}
///
@@ -78,4 +95,17 @@ public McpRuntimeOptions(
[JsonIgnore(Condition = JsonIgnoreCondition.Always)]
[MemberNotNullWhen(true, nameof(Enabled))]
public bool UserProvidedPath { get; init; } = false;
+
+ ///
+ /// Flag which informs CLI and JSON serializer whether to write query-timeout
+ /// property and value to the runtime config file.
+ ///
+ [JsonIgnore(Condition = JsonIgnoreCondition.Always)]
+ public bool UserProvidedQueryTimeout { get; init; } = false;
+
+ ///
+ /// Gets the effective query timeout in seconds, using the default if not specified.
+ ///
+ [JsonIgnore(Condition = JsonIgnoreCondition.Always)]
+ public int EffectiveQueryTimeoutSeconds => QueryTimeout ?? DEFAULT_QUERY_TIMEOUT_SECONDS;
}
diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs
new file mode 100644
index 0000000000..dee8842a0d
--- /dev/null
+++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs
@@ -0,0 +1,411 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#nullable enable
+
+using System.Collections.Generic;
+using System.Text.Json;
+using Azure.DataApiBuilder.Mcp.BuiltInTools;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Azure.DataApiBuilder.Service.Tests.UnitTests
+{
+ ///
+ /// Unit tests for AggregateRecordsTool's internal helper methods.
+ /// Covers validation paths, aggregation logic, and pagination behavior.
+ ///
+ [TestClass]
+ public class AggregateRecordsToolTests
+ {
+ #region ComputeAlias tests
+
+ [TestMethod]
+ [DataRow("count", "*", "count", DisplayName = "count(*) alias is 'count'")]
+ [DataRow("count", "userId", "count_userId", DisplayName = "count(field) alias is 'count_field'")]
+ [DataRow("avg", "price", "avg_price", DisplayName = "avg alias")]
+ [DataRow("sum", "amount", "sum_amount", DisplayName = "sum alias")]
+ [DataRow("min", "age", "min_age", DisplayName = "min alias")]
+ [DataRow("max", "score", "max_score", DisplayName = "max alias")]
+ public void ComputeAlias_ReturnsExpectedAlias(string function, string field, string expectedAlias)
+ {
+ string result = AggregateRecordsTool.ComputeAlias(function, field);
+ Assert.AreEqual(expectedAlias, result);
+ }
+
+ #endregion
+
+ #region PerformAggregation tests - no groupby
+
+ private static JsonElement CreateRecordsArray(params double[] values)
+ {
+ var list = new List