diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index cbe38b7d72..8f283fba36 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -275,6 +275,13 @@ "description": "Allow enabling/disabling MCP requests for all entities.", "default": true }, + "query-timeout": { + "type": "integer", + "description": "Execution timeout in seconds for MCP tool operations. Applies to all MCP tools. Range: 1-600.", + "default": 30, + "minimum": 1, + "maximum": 600 + }, "dml-tools": { "oneOf": [ { @@ -315,6 +322,11 @@ "type": "boolean", "description": "Enable/disable the execute-entity tool.", "default": false + }, + "aggregate-records": { + "type": "boolean", + "description": "Enable/disable the aggregate-records tool.", + "default": false } } } diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs new file mode 100644 index 0000000000..2d7e8e19d8 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.cs @@ -0,0 +1,783 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Data.Common; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.RegularExpressions; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Parsers; +using Azure.DataApiBuilder.Core.Resolvers; +using Azure.DataApiBuilder.Core.Resolvers.Factories; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; +using static Azure.DataApiBuilder.Service.GraphQLBuilder.Sql.SchemaConverter; + +namespace Azure.DataApiBuilder.Mcp.BuiltInTools +{ + /// + /// Tool to aggregate records from a table/view entity configured in DAB. + /// Supports count, avg, sum, min, max with optional distinct, filter, groupby, having, orderby. + /// + public class AggregateRecordsTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + private static readonly HashSet _validFunctions = new(StringComparer.OrdinalIgnoreCase) { "count", "avg", "sum", "min", "max" }; + private static readonly HashSet _validHavingOperators = new(StringComparer.OrdinalIgnoreCase) { "eq", "neq", "gt", "gte", "lt", "lte", "in" }; + + private static readonly Tool _cachedToolMetadata = new() + { + Name = "aggregate_records", + Description = "Computes aggregations (count, avg, sum, min, max) on entity data. " + + "WORKFLOW: 1) Call describe_entities first to get entity names and field names. " + + "2) Call this tool with entity, function, and field from step 1. " + + "RULES: field '*' is ONLY valid with count. " + + "orderby, having, first, and after ONLY apply when groupby is provided. " + + "RESPONSE: Result is aliased as '{function}_{field}' (e.g. avg_unitPrice). " + + "For count(*), the alias is 'count'. " + + "With groupby and first, response includes items, endCursor, and hasNextPage for pagination.", + InputSchema = JsonSerializer.Deserialize( + @"{ + ""type"": ""object"", + ""properties"": { + ""entity"": { + ""type"": ""string"", + ""description"": ""Entity name from describe_entities with READ permission (case-sensitive)."" + }, + ""function"": { + ""type"": ""string"", + ""enum"": [""count"", ""avg"", ""sum"", ""min"", ""max""], + ""description"": ""Aggregation function. count supports field '*'; avg, sum, min, max require a numeric field."" + }, + ""field"": { + ""type"": ""string"", + ""description"": ""Field name to aggregate, or '*' with count to count all rows."" + }, + ""distinct"": { + ""type"": ""boolean"", + ""description"": ""Remove duplicate values before aggregating. Not valid with field '*'."", + ""default"": false + }, + ""filter"": { + ""type"": ""string"", + ""description"": ""OData WHERE clause applied before aggregating. Operators: eq, ne, gt, ge, lt, le, and, or, not. Example: 'unitPrice lt 10'."", + ""default"": """" + }, + ""groupby"": { + ""type"": ""array"", + ""items"": { ""type"": ""string"" }, + ""description"": ""Field names to group by. Each unique combination produces one aggregated row. Enables orderby, having, first, and after."", + ""default"": [] + }, + ""orderby"": { + ""type"": ""string"", + ""enum"": [""asc"", ""desc""], + ""description"": ""Sort grouped results by the aggregated value. Requires groupby."", + ""default"": ""desc"" + }, + ""having"": { + ""type"": ""object"", + ""description"": ""Filter groups by the aggregated value (HAVING clause). Requires groupby. Multiple operators are AND-ed."", + ""properties"": { + ""eq"": { ""type"": ""number"", ""description"": ""Equals."" }, + ""neq"": { ""type"": ""number"", ""description"": ""Not equals."" }, + ""gt"": { ""type"": ""number"", ""description"": ""Greater than."" }, + ""gte"": { ""type"": ""number"", ""description"": ""Greater than or equal."" }, + ""lt"": { ""type"": ""number"", ""description"": ""Less than."" }, + ""lte"": { ""type"": ""number"", ""description"": ""Less than or equal."" }, + ""in"": { + ""type"": ""array"", + ""items"": { ""type"": ""number"" }, + ""description"": ""Matches any value in the list."" + } + } + }, + ""first"": { + ""type"": ""integer"", + ""description"": ""Max grouped results to return. Requires groupby. Enables paginated response with endCursor and hasNextPage."", + ""minimum"": 1 + }, + ""after"": { + ""type"": ""string"", + ""description"": ""Opaque cursor from a previous endCursor for next-page retrieval. Requires groupby and first. Do not construct manually."" + } + }, + ""required"": [""entity"", ""function"", ""field""] + }" + ) + }; + + public Tool GetToolMetadata() + { + return _cachedToolMetadata; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + ILogger? logger = serviceProvider.GetService>(); + string toolName = GetToolMetadata().Name; + + RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); + RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); + + if (runtimeConfig.McpDmlTools?.AggregateRecords is not true) + { + return McpErrorHelpers.ToolDisabled(toolName, logger); + } + + string entityName = string.Empty; + + try + { + cancellationToken.ThrowIfCancellationRequested(); + + if (arguments == null) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "No arguments provided.", logger); + } + + JsonElement root = arguments.RootElement; + + // Parse required arguments + if (!McpArgumentParser.TryParseEntity(root, out string parsedEntityName, out string parseError)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", parseError, logger); + } + + entityName = parsedEntityName; + + if (runtimeConfig.Entities?.TryGetValue(entityName, out Entity? entity) == true && + entity.Mcp?.DmlToolEnabled == false) + { + return McpErrorHelpers.ToolDisabled(toolName, logger, $"DML tools are disabled for entity '{entityName}'."); + } + + if (!root.TryGetProperty("function", out JsonElement functionElement) || string.IsNullOrWhiteSpace(functionElement.GetString())) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'function'.", logger); + } + + string function = functionElement.GetString()!.ToLowerInvariant(); + if (!_validFunctions.Contains(function)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", $"Invalid function '{function}'. Must be one of: count, avg, sum, min, max.", logger); + } + + if (!root.TryGetProperty("field", out JsonElement fieldElement) || string.IsNullOrWhiteSpace(fieldElement.GetString())) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Missing required argument 'field'.", logger); + } + + string field = fieldElement.GetString()!; + + // Validate field/function compatibility + bool isCountStar = function == "count" && field == "*"; + + if (field == "*" && function != "count") + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"Field '*' is only valid with function 'count'. For function '{function}', provide a specific field name.", logger); + } + + bool distinct = root.TryGetProperty("distinct", out JsonElement distinctElement) && distinctElement.GetBoolean(); + + // Reject count(*) with distinct as it is semantically undefined + if (isCountStar && distinct) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "Cannot use distinct=true with field='*'. DISTINCT requires a specific field name. Use a field name instead of '*' to count distinct values.", logger); + } + + string? filter = root.TryGetProperty("filter", out JsonElement filterElement) ? filterElement.GetString() : null; + bool userProvidedOrderby = root.TryGetProperty("orderby", out JsonElement orderbyElement) && !string.IsNullOrWhiteSpace(orderbyElement.GetString()); + string orderby = userProvidedOrderby ? (orderbyElement.GetString() ?? "desc") : "desc"; + + int? first = null; + if (root.TryGetProperty("first", out JsonElement firstElement) && firstElement.ValueKind == JsonValueKind.Number) + { + first = firstElement.GetInt32(); + if (first < 1) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", "Argument 'first' must be at least 1.", logger); + } + } + + string? after = root.TryGetProperty("after", out JsonElement afterElement) ? afterElement.GetString() : null; + + List groupby = new(); + if (root.TryGetProperty("groupby", out JsonElement groupbyElement) && groupbyElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement groupbyItem in groupbyElement.EnumerateArray()) + { + string? groupbyFieldName = groupbyItem.GetString(); + if (!string.IsNullOrWhiteSpace(groupbyFieldName)) + { + groupby.Add(groupbyFieldName); + } + } + } + + // Validate that first, after, orderby, and having require groupby + if (groupby.Count == 0) + { + if (userProvidedOrderby) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'orderby' parameter requires 'groupby' to be specified. Sorting applies to grouped aggregation results.", logger); + } + + if (first.HasValue) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'first' parameter requires 'groupby' to be specified. Pagination applies to grouped aggregation results.", logger); + } + + if (!string.IsNullOrEmpty(after)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'after' parameter requires 'groupby' to be specified. Pagination applies to grouped aggregation results.", logger); + } + } + + if (!string.IsNullOrEmpty(after) && !first.HasValue) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'after' parameter requires 'first' to be specified. Provide 'first' to enable pagination.", logger); + } + + Dictionary? havingOperators = null; + List? havingInValues = null; + if (root.TryGetProperty("having", out JsonElement havingElement) && havingElement.ValueKind == JsonValueKind.Object) + { + if (groupby.Count == 0) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'having' parameter requires 'groupby' to be specified. HAVING filters groups after aggregation.", logger); + } + + havingOperators = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (JsonProperty prop in havingElement.EnumerateObject()) + { + // Reject unsupported operators (e.g. between, notIn, like) + if (!_validHavingOperators.Contains(prop.Name)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"Unsupported having operator '{prop.Name}'. Supported operators: {string.Join(", ", _validHavingOperators)}.", logger); + } + + if (prop.Name.Equals("in", StringComparison.OrdinalIgnoreCase)) + { + if (prop.Value.ValueKind != JsonValueKind.Array) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + "The 'having.in' value must be a numeric array. Example: {\"in\": [5, 10]}.", logger); + } + + havingInValues = new List(); + foreach (JsonElement item in prop.Value.EnumerateArray()) + { + if (item.ValueKind != JsonValueKind.Number) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"All values in 'having.in' must be numeric. Found non-numeric value: '{item}'.", logger); + } + + havingInValues.Add(item.GetDouble()); + } + } + else + { + // Scalar operators (eq, neq, gt, gte, lt, lte) must have numeric values + if (prop.Value.ValueKind != JsonValueKind.Number) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", + $"The 'having.{prop.Name}' value must be numeric. Got: '{prop.Value}'. HAVING filters compare aggregated numeric results.", logger); + } + + havingOperators[prop.Name] = prop.Value.GetDouble(); + } + } + } + + // Resolve metadata + if (!McpMetadataHelper.TryResolveMetadata( + entityName, + runtimeConfig, + serviceProvider, + out ISqlMetadataProvider sqlMetadataProvider, + out DatabaseObject dbObject, + out string dataSourceName, + out string metadataError)) + { + return McpResponseBuilder.BuildErrorResult(toolName, "EntityNotFound", metadataError, logger); + } + + // Early field validation: check all user-supplied field names before authorization or query building. + // This lets the model discover and fix typos immediately. + if (!isCountStar) + { + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, field, out _)) + { + return McpErrorHelpers.FieldNotFound(toolName, entityName, field, "field", logger); + } + } + + foreach (string groupbyField in groupby) + { + if (!sqlMetadataProvider.TryGetBackingColumn(entityName, groupbyField, out _)) + { + return McpErrorHelpers.FieldNotFound(toolName, entityName, groupbyField, "groupby", logger); + } + } + + // Authorization + IAuthorizationResolver authResolver = serviceProvider.GetRequiredService(); + IAuthorizationService authorizationService = serviceProvider.GetRequiredService(); + IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); + HttpContext? httpContext = httpContextAccessor.HttpContext; + + if (!McpAuthorizationHelper.ValidateRoleContext(httpContext, authResolver, out string roleCtxError)) + { + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", roleCtxError, logger); + } + + if (!McpAuthorizationHelper.TryResolveAuthorizedRole( + httpContext!, + authResolver, + entityName, + EntityActionOperation.Read, + out string? effectiveRole, + out string readAuthError)) + { + string finalError = readAuthError.StartsWith("You do not have permission", StringComparison.OrdinalIgnoreCase) + ? $"You do not have permission to read records for entity '{entityName}'." + : readAuthError; + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", finalError, logger); + } + + // Build select list for authorization: groupby fields + aggregation field + List selectFields = new(groupby); + if (!isCountStar && !selectFields.Contains(field, StringComparer.OrdinalIgnoreCase)) + { + selectFields.Add(field); + } + + // Build and validate Find context (reuse for authorization and OData filter parsing) + RequestValidator requestValidator = new(serviceProvider.GetRequiredService(), runtimeConfigProvider); + FindRequestContext context = new(entityName, dbObject, true); + httpContext!.Request.Method = "GET"; + + requestValidator.ValidateEntity(entityName); + + if (selectFields.Count > 0) + { + context.UpdateReturnFields(selectFields); + } + + if (!string.IsNullOrWhiteSpace(filter)) + { + string filterQueryString = $"?{RequestParser.FILTER_URL}={filter}"; + context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}"); + } + + requestValidator.ValidateRequestContext(context); + + AuthorizationResult authorizationResult = await authorizationService.AuthorizeAsync( + user: httpContext.User, + resource: context, + requirements: new[] { new ColumnsPermissionsRequirement() }); + if (!authorizationResult.Succeeded) + { + return McpErrorHelpers.PermissionDenied(toolName, entityName, "read", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + } + + // Build SqlQueryStructure to get OData filter → SQL predicate translation and DB policies + GQLFilterParser gQLFilterParser = serviceProvider.GetRequiredService(); + SqlQueryStructure structure = new( + context, sqlMetadataProvider, authResolver, runtimeConfigProvider, gQLFilterParser, httpContext); + + // Get database-specific components + DatabaseType databaseType = runtimeConfig.GetDataSourceFromDataSourceName(dataSourceName).DatabaseType; + + // Aggregation is only supported for MsSql/DWSQL (matching engine's GraphQL aggregation support) + if (databaseType != DatabaseType.MSSQL && databaseType != DatabaseType.DWSQL) + { + return McpResponseBuilder.BuildErrorResult(toolName, "UnsupportedDatabase", + $"Aggregation is not supported for database type '{databaseType}'. Aggregation is only available for Azure SQL, SQL Server, and SQL Data Warehouse.", logger); + } + + IAbstractQueryManagerFactory queryManagerFactory = serviceProvider.GetRequiredService(); + IQueryBuilder queryBuilder = queryManagerFactory.GetQueryBuilder(databaseType); + IQueryExecutor queryExecutor = queryManagerFactory.GetQueryExecutor(databaseType); + + // Resolve backing column name for the aggregation field (already validated early) + string? backingField = null; + if (!isCountStar) + { + sqlMetadataProvider.TryGetBackingColumn(entityName, field, out backingField); + } + else + { + // For COUNT(*), use primary key column since PK is always NOT NULL, + // making COUNT(pk) equivalent to COUNT(*). The engine's Build(AggregationColumn) + // does not support "*" as a column name (it would produce invalid SQL like count([].[*])). + SourceDefinition sourceDefinition = sqlMetadataProvider.GetSourceDefinition(entityName); + if (sourceDefinition.PrimaryKey.Count > 0) + { + backingField = sourceDefinition.PrimaryKey[0]; + } + } + + // Resolve backing column names for groupby fields (already validated early) + List<(string entityField, string backingColumn)> groupbyMapping = new(); + foreach (string groupbyField in groupby) + { + sqlMetadataProvider.TryGetBackingColumn(entityName, groupbyField, out string? backingGroupbyColumn); + groupbyMapping.Add((groupbyField, backingGroupbyColumn!)); + } + + string alias = ComputeAlias(function, field); + + // Clear default columns from FindRequestContext + structure.Columns.Clear(); + + // Add groupby columns as LabelledColumns and GroupByMetadata.Fields + foreach (var (entityField, backingColumn) in groupbyMapping) + { + structure.Columns.Add(new LabelledColumn( + dbObject.SchemaName, dbObject.Name, backingColumn, entityField, structure.SourceAlias)); + structure.GroupByMetadata.Fields[backingColumn] = new Column( + dbObject.SchemaName, dbObject.Name, backingColumn, structure.SourceAlias); + } + + // Build aggregation column using engine's AggregationColumn type. + // For COUNT(*), we use the primary key column (PK is always NOT NULL, so COUNT(pk) ≡ COUNT(*)). + AggregationType aggregationType = Enum.Parse(function); + AggregationColumn aggregationColumn = new( + dbObject.SchemaName, dbObject.Name, backingField!, aggregationType, alias, distinct, structure.SourceAlias); + + // Build HAVING predicates using engine's Predicate model + List havingPredicates = new(); + if (havingOperators != null) + { + foreach (var havingOperator in havingOperators) + { + PredicateOperation predicateOperation = havingOperator.Key.ToLowerInvariant() switch + { + "eq" => PredicateOperation.Equal, + "neq" => PredicateOperation.NotEqual, + "gt" => PredicateOperation.GreaterThan, + "gte" => PredicateOperation.GreaterThanOrEqual, + "lt" => PredicateOperation.LessThan, + "lte" => PredicateOperation.LessThanOrEqual, + _ => throw new ArgumentException($"Invalid having operator: {havingOperator.Key}") + }; + string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(paramName, new DbConnectionParam(havingOperator.Value)); + havingPredicates.Add(new Predicate( + new PredicateOperand(aggregationColumn), + predicateOperation, + new PredicateOperand(paramName))); + } + } + + if (havingInValues != null && havingInValues.Count > 0) + { + List inParams = new(); + foreach (double val in havingInValues) + { + string paramName = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(paramName, new DbConnectionParam(val)); + inParams.Add(paramName); + } + + havingPredicates.Add(new Predicate( + new PredicateOperand(aggregationColumn), + PredicateOperation.IN, + new PredicateOperand($"({string.Join(", ", inParams)})"))); + } + + // Combine multiple HAVING predicates with AND + Predicate? combinedHaving = null; + foreach (var predicate in havingPredicates) + { + combinedHaving = combinedHaving == null + ? predicate + : new Predicate(new PredicateOperand(combinedHaving), PredicateOperation.AND, new PredicateOperand(predicate)); + } + + structure.GroupByMetadata.Aggregations.Add( + new AggregationOperation(aggregationColumn, having: combinedHaving != null ? new List { combinedHaving } : null)); + structure.GroupByMetadata.RequestedAggregations = true; + + // Clear default OrderByColumns (PK-based) + structure.OrderByColumns.Clear(); + + // Set pagination limit if using first + if (first.HasValue && groupbyMapping.Count > 0) + { + structure.IsListQuery = true; + } + + // Use engine's query builder to generate SQL + string sql = queryBuilder.Build(structure); + + // For groupby queries: add ORDER BY aggregate expression and pagination + if (groupbyMapping.Count > 0) + { + string direction = orderby.Equals("asc", StringComparison.OrdinalIgnoreCase) ? "ASC" : "DESC"; + string quotedCol = $"{queryBuilder.QuoteIdentifier(structure.SourceAlias)}.{queryBuilder.QuoteIdentifier(backingField!)}"; + string orderByAggExpr = distinct + ? $"{function.ToUpperInvariant()}(DISTINCT {quotedCol})" + : $"{function.ToUpperInvariant()}({quotedCol})"; + string orderByClause = $" ORDER BY {orderByAggExpr} {direction}"; + + if (first.HasValue) + { + // With pagination: SQL Server requires ORDER BY for OFFSET/FETCH and + // does not allow both TOP and OFFSET/FETCH. Remove TOP and add ORDER BY + OFFSET/FETCH. + int offset = DecodeCursorOffset(after); + int fetchCount = first.Value + 1; + string offsetParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(offsetParam, new DbConnectionParam(offset)); + string limitParam = BaseQueryStructure.GetEncodedParamName(structure.Counter.Next()); + structure.Parameters.Add(limitParam, new DbConnectionParam(fetchCount)); + + string paginationClause = $" OFFSET {offsetParam} ROWS FETCH NEXT {limitParam} ROWS ONLY"; + + // Remove TOP N from the SELECT clause (TOP conflicts with OFFSET/FETCH) + sql = Regex.Replace(sql, @"SELECT TOP \d+", "SELECT"); + + // Insert ORDER BY + pagination before FOR JSON PATH + int jsonPathIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); + if (jsonPathIdx > 0) + { + sql = sql.Insert(jsonPathIdx, orderByClause + paginationClause); + } + else + { + sql += orderByClause + paginationClause; + } + } + else + { + // Without pagination: insert ORDER BY before FOR JSON PATH + int jsonPathIdx = sql.IndexOf(" FOR JSON PATH", StringComparison.OrdinalIgnoreCase); + if (jsonPathIdx > 0) + { + sql = sql.Insert(jsonPathIdx, orderByClause); + } + else + { + sql += orderByClause; + } + } + } + + // Execute the SQL aggregate query against the database + cancellationToken.ThrowIfCancellationRequested(); + JsonDocument? queryResult = await queryExecutor.ExecuteQueryAsync( + sql, + structure.Parameters, + queryExecutor.GetJsonResultAsync, + dataSourceName, + httpContext); + + // Parse result + JsonArray? resultArray = null; + if (queryResult != null) + { + resultArray = JsonSerializer.Deserialize(queryResult.RootElement.GetRawText()); + } + + // Format and return results + if (first.HasValue && groupby.Count > 0) + { + return BuildPaginatedResponse(resultArray, first.Value, after, entityName, logger); + } + + return BuildSimpleResponse(resultArray, entityName, alias, logger); + } + catch (TimeoutException timeoutException) + { + logger?.LogError(timeoutException, "Aggregation operation timed out for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult( + toolName, + "TimeoutError", + $"The aggregation query for entity '{entityName}' timed out. " + + "This is NOT a tool error. The database did not respond in time. " + + "This may occur with large datasets or complex aggregations. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.", + logger); + } + catch (TaskCanceledException taskCanceledException) + { + logger?.LogError(taskCanceledException, "Aggregation task was canceled for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult( + toolName, + "TimeoutError", + $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. " + + "This is NOT a tool error. The database did not respond in time. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination.", + logger); + } + catch (OperationCanceledException) + { + logger?.LogWarning("Aggregation operation was canceled for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult( + toolName, + "OperationCanceled", + $"The aggregation query for entity '{entityName}' was canceled before completion. " + + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. " + + "No results were returned. You may retry the same request.", + logger); + } + catch (DbException dbException) + { + logger?.LogError(dbException, "Database error during aggregation for entity {Entity}.", entityName); + return McpResponseBuilder.BuildErrorResult(toolName, "DatabaseOperationFailed", dbException.Message, logger); + } + catch (ArgumentException argumentException) + { + return McpResponseBuilder.BuildErrorResult(toolName, "InvalidArguments", argumentException.Message, logger); + } + catch (DataApiBuilderException dabException) + { + return McpResponseBuilder.BuildErrorResult(toolName, dabException.StatusCode.ToString(), dabException.Message, logger); + } + catch (Exception ex) + { + logger?.LogError(ex, "Unexpected error in AggregateRecordsTool."); + return McpResponseBuilder.BuildErrorResult(toolName, "UnexpectedError", "Unexpected error occurred in AggregateRecordsTool.", logger); + } + } + + /// + /// Computes the response alias for the aggregation result. + /// For count with "*", the alias is "count". Otherwise it's "{function}_{field}". + /// + internal static string ComputeAlias(string function, string field) + { + if (function == "count" && field == "*") + { + return "count"; + } + + return $"{function}_{field}"; + } + + /// + /// Decodes a base64-encoded cursor string to an integer offset. + /// Returns 0 if the cursor is null, empty, or invalid. + /// + internal static int DecodeCursorOffset(string? after) + { + if (string.IsNullOrWhiteSpace(after)) + { + return 0; + } + + try + { + byte[] bytes = Convert.FromBase64String(after); + string decoded = Encoding.UTF8.GetString(bytes); + return int.TryParse(decoded, out int cursorOffset) && cursorOffset >= 0 ? cursorOffset : 0; + } + catch (FormatException) + { + return 0; + } + } + + /// + /// Builds the paginated response from a SQL result that fetched first+1 rows. + /// + private static CallToolResult BuildPaginatedResponse( + JsonArray? resultArray, int first, string? after, string entityName, ILogger? logger) + { + int startOffset = DecodeCursorOffset(after); + int actualCount = resultArray?.Count ?? 0; + bool hasNextPage = actualCount > first; + int returnCount = hasNextPage ? first : actualCount; + + // Build page items from the SQL result + JsonArray pageItems = new(); + for (int i = 0; i < returnCount && resultArray != null && i < resultArray.Count; i++) + { + pageItems.Add(resultArray[i]?.DeepClone()); + } + + string? endCursor = null; + if (returnCount > 0) + { + int lastItemIndex = startOffset + returnCount; + endCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(lastItemIndex.ToString())); + } + + JsonElement itemsElement = JsonSerializer.Deserialize(pageItems.ToJsonString()); + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = new Dictionary + { + ["items"] = itemsElement, + ["endCursor"] = endCursor, + ["hasNextPage"] = hasNextPage + }, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); + } + + /// + /// Builds the simple (non-paginated) response from a SQL result. + /// + private static CallToolResult BuildSimpleResponse( + JsonArray? resultArray, string entityName, string alias, ILogger? logger) + { + JsonElement resultElement; + if (resultArray == null || resultArray.Count == 0) + { + // For non-grouped aggregate with no results, return null value + JsonArray nullArray = new() { new JsonObject { [alias] = null } }; + resultElement = JsonSerializer.Deserialize(nullArray.ToJsonString()); + } + else + { + resultElement = JsonSerializer.Deserialize(resultArray.ToJsonString()); + } + + return McpResponseBuilder.BuildSuccessResult( + new Dictionary + { + ["entity"] = entityName, + ["result"] = resultElement, + ["message"] = $"Successfully aggregated records for entity '{entityName}'" + }, + logger, + $"AggregateRecordsTool success for entity {entityName}."); + } + } +} diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md new file mode 100644 index 0000000000..718aa360e1 --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/AggregateRecordsTool.md @@ -0,0 +1,51 @@ +# AggregateRecordsTool + +MCP tool that computes SQL-level aggregations (COUNT, AVG, SUM, MIN, MAX) on DAB entities. All aggregation is pushed to the database engine — no in-memory computation. + +## Class Structure + +| Member | Kind | Purpose | +|---|---|---| +| `ToolType` | Property | Returns `ToolType.BuiltIn` for reflection-based discovery. | +| `_validFunctions` | Static field | Allowlist of aggregation functions: count, avg, sum, min, max. | +| `GetToolMetadata()` | Method | Returns the MCP `Tool` descriptor (name, description, JSON input schema). | +| `ExecuteAsync()` | Method | Main entry point — validates input, resolves metadata, authorizes, builds the SQL query via the engine's `IQueryBuilder.Build(SqlQueryStructure)`, executes it, and formats the response. | +| `ComputeAlias()` | Static method | Produces the result column alias: `"count"` for count(\*), otherwise `"{function}_{field}"`. | +| `DecodeCursorOffset()` | Static method | Decodes a base64 opaque cursor string to an integer offset for OFFSET/FETCH pagination. Returns 0 on any invalid input. | +| `BuildPaginatedResponse()` | Private method | Formats a grouped result set into `{ items, endCursor, hasNextPage }` when `first` is provided. | +| `BuildSimpleResponse()` | Private method | Formats a scalar or grouped result set without pagination. | + +## ExecuteAsync Sequence + +```mermaid +sequenceDiagram + participant Client as MCP Client + participant Tool as AggregateRecordsTool + participant Engine as DAB Engine + participant DB as Database + + Client->>Tool: ExecuteAsync(arguments) + Tool->>Tool: Validate inputs & check tool enabled + Tool->>Engine: Resolve entity metadata & validate fields + Tool->>Engine: Authorize (column-level permissions) + Tool->>Engine: Build SQL via queryBuilder.Build(SqlQueryStructure) + Tool->>Tool: Post-process SQL (ORDER BY, pagination) + Tool->>DB: ExecuteQueryAsync → JSON result + alt Paginated (first provided) + Tool-->>Client: { items, endCursor, hasNextPage } + else Simple + Tool-->>Client: { entity, result: [{alias: value}] } + end + + Note over Tool,Client: On error: TimeoutError, OperationCanceled, or DatabaseOperationFailed +``` + +## Key Design Decisions + +- **No in-memory aggregation.** The engine's `GroupByMetadata` / `AggregationColumn` types drive SQL generation via `queryBuilder.Build(structure)`. All aggregation is performed by the database. +- **COUNT(\*) workaround.** The engine's `Build(AggregationColumn)` doesn't support `*` as a column name (it produces invalid SQL like `count([].[*])`), so the primary key column is used instead. `COUNT(pk)` ≡ `COUNT(*)` since PK is NOT NULL. +- **ORDER BY post-processing.** Neither the GraphQL nor REST code paths support ORDER BY on an aggregate expression, so this tool inserts `ORDER BY {func}({col}) ASC|DESC` into the generated SQL before `FOR JSON PATH`. +- **TOP vs OFFSET/FETCH.** SQL Server forbids both in the same query. When pagination (`first`) is used, `TOP N` is stripped via regex before appending `OFFSET/FETCH NEXT`. +- **Early field validation.** All user-supplied field names (aggregation field, groupby fields) are validated against the entity's metadata before authorization or query building, so typos surface immediately with actionable guidance. +- **Timeout vs cancellation.** `TimeoutException` (from `query-timeout` config) and `OperationCanceledException` (from client disconnect) are handled separately with distinct model-facing messages. Timeouts guide the model to narrow filters or paginate; cancellations suggest retry. +- **Database support.** Only MsSql / DWSQL — matches the engine's GraphQL aggregation support. PostgreSQL, MySQL, and CosmosDB return an `UnsupportedDatabase` error. diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs index 1a5c223798..13835b2fa9 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpErrorHelpers.cs @@ -24,5 +24,16 @@ public static CallToolResult ToolDisabled(string toolName, ILogger? logger, stri string message = customMessage ?? $"The {toolName} tool is disabled in the configuration."; return McpResponseBuilder.BuildErrorResult(toolName, Model.McpErrorCode.ToolDisabled.ToString(), message, logger); } + + /// + /// Returns a model-friendly error when a field name is not found for an entity. + /// Guides the model to call describe_entities to discover valid field names. + /// + public static CallToolResult FieldNotFound(string toolName, string entityName, string fieldName, string parameterName, ILogger? logger) + { + string message = $"Field '{fieldName}' in '{parameterName}' was not found for entity '{entityName}'. " + + $"Call describe_entities to get valid field names for '{entityName}'."; + return McpResponseBuilder.BuildErrorResult(toolName, "FieldNotFound", message, logger); + } } } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs index f69a26fa5d..3ef3aa4d74 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryErrorCodes.cs @@ -37,5 +37,10 @@ internal static class McpTelemetryErrorCodes /// Operation cancelled error code. /// public const string OPERATION_CANCELLED = "OperationCancelled"; + + /// + /// Operation timed out error code. + /// + public const string TIMEOUT = "Timeout"; } } diff --git a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs index 2a60557f8d..31a92ef6e4 100644 --- a/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs +++ b/src/Azure.DataApiBuilder.Mcp/Utils/McpTelemetryHelper.cs @@ -60,8 +60,39 @@ public static async Task ExecuteWithTelemetryAsync( operation: operation, dbProcedure: dbProcedure); - // Execute the tool - CallToolResult result = await tool.ExecuteAsync(arguments, serviceProvider, cancellationToken); + // Read query-timeout from current config per invocation (hot-reload safe). + int timeoutSeconds = McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; + RuntimeConfigProvider? runtimeConfigProvider = serviceProvider.GetService(); + if (runtimeConfigProvider is not null) + { + RuntimeConfig config = runtimeConfigProvider.GetConfig(); + timeoutSeconds = config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds ?? McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; + } + + // Defensive runtime guard: clamp timeout to valid range [1, MAX_QUERY_TIMEOUT_SECONDS]. + if (timeoutSeconds < 1 || timeoutSeconds > McpRuntimeOptions.MAX_QUERY_TIMEOUT_SECONDS) + { + timeoutSeconds = McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS; + } + + // Wrap tool execution with the configured timeout using a linked CancellationTokenSource. + using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds)); + + CallToolResult result; + try + { + result = await tool.ExecuteAsync(arguments, serviceProvider, timeoutCts.Token); + } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + // The timeout CTS fired, not the caller's token. Surface as TimeoutException + // so downstream telemetry and tool handlers see TIMEOUT, not cancellation. + throw new TimeoutException( + $"The MCP tool '{toolName}' did not complete within {timeoutSeconds} {(timeoutSeconds == 1 ? "second" : "seconds")}. " + + "This is NOT a tool error. The operation exceeded the configured query-timeout. " + + "Try narrowing results with a filter, reducing groupby fields, or using pagination."); + } // Check if the tool returned an error result (tools catch exceptions internally // and return CallToolResult with IsError=true instead of throwing) @@ -124,6 +155,7 @@ public static string InferOperationFromTool(IMcpTool tool, string toolName) "delete_record" => "delete", "describe_entities" => "describe", "execute_entity" => "execute", + "aggregate_records" => "aggregate", _ => "execute" // Fallback for any unknown built-in tools }; } @@ -188,6 +220,7 @@ public static string MapExceptionToErrorCode(Exception ex) return ex switch { OperationCanceledException => McpTelemetryErrorCodes.OPERATION_CANCELLED, + TimeoutException => McpTelemetryErrorCodes.TIMEOUT, DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthenticationChallenge => McpTelemetryErrorCodes.AUTHENTICATION_FAILED, DataApiBuilderException dabEx when dabEx.SubStatusCode == DataApiBuilderException.SubStatusCodes.AuthorizationCheckFailed diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt index 3fa1fbc14e..0fd0030402 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestAddingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt index 76ea01dfca..725eed7a83 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceAsStoredProcedure.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt index 3a8c738a70..70cb42137b 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithSourceWithDefaultType.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt index df2cd4b009..46bec31cc9 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestConfigGeneratedAfterAddingEntityWithoutIEnumerables.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt index 1b14a3a7f0..0932956d7a 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestInitForCosmosDBNoSql.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: planet, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt index 62d9e237b5..fdda324d36 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethods.verified.txt @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt index fa8b16e739..2a4e8653a1 100644 --- a/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt +++ b/src/Cli.Tests/Snapshots/EndToEndTests.TestUpdatingStoredProcedureWithRestMethodsAndGraphQLOperations.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt b/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt index 9d5458c0ee..4870537837 100644 --- a/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.CosmosDbNoSqlDatabase.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt b/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt index 51f6ad8d95..e03973b91e 100644 --- a/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.CosmosDbPostgreSqlDatabase.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt index 978d1a253b..d33247dcab 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_171ea8114ff71814.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt index 402bf4d2bc..fa08aefa62 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_2df7a1794712f154.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt index ab71a40f03..98fdb25c77 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_59fe1a10aa78899d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt index 25e3976685..74afea9ef6 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_b95b637ea87f16a7.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt index 140f017b78..3145f775c0 100644 --- a/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.EnsureCorrectConfigGenerationWithDifferentAuthenticationProviders_daacbd948b7ef72f.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt b/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt index a3a056ac0a..ae32e3b379 100644 --- a/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.GraphQLPathWithoutStartingSlashWillHaveItAdded.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt b/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt index f40350c4da..0f2c151763 100644 --- a/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.MsSQLDatabase.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt b/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt index b792d41c9f..d9067e1b43 100644 --- a/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.RestPathWithoutStartingSlashWillHaveItAdded.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt b/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt index 173960d7b1..e48b87e1c8 100644 --- a/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.TestInitializingConfigWithoutConnectionString.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt b/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt index 25e3976685..74afea9ef6 100644 --- a/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.TestSpecialCharactersInConnectionString.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt index 63f0da701c..2cb50a06da 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0546bef37027a950.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: DWSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt index f40350c4da..0f2c151763 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0ac567dd32a2e8f5.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt index e59070d692..bbea4aadd3 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_0c06949221514e77.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -32,14 +32,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt index f7de35b7ae..63f411cdb2 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_18667ab7db033e9d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt index 63f0da701c..2cb50a06da 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_2f42f44c328eb020.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: DWSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt index f7de35b7ae..63f411cdb2 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_3243d3f3441fdcc1.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt index 75613db959..5af597f50a 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_53350b8b47df2112.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MySQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt index d93aac7dc6..860fa1616c 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_6584e0ec46b8a11d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt index 640815babb..48f3d0ce51 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_81cc88db3d4eecfb.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MSSQL, Options: { @@ -32,14 +32,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt index 5900015d5a..f56dcad7d7 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_8ea187616dbb5577.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt index 63f0da701c..2cb50a06da 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_905845c29560a3ef.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: DWSQL, Options: { @@ -27,14 +27,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt index d93aac7dc6..860fa1616c 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_b2fd24fab5b80917.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt index d93aac7dc6..860fa1616c 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_bd7cd088755287c9.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { Options: { container: testcontainer, @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt index 75613db959..5af597f50a 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d2eccba2f836b380.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MySQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt index 5900015d5a..f56dcad7d7 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d463eed7fe5e4bbe.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt index 75613db959..5af597f50a 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_d5520dd5c33f7b8d.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: MySQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt index f7de35b7ae..63f411cdb2 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_eab4a6010e602b59.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt index 5900015d5a..f56dcad7d7 100644 --- a/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt +++ b/src/Cli.Tests/Snapshots/InitTests.VerifyCorrectConfigGenerationWithMultipleMutationOptions_ecaa688829b4030e.verified.txt @@ -1,4 +1,4 @@ -{ +{ DataSource: { DatabaseType: CosmosDB_PostgreSQL }, @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 262cbc9145..99d7efc637 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -42,6 +42,7 @@ public ConfigureOptions( bool? runtimeMcpEnabled = null, string? runtimeMcpPath = null, string? runtimeMcpDescription = null, + int? runtimeMcpQueryTimeout = null, bool? runtimeMcpDmlToolsEnabled = null, bool? runtimeMcpDmlToolsDescribeEntitiesEnabled = null, bool? runtimeMcpDmlToolsCreateRecordEnabled = null, @@ -49,6 +50,7 @@ public ConfigureOptions( bool? runtimeMcpDmlToolsUpdateRecordEnabled = null, bool? runtimeMcpDmlToolsDeleteRecordEnabled = null, bool? runtimeMcpDmlToolsExecuteEntityEnabled = null, + bool? runtimeMcpDmlToolsAggregateRecordsEnabled = null, bool? runtimeCacheEnabled = null, int? runtimeCacheTtl = null, CompressionLevel? runtimeCompressionLevel = null, @@ -102,6 +104,7 @@ public ConfigureOptions( RuntimeMcpEnabled = runtimeMcpEnabled; RuntimeMcpPath = runtimeMcpPath; RuntimeMcpDescription = runtimeMcpDescription; + RuntimeMcpQueryTimeout = runtimeMcpQueryTimeout; RuntimeMcpDmlToolsEnabled = runtimeMcpDmlToolsEnabled; RuntimeMcpDmlToolsDescribeEntitiesEnabled = runtimeMcpDmlToolsDescribeEntitiesEnabled; RuntimeMcpDmlToolsCreateRecordEnabled = runtimeMcpDmlToolsCreateRecordEnabled; @@ -109,6 +112,7 @@ public ConfigureOptions( RuntimeMcpDmlToolsUpdateRecordEnabled = runtimeMcpDmlToolsUpdateRecordEnabled; RuntimeMcpDmlToolsDeleteRecordEnabled = runtimeMcpDmlToolsDeleteRecordEnabled; RuntimeMcpDmlToolsExecuteEntityEnabled = runtimeMcpDmlToolsExecuteEntityEnabled; + RuntimeMcpDmlToolsAggregateRecordsEnabled = runtimeMcpDmlToolsAggregateRecordsEnabled; // Cache RuntimeCacheEnabled = runtimeCacheEnabled; RuntimeCacheTTL = runtimeCacheTtl; @@ -203,6 +207,9 @@ public ConfigureOptions( [Option("runtime.mcp.description", Required = false, HelpText = "Set the MCP server description to be exposed in the initialize response.")] public string? RuntimeMcpDescription { get; } + [Option("runtime.mcp.query-timeout", Required = false, HelpText = "Set the execution timeout in seconds for MCP tool operations. Applies to all MCP tools. Default: 30 (integer). Must be >= 1.")] + public int? RuntimeMcpQueryTimeout { get; } + [Option("runtime.mcp.dml-tools.enabled", Required = false, HelpText = "Enable DAB's MCP DML tools endpoint. Default: true (boolean).")] public bool? RuntimeMcpDmlToolsEnabled { get; } @@ -224,6 +231,9 @@ public ConfigureOptions( [Option("runtime.mcp.dml-tools.execute-entity.enabled", Required = false, HelpText = "Enable DAB's MCP execute entity tool. Default: true (boolean).")] public bool? RuntimeMcpDmlToolsExecuteEntityEnabled { get; } + [Option("runtime.mcp.dml-tools.aggregate-records.enabled", Required = false, HelpText = "Enable DAB's MCP aggregate records tool. Default: true (boolean).")] + public bool? RuntimeMcpDmlToolsAggregateRecordsEnabled { get; } + [Option("runtime.cache.enabled", Required = false, HelpText = "Enable DAB's cache globally. (You must also enable each entity's cache separately.). Default: false (boolean).")] public bool? RuntimeCacheEnabled { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 6d1fcf946d..2dfe14796a 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -876,13 +876,15 @@ private static bool TryUpdateConfiguredRuntimeOptions( if (options.RuntimeMcpEnabled != null || options.RuntimeMcpPath != null || options.RuntimeMcpDescription != null || + options.RuntimeMcpQueryTimeout != null || options.RuntimeMcpDmlToolsEnabled != null || options.RuntimeMcpDmlToolsDescribeEntitiesEnabled != null || options.RuntimeMcpDmlToolsCreateRecordEnabled != null || options.RuntimeMcpDmlToolsReadRecordsEnabled != null || options.RuntimeMcpDmlToolsUpdateRecordEnabled != null || options.RuntimeMcpDmlToolsDeleteRecordEnabled != null || - options.RuntimeMcpDmlToolsExecuteEntityEnabled != null) + options.RuntimeMcpDmlToolsExecuteEntityEnabled != null || + options.RuntimeMcpDmlToolsAggregateRecordsEnabled != null) { McpRuntimeOptions updatedMcpOptions = runtimeConfig?.Runtime?.Mcp ?? new(); bool status = TryUpdateConfiguredMcpValues(options, ref updatedMcpOptions); @@ -1161,6 +1163,14 @@ private static bool TryUpdateConfiguredMcpValues( _logger.LogInformation("Updated RuntimeConfig with Runtime.Mcp.Description as '{updatedValue}'", updatedValue); } + // Runtime.Mcp.QueryTimeout + updatedValue = options?.RuntimeMcpQueryTimeout; + if (updatedValue != null) + { + updatedMcpOptions = updatedMcpOptions! with { QueryTimeout = (int)updatedValue }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Mcp.QueryTimeout as '{updatedValue}'", updatedValue); + } + // Handle DML tools configuration bool hasToolUpdates = false; DmlToolsConfig? currentDmlTools = updatedMcpOptions?.DmlTools; @@ -1181,6 +1191,7 @@ private static bool TryUpdateConfiguredMcpValues( bool? updateRecord = currentDmlTools?.UpdateRecord; bool? deleteRecord = currentDmlTools?.DeleteRecord; bool? executeEntity = currentDmlTools?.ExecuteEntity; + bool? aggregateRecords = currentDmlTools?.AggregateRecords; updatedValue = options?.RuntimeMcpDmlToolsDescribeEntitiesEnabled; if (updatedValue != null) @@ -1230,6 +1241,14 @@ private static bool TryUpdateConfiguredMcpValues( _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.execute-entity as '{updatedValue}'", updatedValue); } + updatedValue = options?.RuntimeMcpDmlToolsAggregateRecordsEnabled; + if (updatedValue != null) + { + aggregateRecords = (bool)updatedValue; + hasToolUpdates = true; + _logger.LogInformation("Updated RuntimeConfig with runtime.mcp.dml-tools.aggregate-records as '{updatedValue}'", updatedValue); + } + if (hasToolUpdates) { updatedMcpOptions = updatedMcpOptions! with @@ -1242,7 +1261,8 @@ private static bool TryUpdateConfiguredMcpValues( ReadRecords = readRecord, UpdateRecord = updateRecord, DeleteRecord = deleteRecord, - ExecuteEntity = executeEntity + ExecuteEntity = executeEntity, + AggregateRecords = aggregateRecords } }; } diff --git a/src/Config/Converters/DmlToolsConfigConverter.cs b/src/Config/Converters/DmlToolsConfigConverter.cs index 82ac3f6069..7e049c7926 100644 --- a/src/Config/Converters/DmlToolsConfigConverter.cs +++ b/src/Config/Converters/DmlToolsConfigConverter.cs @@ -44,6 +44,7 @@ internal class DmlToolsConfigConverter : JsonConverter bool? updateRecord = null; bool? deleteRecord = null; bool? executeEntity = null; + bool? aggregateRecords = null; while (reader.Read()) { @@ -82,6 +83,9 @@ internal class DmlToolsConfigConverter : JsonConverter case "execute-entity": executeEntity = value; break; + case "aggregate-records": + aggregateRecords = value; + break; default: // Skip unknown properties break; @@ -91,7 +95,8 @@ internal class DmlToolsConfigConverter : JsonConverter { // Error on non-boolean values for known properties if (property?.ToLowerInvariant() is "describe-entities" or "create-record" - or "read-records" or "update-record" or "delete-record" or "execute-entity") + or "read-records" or "update-record" or "delete-record" or "execute-entity" + or "aggregate-records") { throw new JsonException($"Property '{property}' must be a boolean value."); } @@ -110,7 +115,8 @@ internal class DmlToolsConfigConverter : JsonConverter readRecords: readRecords, updateRecord: updateRecord, deleteRecord: deleteRecord, - executeEntity: executeEntity); + executeEntity: executeEntity, + aggregateRecords: aggregateRecords); } // For any other unexpected token type, return default (all enabled) @@ -135,7 +141,8 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer value.UserProvidedReadRecords || value.UserProvidedUpdateRecord || value.UserProvidedDeleteRecord || - value.UserProvidedExecuteEntity; + value.UserProvidedExecuteEntity || + value.UserProvidedAggregateRecords; // Only write the boolean value if it's provided by user // This prevents writing "dml-tools": true when it's the default @@ -181,6 +188,11 @@ public override void Write(Utf8JsonWriter writer, DmlToolsConfig? value, JsonSer writer.WriteBoolean("execute-entity", value.ExecuteEntity.Value); } + if (value.UserProvidedAggregateRecords && value.AggregateRecords.HasValue) + { + writer.WriteBoolean("aggregate-records", value.AggregateRecords.Value); + } + writer.WriteEndObject(); } } diff --git a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs index 8b3c640725..ad4edc229e 100644 --- a/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs +++ b/src/Config/Converters/McpRuntimeOptionsConverterFactory.cs @@ -66,12 +66,13 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings? string? path = null; DmlToolsConfig? dmlTools = null; string? description = null; + int? queryTimeout = null; while (reader.Read()) { if (reader.TokenType == JsonTokenType.EndObject) { - return new McpRuntimeOptions(enabled, path, dmlTools, description); + return new McpRuntimeOptions(enabled, path, dmlTools, description, queryTimeout); } string? propertyName = reader.GetString(); @@ -107,6 +108,14 @@ internal McpRuntimeOptionsConverter(DeserializationVariableReplacementSettings? break; + case "query-timeout": + if (reader.TokenType is not JsonTokenType.Null) + { + queryTimeout = reader.GetInt32(); + } + + break; + default: throw new JsonException($"Unexpected property {propertyName}"); } @@ -150,6 +159,12 @@ public override void Write(Utf8JsonWriter writer, McpRuntimeOptions value, JsonS JsonSerializer.Serialize(writer, value.Description, options); } + // Write query-timeout if it's user provided + if (value?.UserProvidedQueryTimeout is true && value.QueryTimeout.HasValue) + { + writer.WriteNumber("query-timeout", value.QueryTimeout.Value); + } + writer.WriteEndObject(); } } diff --git a/src/Config/ObjectModel/DmlToolsConfig.cs b/src/Config/ObjectModel/DmlToolsConfig.cs index 2a09e9d53c..c1f8b278cd 100644 --- a/src/Config/ObjectModel/DmlToolsConfig.cs +++ b/src/Config/ObjectModel/DmlToolsConfig.cs @@ -51,6 +51,11 @@ public record DmlToolsConfig /// public bool? ExecuteEntity { get; init; } + /// + /// Whether aggregate-records tool is enabled + /// + public bool? AggregateRecords { get; init; } + [JsonConstructor] public DmlToolsConfig( bool? allToolsEnabled = null, @@ -59,7 +64,8 @@ public DmlToolsConfig( bool? readRecords = null, bool? updateRecord = null, bool? deleteRecord = null, - bool? executeEntity = null) + bool? executeEntity = null, + bool? aggregateRecords = null) { if (allToolsEnabled is not null) { @@ -75,6 +81,7 @@ public DmlToolsConfig( UpdateRecord = updateRecord ?? toolDefault; DeleteRecord = deleteRecord ?? toolDefault; ExecuteEntity = executeEntity ?? toolDefault; + AggregateRecords = aggregateRecords ?? toolDefault; } else { @@ -87,6 +94,7 @@ public DmlToolsConfig( UpdateRecord = updateRecord ?? DEFAULT_ENABLED; DeleteRecord = deleteRecord ?? DEFAULT_ENABLED; ExecuteEntity = executeEntity ?? DEFAULT_ENABLED; + AggregateRecords = aggregateRecords ?? DEFAULT_ENABLED; } // Track user-provided status - only true if the parameter was not null @@ -96,6 +104,7 @@ public DmlToolsConfig( UserProvidedUpdateRecord = updateRecord is not null; UserProvidedDeleteRecord = deleteRecord is not null; UserProvidedExecuteEntity = executeEntity is not null; + UserProvidedAggregateRecords = aggregateRecords is not null; } /// @@ -112,7 +121,8 @@ public static DmlToolsConfig FromBoolean(bool enabled) readRecords: null, updateRecord: null, deleteRecord: null, - executeEntity: null + executeEntity: null, + aggregateRecords: null ); } @@ -127,7 +137,8 @@ public static DmlToolsConfig FromBoolean(bool enabled) readRecords: null, updateRecord: null, deleteRecord: null, - executeEntity: null + executeEntity: null, + aggregateRecords: null ); /// @@ -185,4 +196,12 @@ public static DmlToolsConfig FromBoolean(bool enabled) [JsonIgnore(Condition = JsonIgnoreCondition.Always)] [MemberNotNullWhen(true, nameof(ExecuteEntity))] public bool UserProvidedExecuteEntity { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write aggregate-records + /// property/value to the runtime config file. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(AggregateRecords))] + public bool UserProvidedAggregateRecords { get; init; } = false; } diff --git a/src/Config/ObjectModel/McpRuntimeOptions.cs b/src/Config/ObjectModel/McpRuntimeOptions.cs index e17d53fc8f..5b48b2fcc3 100644 --- a/src/Config/ObjectModel/McpRuntimeOptions.cs +++ b/src/Config/ObjectModel/McpRuntimeOptions.cs @@ -10,6 +10,8 @@ namespace Azure.DataApiBuilder.Config.ObjectModel; public record McpRuntimeOptions { public const string DEFAULT_PATH = "/mcp"; + public const int DEFAULT_QUERY_TIMEOUT_SECONDS = 30; + public const int MAX_QUERY_TIMEOUT_SECONDS = 600; /// /// Whether MCP endpoints are enabled @@ -36,12 +38,22 @@ public record McpRuntimeOptions [JsonPropertyName("description")] public string? Description { get; init; } + /// + /// Execution timeout in seconds for MCP tool operations. + /// This timeout wraps the entire tool execution including database queries. + /// It applies to all MCP tools, not just aggregation. + /// Default: 30 seconds. + /// + [JsonPropertyName("query-timeout")] + public int? QueryTimeout { get; init; } + [JsonConstructor] public McpRuntimeOptions( bool? Enabled = null, string? Path = null, DmlToolsConfig? DmlTools = null, - string? Description = null) + string? Description = null, + int? QueryTimeout = null) { this.Enabled = Enabled ?? true; @@ -67,6 +79,12 @@ public McpRuntimeOptions( } this.Description = Description; + + if (QueryTimeout is not null) + { + this.QueryTimeout = QueryTimeout; + UserProvidedQueryTimeout = true; + } } /// @@ -78,4 +96,17 @@ public McpRuntimeOptions( [JsonIgnore(Condition = JsonIgnoreCondition.Always)] [MemberNotNullWhen(true, nameof(Enabled))] public bool UserProvidedPath { get; init; } = false; + + /// + /// Flag which informs CLI and JSON serializer whether to write query-timeout + /// property and value to the runtime config file. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedQueryTimeout { get; init; } = false; + + /// + /// Gets the effective query timeout in seconds, using the default if not specified. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public int EffectiveQueryTimeoutSeconds => QueryTimeout ?? DEFAULT_QUERY_TIMEOUT_SECONDS; } diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index f5112844da..90550fbaba 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -914,6 +914,17 @@ public void ValidateMcpUri(RuntimeConfig runtimeConfig) statusCode: HttpStatusCode.ServiceUnavailable, subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); } + + // Validate query-timeout if provided + if (runtimeConfig.Runtime.Mcp.QueryTimeout is not null && + (runtimeConfig.Runtime.Mcp.QueryTimeout < 1 || runtimeConfig.Runtime.Mcp.QueryTimeout > McpRuntimeOptions.MAX_QUERY_TIMEOUT_SECONDS)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"MCP query-timeout must be between 1 and {McpRuntimeOptions.MAX_QUERY_TIMEOUT_SECONDS} seconds. " + + $"Provided value: {runtimeConfig.Runtime.Mcp.QueryTimeout}.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } } private void ValidateAuthenticationOptions(RuntimeConfig runtimeConfig) diff --git a/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs new file mode 100644 index 0000000000..bab8d68f2c --- /dev/null +++ b/src/Service.Tests/Mcp/AggregateRecordsToolTests.cs @@ -0,0 +1,1089 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.BuiltInTools; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for the AggregateRecordsTool MCP tool. + /// Covers: + /// - Tool metadata and schema validation + /// - Runtime-level enabled/disabled configuration + /// - Entity-level DML tool configuration + /// - Input validation (missing/invalid arguments) + /// - SQL expression generation (count, avg, sum, min, max, distinct) + /// - Table reference quoting, cursor/pagination logic + /// - Alias convention + /// + [TestClass] + public class AggregateRecordsToolTests + { + #region Tool Metadata Tests + + [TestMethod] + public void GetToolMetadata_ReturnsCorrectName() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + Assert.AreEqual("aggregate_records", metadata.Name); + } + + [TestMethod] + public void GetToolMetadata_ReturnsCorrectToolType() + { + AggregateRecordsTool tool = new(); + Assert.AreEqual(McpEnums.ToolType.BuiltIn, tool.ToolType); + } + + [TestMethod] + public void GetToolMetadata_HasInputSchema() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + Assert.AreEqual(JsonValueKind.Object, metadata.InputSchema.ValueKind); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("properties", out JsonElement properties)); + Assert.IsTrue(metadata.InputSchema.TryGetProperty("required", out JsonElement required)); + + List requiredFields = new(); + foreach (JsonElement r in required.EnumerateArray()) + { + requiredFields.Add(r.GetString()!); + } + + CollectionAssert.Contains(requiredFields, "entity"); + CollectionAssert.Contains(requiredFields, "function"); + CollectionAssert.Contains(requiredFields, "field"); + + // Verify first and after properties exist in schema + Assert.IsTrue(properties.TryGetProperty("first", out JsonElement firstProp)); + Assert.AreEqual("integer", firstProp.GetProperty("type").GetString()); + Assert.IsTrue(properties.TryGetProperty("after", out JsonElement afterProp)); + Assert.AreEqual("string", afterProp.GetProperty("type").GetString()); + } + + #endregion + + #region Configuration Tests + + [TestMethod] + public async Task AggregateRecords_DisabledAtRuntimeLevel_ReturnsToolDisabledError() + { + RuntimeConfig config = CreateConfig(aggregateRecordsEnabled: false); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + AssertToolDisabledError(content); + } + + [TestMethod] + public async Task AggregateRecords_DisabledAtEntityLevel_ReturnsToolDisabledError() + { + RuntimeConfig config = CreateConfigWithEntityDmlDisabled(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + AssertToolDisabledError(content); + } + + #endregion + + #region Input Validation Tests + + [TestMethod] + public async Task AggregateRecords_NullArguments_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + CallToolResult result = await tool.ExecuteAsync(null, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.AreEqual("InvalidArguments", error.GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingEntity_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingFunction_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_MissingField_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + } + + [TestMethod] + public async Task AggregateRecords_InvalidFunction_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"median\", \"field\": \"price\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("median")); + } + + [TestMethod] + public async Task AggregateRecords_StarFieldWithAvg_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"avg\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("count")); + } + + [TestMethod] + public async Task AggregateRecords_DistinctCountStar_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"distinct\": true}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("DISTINCT")); + } + + [TestMethod] + public async Task AggregateRecords_HavingWithoutGroupBy_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"having\": {\"gt\": 5}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("groupby")); + } + + [TestMethod] + public async Task AggregateRecords_OrderByWithoutGroupBy_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"orderby\": \"desc\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("groupby")); + } + + [TestMethod] + public async Task AggregateRecords_UnsupportedHavingOperator_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"between\": 5}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("between")); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("Supported operators")); + } + + [TestMethod] + public async Task AggregateRecords_NonNumericHavingValue_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"eq\": \"ten\"}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("numeric")); + } + + [TestMethod] + public async Task AggregateRecords_NonNumericHavingInArray_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"in\": [5, \"abc\"]}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("numeric")); + } + + [TestMethod] + public async Task AggregateRecords_HavingInNotArray_ReturnsInvalidArguments() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\", \"groupby\": [\"title\"], \"having\": {\"in\": 5}}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.AreEqual("InvalidArguments", content.GetProperty("error").GetProperty("type").GetString()); + Assert.IsTrue(content.GetProperty("error").GetProperty("message").GetString()!.Contains("numeric array")); + } + + #endregion + + #region Alias Convention Tests + + [TestMethod] + public void ComputeAlias_CountStar_ReturnsCount() + { + Assert.AreEqual("count", AggregateRecordsTool.ComputeAlias("count", "*")); + } + + [TestMethod] + public void ComputeAlias_CountField_ReturnsFunctionField() + { + Assert.AreEqual("count_supplierId", AggregateRecordsTool.ComputeAlias("count", "supplierId")); + } + + [TestMethod] + public void ComputeAlias_AvgField_ReturnsFunctionField() + { + Assert.AreEqual("avg_unitPrice", AggregateRecordsTool.ComputeAlias("avg", "unitPrice")); + } + + [TestMethod] + public void ComputeAlias_SumField_ReturnsFunctionField() + { + Assert.AreEqual("sum_unitPrice", AggregateRecordsTool.ComputeAlias("sum", "unitPrice")); + } + + [TestMethod] + public void ComputeAlias_MinField_ReturnsFunctionField() + { + Assert.AreEqual("min_price", AggregateRecordsTool.ComputeAlias("min", "price")); + } + + [TestMethod] + public void ComputeAlias_MaxField_ReturnsFunctionField() + { + Assert.AreEqual("max_price", AggregateRecordsTool.ComputeAlias("max", "price")); + } + + #endregion + + #region Cursor and Pagination Tests + + [TestMethod] + public void DecodeCursorOffset_NullCursor_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); + } + + [TestMethod] + public void DecodeCursorOffset_EmptyCursor_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("")); + } + + [TestMethod] + public void DecodeCursorOffset_WhitespaceCursor_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(" ")); + } + + [TestMethod] + public void DecodeCursorOffset_ValidBase64Cursor_ReturnsDecodedOffset() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); + Assert.AreEqual(5, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_InvalidBase64_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("not-valid-base64!!!")); + } + + [TestMethod] + public void DecodeCursorOffset_NonNumericBase64_ReturnsZero() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("abc")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_RoundTrip_PreservesOffset() + { + int expectedOffset = 15; + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(expectedOffset.ToString())); + Assert.AreEqual(expectedOffset, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_ZeroOffset_ReturnsZero() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("0")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_LargeOffset_ReturnsCorrectValue() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("1000")); + Assert.AreEqual(1000, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + #endregion + + #region Timeout and Cancellation Tests + + /// + /// Verifies that OperationCanceledException produces a model-explicit error + /// that clearly states the operation was canceled, not errored. + /// + [TestMethod] + public async Task AggregateRecords_OperationCanceled_ReturnsExplicitCanceledMessage() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + // Create a pre-canceled token + CancellationTokenSource cts = new(); + cts.Cancel(); + + JsonDocument args = JsonDocument.Parse("{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}"); + CallToolResult result = await tool.ExecuteAsync(args, sp, cts.Token); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + string errorType = error.GetProperty("type").GetString(); + string errorMessage = error.GetProperty("message").GetString(); + + // Verify the error type identifies it as a cancellation + Assert.IsNotNull(errorType); + Assert.AreEqual("OperationCanceled", errorType); + + // Verify the message explicitly tells the model this is NOT a tool error + Assert.IsNotNull(errorMessage); + Assert.IsTrue(errorMessage!.Contains("NOT a tool error"), "Message must explicitly state this is NOT a tool error."); + + // Verify the message tells the model what happened + Assert.IsTrue(errorMessage.Contains("canceled"), "Message must mention the operation was canceled."); + + // Verify the message tells the model it can retry + Assert.IsTrue(errorMessage.Contains("retry"), "Message must tell the model it can retry."); + } + + /// + /// Verifies that the timeout error message provides explicit guidance to the model + /// about what happened and what to do next. + /// + [TestMethod] + public void TimeoutErrorMessage_ContainsModelGuidance() + { + // Simulate what the tool builds for a TimeoutException response + string entityName = "Product"; + string expectedMessage = $"The aggregation query for entity '{entityName}' timed out. " + + "This is NOT a tool error. The database did not respond in time. " + + "This may occur with large datasets or complex aggregations. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination."; + + // Verify message explicitly states it's NOT a tool error + Assert.IsTrue(expectedMessage.Contains("NOT a tool error"), "Timeout message must state this is NOT a tool error."); + + // Verify message explains the cause + Assert.IsTrue(expectedMessage.Contains("database did not respond"), "Timeout message must explain the database didn't respond."); + + // Verify message mentions large datasets + Assert.IsTrue(expectedMessage.Contains("large datasets"), "Timeout message must mention large datasets as a possible cause."); + + // Verify message provides actionable remediation steps + Assert.IsTrue(expectedMessage.Contains("filter"), "Timeout message must suggest using a filter."); + Assert.IsTrue(expectedMessage.Contains("groupby"), "Timeout message must suggest reducing groupby fields."); + Assert.IsTrue(expectedMessage.Contains("first"), "Timeout message must suggest using pagination with first."); + } + + /// + /// Verifies that TaskCanceledException (which typically signals HTTP/DB timeout) + /// produces a TimeoutError, not a cancellation error. + /// + [TestMethod] + public void TaskCanceledErrorMessage_ContainsTimeoutGuidance() + { + // Simulate what the tool builds for a TaskCanceledException response + string entityName = "Product"; + string expectedMessage = $"The aggregation query for entity '{entityName}' was canceled, likely due to a timeout. " + + "This is NOT a tool error. The database did not respond in time. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination."; + + // TaskCanceledException should produce a TimeoutError, not OperationCanceled + Assert.IsTrue(expectedMessage.Contains("NOT a tool error"), "TaskCanceled message must state this is NOT a tool error."); + Assert.IsTrue(expectedMessage.Contains("timeout"), "TaskCanceled message must reference timeout as the cause."); + Assert.IsTrue(expectedMessage.Contains("filter"), "TaskCanceled message must suggest filter as remediation."); + Assert.IsTrue(expectedMessage.Contains("first"), "TaskCanceled message must suggest first for pagination."); + } + + /// + /// Verifies that the OperationCanceled error message for a specific entity + /// includes the entity name so the model knows which aggregation failed. + /// + [TestMethod] + public void CanceledErrorMessage_IncludesEntityName() + { + string entityName = "LargeProductCatalog"; + string expectedMessage = $"The aggregation query for entity '{entityName}' was canceled before completion. " + + "This is NOT a tool error. The operation was interrupted, possibly due to a timeout or client disconnect. " + + "No results were returned. You may retry the same request."; + + Assert.IsTrue(expectedMessage.Contains(entityName), "Canceled message must include the entity name."); + Assert.IsTrue(expectedMessage.Contains("No results were returned"), "Canceled message must state no results were returned."); + } + + /// + /// Verifies that the timeout error message for a specific entity + /// includes the entity name so the model knows which aggregation timed out. + /// + [TestMethod] + public void TimeoutErrorMessage_IncludesEntityName() + { + string entityName = "HugeTransactionLog"; + string expectedMessage = $"The aggregation query for entity '{entityName}' timed out. " + + "This is NOT a tool error. The database did not respond in time. " + + "This may occur with large datasets or complex aggregations. " + + "Try narrowing results with a 'filter', reducing 'groupby' fields, or adding 'first' for pagination."; + + Assert.IsTrue(expectedMessage.Contains(entityName), "Timeout message must include the entity name."); + } + + #endregion + + #region Spec Example Tests + + /// + /// Spec Example 1: "How many products are there?" + /// COUNT(*) - expects alias "count" + /// + [TestMethod] + public void SpecExample01_CountStar_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + Assert.AreEqual("count", alias); + } + + /// + /// Spec Example 2: "What is the average price of products under $10?" + /// AVG(unitPrice) with filter + /// + [TestMethod] + public void SpecExample02_AvgWithFilter_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + Assert.AreEqual("avg_unitPrice", alias); + } + + /// + /// Spec Example 3: "Which categories have more than 20 products?" + /// COUNT(*) GROUP BY categoryName HAVING gt 20 + /// + [TestMethod] + public void SpecExample03_CountGroupByHavingGt_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + Assert.AreEqual("count", alias); + } + + /// + /// Spec Example 4: "For discontinued products, which categories have total revenue between $500 and $10,000?" + /// SUM(unitPrice) GROUP BY categoryName HAVING gte 500 AND lte 10000 + /// + [TestMethod] + public void SpecExample04_SumFilterGroupByHavingRange_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); + Assert.AreEqual("sum_unitPrice", alias); + } + + /// + /// Spec Example 5: "How many distinct suppliers do we have?" + /// COUNT(DISTINCT supplierId) + /// + [TestMethod] + public void SpecExample05_CountDistinct_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("count", "supplierId"); + Assert.AreEqual("count_supplierId", alias); + } + + /// + /// Spec Example 6: "Which categories have exactly 5 or 10 products?" + /// COUNT(*) GROUP BY categoryName HAVING IN (5, 10) + /// + [TestMethod] + public void SpecExample06_CountGroupByHavingIn_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + Assert.AreEqual("count", alias); + } + + /// + /// Spec Example 7: "Average distinct unit price per category, for categories averaging over $25" + /// AVG(DISTINCT unitPrice) GROUP BY categoryName HAVING gt 25 + /// + [TestMethod] + public void SpecExample07_AvgDistinctGroupByHavingGt_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + Assert.AreEqual("avg_unitPrice", alias); + } + + /// + /// Spec Example 8: "Which categories have the most products?" + /// COUNT(*) GROUP BY categoryName ORDER BY DESC + /// + [TestMethod] + public void SpecExample08_CountGroupByOrderByDesc_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + Assert.AreEqual("count", alias); + } + + /// + /// Spec Example 9: "What are the cheapest categories by average price?" + /// AVG(unitPrice) GROUP BY categoryName ORDER BY ASC + /// + [TestMethod] + public void SpecExample09_AvgGroupByOrderByAsc_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + Assert.AreEqual("avg_unitPrice", alias); + } + + /// + /// Spec Example 10: "For categories with over $500 revenue, which has the highest total?" + /// SUM(unitPrice) GROUP BY categoryName HAVING gt 500 ORDER BY DESC + /// + [TestMethod] + public void SpecExample10_SumFilterGroupByHavingGtOrderByDesc_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("sum", "unitPrice"); + Assert.AreEqual("sum_unitPrice", alias); + } + + /// + /// Spec Example 11: "Show me the first 5 categories by product count" + /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 + /// + [TestMethod] + public void SpecExample11_CountGroupByOrderByDescFirst5_CorrectAliasAndCursor() + { + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + Assert.AreEqual("count", alias); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); + } + + /// + /// Spec Example 12: "Show me the next 5 categories" (continuation of Example 11) + /// COUNT(*) GROUP BY categoryName ORDER BY DESC FIRST 5 AFTER cursor + /// + [TestMethod] + public void SpecExample12_CountGroupByOrderByDescFirst5After_CorrectCursorDecode() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); + int offset = AggregateRecordsTool.DecodeCursorOffset(cursor); + Assert.AreEqual(5, offset); + + string alias = AggregateRecordsTool.ComputeAlias("count", "*"); + Assert.AreEqual("count", alias); + } + + /// + /// Spec Example 13: "Show me the top 3 most expensive categories by average price" + /// AVG(unitPrice) GROUP BY categoryName ORDER BY DESC FIRST 3 + /// + [TestMethod] + public void SpecExample13_AvgGroupByOrderByDescFirst3_CorrectAlias() + { + string alias = AggregateRecordsTool.ComputeAlias("avg", "unitPrice"); + Assert.AreEqual("avg_unitPrice", alias); + } + + #endregion + + #region Blog Scenario Tests (devblogs.microsoft.com/azure-sql/data-api-builder-mcp-questions) + + // These tests verify that the exact JSON payloads from the DAB MCP blog + // pass input validation. The tool will fail at metadata resolution (no real DB) + // but must NOT return "InvalidArguments", proving the input shape is valid. + + /// + /// Blog Scenario 1: Strategic customer importance + /// "Who is our most important customer based on total revenue?" + /// Uses: sum, totalRevenue, filter, groupby [customerId, customerName], orderby desc, first 1 + /// + [TestMethod] + public async Task BlogScenario1_StrategicCustomerImportance_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""totalRevenue"", + ""filter"": ""isActive eq true and orderDate ge 2025-01-01"", + ""groupby"": [""customerId"", ""customerName""], + ""orderby"": ""desc"", + ""first"": 1 + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 1 JSON must pass input validation (sum/totalRevenue/groupby/orderby/first)."); + Assert.AreEqual("sum_totalRevenue", AggregateRecordsTool.ComputeAlias("sum", "totalRevenue")); + } + + /// + /// Blog Scenario 2: Product discontinuation candidate + /// "Which product should we consider discontinuing based on lowest totalRevenue?" + /// Uses: sum, totalRevenue, filter, groupby [productId, productName], orderby asc, first 1 + /// + [TestMethod] + public async Task BlogScenario2_ProductDiscontinuation_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""totalRevenue"", + ""filter"": ""isActive eq true and inStock gt 0 and orderDate ge 2025-01-01"", + ""groupby"": [""productId"", ""productName""], + ""orderby"": ""asc"", + ""first"": 1 + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 2 JSON must pass input validation (sum/totalRevenue/groupby/orderby asc/first)."); + Assert.AreEqual("sum_totalRevenue", AggregateRecordsTool.ComputeAlias("sum", "totalRevenue")); + } + + /// + /// Blog Scenario 3: Forward-looking performance expectation + /// "Average quarterlyRevenue per region, regions averaging > $2,000,000?" + /// Uses: avg, quarterlyRevenue, filter, groupby [region], having {gt: 2000000}, orderby desc + /// + [TestMethod] + public async Task BlogScenario3_QuarterlyPerformance_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""avg"", + ""field"": ""quarterlyRevenue"", + ""filter"": ""fiscalYear eq 2025"", + ""groupby"": [""region""], + ""having"": { ""gt"": 2000000 }, + ""orderby"": ""desc"" + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 3 JSON must pass input validation (avg/quarterlyRevenue/groupby/having gt)."); + Assert.AreEqual("avg_quarterlyRevenue", AggregateRecordsTool.ComputeAlias("avg", "quarterlyRevenue")); + } + + /// + /// Blog Scenario 4: Revenue concentration across regions + /// "Total revenue of active retail customers in Midwest/Southwest, >$5M, by region and customerTier" + /// Uses: sum, totalRevenue, complex filter with OR, groupby [region, customerTier], having {gt: 5000000}, orderby desc + /// + [TestMethod] + public async Task BlogScenario4_RevenueConcentration_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""totalRevenue"", + ""filter"": ""isActive eq true and customerType eq 'Retail' and (region eq 'Midwest' or region eq 'Southwest')"", + ""groupby"": [""region"", ""customerTier""], + ""having"": { ""gt"": 5000000 }, + ""orderby"": ""desc"" + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 4 JSON must pass input validation (sum/totalRevenue/complex filter/multi-groupby/having)."); + Assert.AreEqual("sum_totalRevenue", AggregateRecordsTool.ComputeAlias("sum", "totalRevenue")); + } + + /// + /// Blog Scenario 5: Risk exposure by product line + /// "For discontinued products, total onHandValue by productLine and warehouseRegion, >$2.5M" + /// Uses: sum, onHandValue, filter, groupby [productLine, warehouseRegion], having {gt: 2500000}, orderby desc + /// + [TestMethod] + public async Task BlogScenario5_RiskExposure_PassesInputValidation() + { + RuntimeConfig config = CreateConfig(); + IServiceProvider sp = CreateServiceProvider(config); + AggregateRecordsTool tool = new(); + + string json = @"{ + ""entity"": ""Book"", + ""function"": ""sum"", + ""field"": ""onHandValue"", + ""filter"": ""discontinued eq true and onHandValue gt 0"", + ""groupby"": [""productLine"", ""warehouseRegion""], + ""having"": { ""gt"": 2500000 }, + ""orderby"": ""desc"" + }"; + + JsonDocument args = JsonDocument.Parse(json); + CallToolResult result = await tool.ExecuteAsync(args, sp, CancellationToken.None); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string errorType = content.GetProperty("error").GetProperty("type").GetString()!; + Assert.AreNotEqual("InvalidArguments", errorType, + "Blog scenario 5 JSON must pass input validation (sum/onHandValue/filter/multi-groupby/having)."); + Assert.AreEqual("sum_onHandValue", AggregateRecordsTool.ComputeAlias("sum", "onHandValue")); + } + + /// + /// Verifies that the tool schema supports all properties used across the 5 blog scenarios. + /// + [TestMethod] + public void BlogScenarios_ToolSchema_SupportsAllRequiredProperties() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + JsonElement properties = metadata.InputSchema.GetProperty("properties"); + + string[] blogProperties = { "entity", "function", "field", "filter", "groupby", "orderby", "having", "first" }; + foreach (string prop in blogProperties) + { + Assert.IsTrue(properties.TryGetProperty(prop, out _), + $"Tool schema must include '{prop}' property used in blog scenarios."); + } + + // Additional schema properties used in spec but not blog + Assert.IsTrue(properties.TryGetProperty("distinct", out _), "Tool schema must include 'distinct'."); + Assert.IsTrue(properties.TryGetProperty("after", out _), "Tool schema must include 'after'."); + } + + /// + /// Verifies that the tool description instructs models to call describe_entities first. + /// + [TestMethod] + public void BlogScenarios_ToolDescription_ForcesDescribeEntitiesFirst() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + + Assert.IsTrue(metadata.Description!.Contains("describe_entities"), + "Tool description must instruct models to call describe_entities first."); + Assert.IsTrue(metadata.Description.Contains("1)"), + "Tool description must use numbered workflow steps."); + } + + /// + /// Verifies that the tool description documents the alias convention used in blog examples. + /// + [TestMethod] + public void BlogScenarios_ToolDescription_DocumentsAliasConvention() + { + AggregateRecordsTool tool = new(); + Tool metadata = tool.GetToolMetadata(); + + Assert.IsTrue(metadata.Description!.Contains("{function}_{field}"), + "Tool description must document the alias pattern '{function}_{field}'."); + Assert.IsTrue(metadata.Description.Contains("'count'"), + "Tool description must mention the special 'count' alias for count(*)."); + } + + #endregion + + #region FieldNotFound Error Helper Tests + + /// + /// Verifies the FieldNotFound error helper produces the correct error type + /// and a model-friendly message that includes the field name, entity, and guidance. + /// + [TestMethod] + public void FieldNotFound_ReturnsCorrectErrorTypeAndMessage() + { + CallToolResult result = McpErrorHelpers.FieldNotFound("aggregate_records", "Product", "badField", "field", null); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + JsonElement error = content.GetProperty("error"); + + Assert.AreEqual("FieldNotFound", error.GetProperty("type").GetString()); + string message = error.GetProperty("message").GetString()!; + Assert.IsTrue(message.Contains("badField"), "Message must include the invalid field name."); + Assert.IsTrue(message.Contains("Product"), "Message must include the entity name."); + Assert.IsTrue(message.Contains("field"), "Message must identify which parameter was invalid."); + Assert.IsTrue(message.Contains("describe_entities"), "Message must guide the model to call describe_entities."); + } + + /// + /// Verifies the FieldNotFound error helper identifies the groupby parameter. + /// + [TestMethod] + public void FieldNotFound_GroupBy_IdentifiesParameter() + { + CallToolResult result = McpErrorHelpers.FieldNotFound("aggregate_records", "Product", "invalidCol", "groupby", null); + + Assert.IsTrue(result.IsError == true); + JsonElement content = ParseContent(result); + string message = content.GetProperty("error").GetProperty("message").GetString()!; + + Assert.IsTrue(message.Contains("invalidCol"), "Message must include the invalid field name."); + Assert.IsTrue(message.Contains("groupby"), "Message must identify 'groupby' as the parameter."); + Assert.IsTrue(message.Contains("describe_entities"), "Message must guide the model to call describe_entities."); + } + + #endregion + + #region Helper Methods + + private static JsonElement ParseContent(CallToolResult result) + { + TextContentBlock firstContent = (TextContentBlock)result.Content[0]; + return JsonDocument.Parse(firstContent.Text).RootElement; + } + + private static void AssertToolDisabledError(JsonElement content) + { + Assert.IsTrue(content.TryGetProperty("error", out JsonElement error)); + Assert.IsTrue(error.TryGetProperty("type", out JsonElement errorType)); + Assert.AreEqual("ToolDisabled", errorType.GetString()); + } + + private static RuntimeConfig CreateConfig(bool aggregateRecordsEnabled = true) + { + Dictionary entities = new() + { + ["Book"] = new Entity( + Source: new("books", EntitySourceType.Table, null, null), + GraphQL: new("Book", "Books"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: null + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: aggregateRecordsEnabled + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + private static RuntimeConfig CreateConfigWithEntityDmlDisabled() + { + Dictionary entities = new() + { + ["Book"] = new Entity( + Source: new("books", EntitySourceType.Table, null, null), + GraphQL: new("Book", "Books"), + Fields: null, + Rest: new(Enabled: true), + Permissions: new[] { new EntityPermission(Role: "anonymous", Actions: new[] { + new EntityAction(Action: EntityActionOperation.Read, Fields: null, Policy: null) + }) }, + Mappings: null, + Relationships: null, + Mcp: new EntityMcpOptions(customToolEnabled: false, dmlToolsEnabled: false) + ) + }; + + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(entities) + ); + } + + private static IServiceProvider CreateServiceProvider(RuntimeConfig config) + { + ServiceCollection services = new(); + + RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + + Mock mockAuthResolver = new(); + mockAuthResolver.Setup(x => x.IsValidRoleContext(It.IsAny())).Returns(true); + services.AddSingleton(mockAuthResolver.Object); + + Mock mockHttpContext = new(); + Mock mockRequest = new(); + mockRequest.Setup(x => x.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]).Returns("anonymous"); + mockHttpContext.Setup(x => x.Request).Returns(mockRequest.Object); + + Mock mockHttpContextAccessor = new(); + mockHttpContextAccessor.Setup(x => x.HttpContext).Returns(mockHttpContext.Object); + services.AddSingleton(mockHttpContextAccessor.Object); + + services.AddLogging(); + + return services.BuildServiceProvider(); + } + + #endregion + } +} diff --git a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs index d2f6554cd3..b4ae074207 100644 --- a/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs +++ b/src/Service.Tests/Mcp/EntityLevelDmlToolConfigurationTests.cs @@ -48,6 +48,7 @@ public class EntityLevelDmlToolConfigurationTests [DataRow("UpdateRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}, \"fields\": {\"title\": \"Updated\"}}", false, DisplayName = "UpdateRecord respects entity-level DmlToolEnabled=false")] [DataRow("DeleteRecord", "{\"entity\": \"Book\", \"keys\": {\"id\": 1}}", false, DisplayName = "DeleteRecord respects entity-level DmlToolEnabled=false")] [DataRow("ExecuteEntity", "{\"entity\": \"GetBook\"}", true, DisplayName = "ExecuteEntity respects entity-level DmlToolEnabled=false")] + [DataRow("AggregateRecords", "{\"entity\": \"Book\", \"function\": \"count\", \"field\": \"*\"}", false, DisplayName = "AggregateRecords respects entity-level DmlToolEnabled=false")] public async Task DmlTool_RespectsEntityLevelDmlToolDisabled(string toolType, string jsonArguments, bool isStoredProcedure) { // Arrange @@ -238,6 +239,7 @@ private static IMcpTool CreateTool(string toolType) "UpdateRecord" => new UpdateRecordTool(), "DeleteRecord" => new DeleteRecordTool(), "ExecuteEntity" => new ExecuteEntityTool(), + "AggregateRecords" => new AggregateRecordsTool(), _ => throw new ArgumentException($"Unknown tool type: {toolType}", nameof(toolType)) }; } diff --git a/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs new file mode 100644 index 0000000000..237e40e57e --- /dev/null +++ b/src/Service.Tests/Mcp/McpQueryTimeoutTests.cs @@ -0,0 +1,450 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Mcp.Utils; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Service.Tests.Mcp +{ + /// + /// Tests for the MCP query-timeout configuration property. + /// Verifies: + /// - Default value of 30 seconds when not configured + /// - Custom value overrides default + /// - Timeout wrapping applies to all MCP tools via ExecuteWithTelemetryAsync + /// - Hot reload: changing config value updates behavior without restart + /// - Timeout surfaces as TimeoutException, not generic cancellation + /// - Telemetry maps timeout to TIMEOUT error code + /// + [TestClass] + public class McpQueryTimeoutTests + { + #region Default Value Tests + + [TestMethod] + public void McpRuntimeOptions_DefaultQueryTimeout_Is30Seconds() + { + Assert.AreEqual(30, McpRuntimeOptions.DEFAULT_QUERY_TIMEOUT_SECONDS); + } + + [TestMethod] + public void McpRuntimeOptions_EffectiveTimeout_ReturnsDefault_WhenNotConfigured() + { + McpRuntimeOptions options = new(); + Assert.IsNull(options.QueryTimeout); + Assert.AreEqual(30, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_EffectiveTimeout_ReturnsConfiguredValue() + { + McpRuntimeOptions options = new(QueryTimeout: 60); + Assert.AreEqual(60, options.QueryTimeout); + Assert.AreEqual(60, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_UserProvidedQueryTimeout_FalseByDefault() + { + McpRuntimeOptions options = new(); + Assert.IsFalse(options.UserProvidedQueryTimeout); + } + + [TestMethod] + public void McpRuntimeOptions_UserProvidedQueryTimeout_TrueWhenSet() + { + McpRuntimeOptions options = new(QueryTimeout: 45); + Assert.IsTrue(options.UserProvidedQueryTimeout); + } + + #endregion + + #region Custom Value Tests + + [TestMethod] + public void McpRuntimeOptions_CustomTimeout_1Second() + { + McpRuntimeOptions options = new(QueryTimeout: 1); + Assert.AreEqual(1, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_CustomTimeout_120Seconds() + { + McpRuntimeOptions options = new(QueryTimeout: 120); + Assert.AreEqual(120, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void RuntimeConfig_McpQueryTimeout_ExposedInConfig() + { + RuntimeConfig config = CreateConfigWithQueryTimeout(45); + Assert.AreEqual(45, config.Runtime?.Mcp?.QueryTimeout); + Assert.AreEqual(45, config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void RuntimeConfig_McpQueryTimeout_DefaultWhenNotSet() + { + RuntimeConfig config = CreateConfigWithoutQueryTimeout(); + Assert.IsNull(config.Runtime?.Mcp?.QueryTimeout); + Assert.AreEqual(30, config.Runtime?.Mcp?.EffectiveQueryTimeoutSeconds); + } + + #endregion + + #region Timeout Wrapping Tests + + [TestMethod] + public async Task ExecuteWithTelemetry_CompletesSuccessfully_WithinTimeout() + { + // A tool that completes immediately should succeed + RuntimeConfig config = CreateConfigWithQueryTimeout(30); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new ImmediateCompletionTool(); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "test_tool", null, sp, CancellationToken.None); + + // Tool should complete without throwing TimeoutException + Assert.IsNotNull(result); + Assert.IsTrue(result.IsError != true, "Tool result should not be an error"); + } + + [TestMethod] + public async Task ExecuteWithTelemetry_ThrowsTimeoutException_WhenToolExceedsTimeout() + { + // Configure a very short timeout (1 second) and a tool that takes longer + RuntimeConfig config = CreateConfigWithQueryTimeout(1); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new SlowTool(delaySeconds: 30); + + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "slow_tool", null, sp, CancellationToken.None); + }); + } + + [TestMethod] + public async Task ExecuteWithTelemetry_TimeoutMessage_ContainsToolName() + { + RuntimeConfig config = CreateConfigWithQueryTimeout(1); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new SlowTool(delaySeconds: 30); + + try + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "aggregate_records", null, sp, CancellationToken.None); + Assert.Fail("Expected TimeoutException"); + } + catch (TimeoutException ex) + { + Assert.IsTrue(ex.Message.Contains("aggregate_records"), "Message should contain tool name"); + Assert.IsTrue(ex.Message.Contains("1 second"), "Message should contain timeout value"); + Assert.IsTrue(ex.Message.Contains("NOT a tool error"), "Message should clarify it is not a tool error"); + } + } + + [TestMethod] + public async Task ExecuteWithTelemetry_ClientCancellation_PropagatesAsCancellation() + { + // Client cancellation (not timeout) should propagate as OperationCanceledException + // rather than being converted to TimeoutException. + RuntimeConfig config = CreateConfigWithQueryTimeout(30); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + IMcpTool tool = new SlowTool(delaySeconds: 30); + + using CancellationTokenSource cts = new(); + cts.Cancel(); // Cancel immediately + + try + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "test_tool", null, sp, cts.Token); + Assert.Fail("Expected OperationCanceledException or subclass to be thrown"); + } + catch (TimeoutException) + { + Assert.Fail("Client cancellation should NOT be converted to TimeoutException"); + } + catch (OperationCanceledException) + { + // Expected: client-initiated cancellation propagates as OperationCanceledException + // (or subclass TaskCanceledException) + } + } + + [TestMethod] + public async Task ExecuteWithTelemetry_AppliesTimeout_ToAllToolTypes() + { + // Verify timeout applies to both built-in and custom tool types + RuntimeConfig config = CreateConfigWithQueryTimeout(1); + IServiceProvider sp = CreateServiceProviderWithConfig(config); + + // Test with built-in tool type + IMcpTool builtInTool = new SlowTool(delaySeconds: 30, toolType: ToolType.BuiltIn); + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + builtInTool, "builtin_slow", null, sp, CancellationToken.None); + }); + + // Test with custom tool type + IMcpTool customTool = new SlowTool(delaySeconds: 30, toolType: ToolType.Custom); + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + customTool, "custom_slow", null, sp, CancellationToken.None); + }); + } + + #endregion + + #region Hot Reload Tests + + [TestMethod] + public async Task ExecuteWithTelemetry_ReadsConfigPerInvocation_HotReload() + { + // First invocation with long timeout should succeed + RuntimeConfig config1 = CreateConfigWithQueryTimeout(30); + IServiceProvider sp1 = CreateServiceProviderWithConfig(config1); + + IMcpTool fastTool = new ImmediateCompletionTool(); + CallToolResult result1 = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + fastTool, "test_tool", null, sp1, CancellationToken.None); + Assert.IsNotNull(result1); + + // Second invocation with very short timeout and a slow tool should timeout. + // This demonstrates that each invocation reads the current config value. + RuntimeConfig config2 = CreateConfigWithQueryTimeout(1); + IServiceProvider sp2 = CreateServiceProviderWithConfig(config2); + + IMcpTool slowTool = new SlowTool(delaySeconds: 30); + await Assert.ThrowsExceptionAsync(async () => + { + await McpTelemetryHelper.ExecuteWithTelemetryAsync( + slowTool, "test_tool", null, sp2, CancellationToken.None); + }); + } + + #endregion + + #region Telemetry Tests + + [TestMethod] + public void MapExceptionToErrorCode_TimeoutException_ReturnsTIMEOUT() + { + string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new TimeoutException()); + Assert.AreEqual(McpTelemetryErrorCodes.TIMEOUT, errorCode); + } + + [TestMethod] + public void MapExceptionToErrorCode_OperationCanceled_ReturnsOPERATION_CANCELLED() + { + string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new OperationCanceledException()); + Assert.AreEqual(McpTelemetryErrorCodes.OPERATION_CANCELLED, errorCode); + } + + [TestMethod] + public void MapExceptionToErrorCode_TaskCanceled_ReturnsOPERATION_CANCELLED() + { + string errorCode = McpTelemetryHelper.MapExceptionToErrorCode(new TaskCanceledException()); + Assert.AreEqual(McpTelemetryErrorCodes.OPERATION_CANCELLED, errorCode); + } + + #endregion + + #region JSON Serialization Tests + + [TestMethod] + public void McpRuntimeOptions_Serialization_IncludesQueryTimeout_WhenUserProvided() + { + McpRuntimeOptions options = new(QueryTimeout: 45); + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(); + string json = JsonSerializer.Serialize(options, serializerOptions); + Assert.IsTrue(json.Contains("\"query-timeout\": 45") || json.Contains("\"query-timeout\":45")); + } + + [TestMethod] + public void McpRuntimeOptions_Deserialization_ReadsQueryTimeout() + { + string json = @"{""enabled"": true, ""query-timeout"": 60}"; + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(); + McpRuntimeOptions? options = JsonSerializer.Deserialize(json, serializerOptions); + Assert.IsNotNull(options); + Assert.AreEqual(60, options.QueryTimeout); + Assert.AreEqual(60, options.EffectiveQueryTimeoutSeconds); + } + + [TestMethod] + public void McpRuntimeOptions_Deserialization_DefaultsWhenOmitted() + { + string json = @"{""enabled"": true}"; + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(); + McpRuntimeOptions? options = JsonSerializer.Deserialize(json, serializerOptions); + Assert.IsNotNull(options); + Assert.IsNull(options.QueryTimeout); + Assert.AreEqual(30, options.EffectiveQueryTimeoutSeconds); + } + + #endregion + + #region Helpers + + private static RuntimeConfig CreateConfigWithQueryTimeout(int timeoutSeconds) + { + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + QueryTimeout: timeoutSeconds, + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(new Dictionary()) + ); + } + + private static RuntimeConfig CreateConfigWithoutQueryTimeout() + { + return new RuntimeConfig( + Schema: "test-schema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, ConnectionString: "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: new( + describeEntities: true, + readRecords: true, + createRecord: true, + updateRecord: true, + deleteRecord: true, + executeEntity: true, + aggregateRecords: true + ) + ), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Development) + ), + Entities: new(new Dictionary()) + ); + } + + private static IServiceProvider CreateServiceProviderWithConfig(RuntimeConfig config) + { + ServiceCollection services = new(); + RuntimeConfigProvider configProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + services.AddLogging(); + return services.BuildServiceProvider(); + } + + /// + /// A mock tool that completes immediately with a success result. + /// + private class ImmediateCompletionTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "test_tool", + Description = "A test tool that completes immediately", + InputSchema = doc.RootElement.Clone() + }; + } + + public Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + return Task.FromResult(new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"success\"}" } + } + }); + } + } + + /// + /// A mock tool that delays for a specified duration, respecting cancellation. + /// Used to test timeout behavior. + /// + private class SlowTool : IMcpTool + { + private readonly int _delaySeconds; + + public SlowTool(int delaySeconds, ToolType toolType = ToolType.BuiltIn) + { + _delaySeconds = delaySeconds; + ToolType = toolType; + } + + public ToolType ToolType { get; } + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "slow_tool", + Description = "A test tool that takes a long time", + InputSchema = doc.RootElement.Clone() + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + await Task.Delay(TimeSpan.FromSeconds(_delaySeconds), cancellationToken); + return new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"completed\"}" } + } + }; + } + } + + #endregion + } +} diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt index 9279da9d59..15f242605f 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForCosmos.verified.txt @@ -28,14 +28,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt index 35fd562c87..966af2777f 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMsSql.verified.txt @@ -32,14 +32,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt index 1490309ece..0779215cd0 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForMySql.verified.txt @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt index ceba40ae63..75077c22fa 100644 --- a/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt +++ b/src/Service.Tests/Snapshots/ConfigurationTests.TestReadingRuntimeConfigForPostgreSql.verified.txt @@ -24,14 +24,18 @@ UpdateRecord: true, DeleteRecord: true, ExecuteEntity: true, + AggregateRecords: true, UserProvidedAllTools: false, UserProvidedDescribeEntities: false, UserProvidedCreateRecord: false, UserProvidedReadRecords: false, UserProvidedUpdateRecord: false, UserProvidedDeleteRecord: false, - UserProvidedExecuteEntity: false - } + UserProvidedExecuteEntity: false, + UserProvidedAggregateRecords: false + }, + UserProvidedQueryTimeout: false, + EffectiveQueryTimeoutSeconds: 30 }, Host: { Cors: { diff --git a/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs new file mode 100644 index 0000000000..c291d87660 --- /dev/null +++ b/src/Service.Tests/UnitTests/AggregateRecordsToolTests.cs @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Text; +using Azure.DataApiBuilder.Mcp.BuiltInTools; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + /// + /// Unit tests for AggregateRecordsTool helper methods. + /// Validates alias computation, cursor decoding, and input validation logic. + /// SQL generation is delegated to the engine's query builder (GroupByMetadata/AggregationColumn). + /// + [TestClass] + public class AggregateRecordsToolTests + { + #region ComputeAlias tests + + [TestMethod] + [DataRow("count", "*", "count", DisplayName = "count(*) alias is 'count'")] + [DataRow("count", "userId", "count_userId", DisplayName = "count(field) alias is 'count_field'")] + [DataRow("avg", "price", "avg_price", DisplayName = "avg alias")] + [DataRow("sum", "amount", "sum_amount", DisplayName = "sum alias")] + [DataRow("min", "age", "min_age", DisplayName = "min alias")] + [DataRow("max", "score", "max_score", DisplayName = "max alias")] + public void ComputeAlias_ReturnsExpectedAlias(string function, string field, string expectedAlias) + { + string result = AggregateRecordsTool.ComputeAlias(function, field); + Assert.AreEqual(expectedAlias, result); + } + + #endregion + + #region DecodeCursorOffset tests + + [TestMethod] + public void DecodeCursorOffset_NullCursor_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(null)); + } + + [TestMethod] + public void DecodeCursorOffset_EmptyCursor_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("")); + } + + [TestMethod] + public void DecodeCursorOffset_ValidBase64_ReturnsOffset() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("5")); + Assert.AreEqual(5, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_InvalidBase64_ReturnsZero() + { + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset("not-valid-base64!!")); + } + + [TestMethod] + public void DecodeCursorOffset_NonNumericBase64_ReturnsZero() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("abc")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_RoundTrip_FirstPage() + { + int offset = 3; + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(offset.ToString())); + Assert.AreEqual(offset, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + [TestMethod] + public void DecodeCursorOffset_NegativeValue_ReturnsZero() + { + string cursor = Convert.ToBase64String(Encoding.UTF8.GetBytes("-5")); + Assert.AreEqual(0, AggregateRecordsTool.DecodeCursorOffset(cursor)); + } + + #endregion + + #region Validation logic tests + + [TestMethod] + [DataRow("avg", "Validation: avg with star field should be rejected")] + [DataRow("sum", "Validation: sum with star field should be rejected")] + [DataRow("min", "Validation: min with star field should be rejected")] + [DataRow("max", "Validation: max with star field should be rejected")] + public void ValidateFieldFunctionCompat_StarWithNumericFunction_IsInvalid(string function, string description) + { + bool isCountStar = function == "count" && "*" == "*"; + bool isInvalidStarUsage = "*" == "*" && function != "count"; + + Assert.IsFalse(isCountStar, $"{description}: should not be count-star"); + Assert.IsTrue(isInvalidStarUsage, $"{description}: should be identified as invalid star usage"); + } + + [TestMethod] + public void ValidateFieldFunctionCompat_CountStar_IsValid() + { + bool isCountStar = "count" == "count" && "*" == "*"; + Assert.IsTrue(isCountStar, "count(*) should be valid"); + } + + [TestMethod] + public void ValidateDistinctCountStar_IsInvalid() + { + bool isCountStar = "count" == "count" && "*" == "*"; + bool distinct = true; + + bool shouldReject = isCountStar && distinct; + Assert.IsTrue(shouldReject, "count(*) with distinct=true should be rejected"); + } + + [TestMethod] + public void ValidateDistinctCountField_IsValid() + { + bool isCountStar = "count" == "count" && "userId" == "*"; + bool distinct = true; + + bool shouldReject = isCountStar && distinct; + Assert.IsFalse(shouldReject, "count(field) with distinct=true should be valid"); + } + + #endregion + + #region Blog scenario tests - alias and type validation + + /// + /// Blog Example 1: Strategic customer importance + /// "Who is our most important customer based on total revenue?" + /// SUM(totalRevenue) grouped by customerId, customerName, ORDER BY DESC, FIRST 1 + /// + [TestMethod] + public void BlogScenario_StrategicCustomerImportance_AliasAndTypeCorrect() + { + string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); + Assert.AreEqual("sum_totalRevenue", alias); + } + + /// + /// Blog Example 2: Product discontinuation candidate + /// Lowest totalRevenue with orderby=asc, first=1 + /// + [TestMethod] + public void BlogScenario_ProductDiscontinuation_AliasAndTypeCorrect() + { + string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); + Assert.AreEqual("sum_totalRevenue", alias); + } + + /// + /// Blog Example 3: Forward-looking performance expectation + /// AVG quarterlyRevenue with HAVING gt 2000000 + /// + [TestMethod] + public void BlogScenario_QuarterlyPerformance_AliasAndTypeCorrect() + { + string alias = AggregateRecordsTool.ComputeAlias("avg", "quarterlyRevenue"); + Assert.AreEqual("avg_quarterlyRevenue", alias); + } + + /// + /// Blog Example 4: Revenue concentration across regions + /// SUM totalRevenue grouped by region and customerTier, HAVING gt 5000000 + /// + [TestMethod] + public void BlogScenario_RevenueConcentration_AliasAndTypeCorrect() + { + string alias = AggregateRecordsTool.ComputeAlias("sum", "totalRevenue"); + Assert.AreEqual("sum_totalRevenue", alias); + } + + /// + /// Blog Example 5: Risk exposure by product line + /// SUM onHandValue grouped by productLine and warehouseRegion, HAVING gt 2500000 + /// + [TestMethod] + public void BlogScenario_RiskExposure_AliasAndTypeCorrect() + { + string alias = AggregateRecordsTool.ComputeAlias("sum", "onHandValue"); + Assert.AreEqual("sum_onHandValue", alias); + } + + #endregion + } +} diff --git a/src/Service.Tests/UnitTests/McpTelemetryTests.cs b/src/Service.Tests/UnitTests/McpTelemetryTests.cs index 9a8130f012..1b7c66fbcc 100644 --- a/src/Service.Tests/UnitTests/McpTelemetryTests.cs +++ b/src/Service.Tests/UnitTests/McpTelemetryTests.cs @@ -17,7 +17,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using ModelContextProtocol.Protocol; using static Azure.DataApiBuilder.Mcp.Model.McpEnums; - namespace Azure.DataApiBuilder.Service.Tests.UnitTests { /// @@ -337,6 +336,98 @@ public async Task ExecuteWithTelemetryAsync_RecordsExceptionAndRethrows_WhenTool Assert.IsNotNull(exceptionEvent, "Exception event should be recorded"); } + /// + /// Test that ExecuteWithTelemetryAsync applies the configured query-timeout and throws TimeoutException + /// when a tool exceeds the configured timeout. + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_ThrowsTimeoutException_WhenToolExceedsTimeout() + { + // Use a 1-second timeout with a tool that takes 10 seconds + IServiceProvider serviceProvider = CreateServiceProviderWithTimeout(queryTimeoutSeconds: 1); + IMcpTool tool = new SlowTool(delaySeconds: 10); + + TimeoutException thrownEx = await Assert.ThrowsExceptionAsync( + () => McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "aggregate_records", arguments: null, serviceProvider, CancellationToken.None)); + + Assert.IsTrue(thrownEx.Message.Contains("aggregate_records"), "Exception message should contain tool name"); + Assert.IsTrue(thrownEx.Message.Contains("1 second"), "Exception message should contain timeout duration"); + } + + /// + /// Test that ExecuteWithTelemetryAsync succeeds when tool completes before the timeout. + /// + [TestMethod] + public async Task ExecuteWithTelemetryAsync_Succeeds_WhenToolCompletesBeforeTimeout() + { + // Use a 30-second timeout with a tool that completes immediately + IServiceProvider serviceProvider = CreateServiceProviderWithTimeout(queryTimeoutSeconds: 30); + IMcpTool tool = new ImmediateCompletionTool(); + + CallToolResult result = await McpTelemetryHelper.ExecuteWithTelemetryAsync( + tool, "aggregate_records", arguments: null, serviceProvider, CancellationToken.None); + + Assert.IsNotNull(result); + Assert.IsFalse(result.IsError == true); + } + + /// + /// Test that aggregate_records tool name maps to "aggregate" operation. + /// + [TestMethod] + public void InferOperationFromTool_AggregateRecords_ReturnsAggregate() + { + CallToolResult dummyResult = CreateToolResult("ok"); + IMcpTool tool = new MockMcpTool(dummyResult, ToolType.BuiltIn); + + string operation = McpTelemetryHelper.InferOperationFromTool(tool, "aggregate_records"); + + Assert.AreEqual("aggregate", operation); + } + + #endregion + + #region Helpers for timeout tests + + /// + /// Creates a service provider with a RuntimeConfigProvider configured with the given timeout. + /// + private static IServiceProvider CreateServiceProviderWithTimeout(int queryTimeoutSeconds) + { + Azure.DataApiBuilder.Config.ObjectModel.RuntimeConfig config = CreateConfigWithQueryTimeout(queryTimeoutSeconds); + ServiceCollection services = new(); + Azure.DataApiBuilder.Core.Configurations.RuntimeConfigProvider configProvider = + TestHelper.GenerateInMemoryRuntimeConfigProvider(config); + services.AddSingleton(configProvider); + services.AddLogging(); + return services.BuildServiceProvider(); + } + + private static Azure.DataApiBuilder.Config.ObjectModel.RuntimeConfig CreateConfigWithQueryTimeout(int queryTimeoutSeconds) + { + return new Azure.DataApiBuilder.Config.ObjectModel.RuntimeConfig( + Schema: "test-schema", + DataSource: new Azure.DataApiBuilder.Config.ObjectModel.DataSource( + DatabaseType: Azure.DataApiBuilder.Config.ObjectModel.DatabaseType.MSSQL, + ConnectionString: "", + Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new( + Enabled: true, + Path: "/mcp", + DmlTools: null, + Description: null, + QueryTimeout: queryTimeoutSeconds + ), + Host: new(Cors: null, Authentication: null, Mode: Azure.DataApiBuilder.Config.ObjectModel.HostMode.Development) + ), + Entities: new(new System.Collections.Generic.Dictionary()) + ); + } + #endregion #region Test Mocks @@ -377,6 +468,81 @@ public Task ExecuteAsync(JsonDocument? arguments, IServiceProvid } } + /// + /// A mock tool that completes immediately with a success result. + /// + private class ImmediateCompletionTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "test_tool", + Description = "A test tool that completes immediately", + InputSchema = doc.RootElement.Clone() + }; + } + + public Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + return Task.FromResult(new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"success\"}" } + } + }); + } + } + + /// + /// A mock tool that delays for a specified duration, respecting cancellation. + /// Used to test timeout behavior. + /// + private class SlowTool : IMcpTool + { + private readonly int _delaySeconds; + + public SlowTool(int delaySeconds) + { + _delaySeconds = delaySeconds; + } + + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + using JsonDocument doc = JsonDocument.Parse("{\"type\": \"object\"}"); + return new Tool + { + Name = "slow_tool", + Description = "A test tool that takes a long time", + InputSchema = doc.RootElement.Clone() + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + await Task.Delay(TimeSpan.FromSeconds(_delaySeconds), cancellationToken); + return new CallToolResult + { + Content = new List + { + new TextContentBlock { Text = "{\"result\": \"completed\"}" } + } + }; + } + } + #endregion } } diff --git a/src/Service/Utilities/McpStdioHelper.cs b/src/Service/Utilities/McpStdioHelper.cs index 043e9dd85d..f22e12b02f 100644 --- a/src/Service/Utilities/McpStdioHelper.cs +++ b/src/Service/Utilities/McpStdioHelper.cs @@ -78,15 +78,8 @@ public static bool RunMcpStdioHost(IHost host) { host.Start(); - Mcp.Core.McpToolRegistry registry = - host.Services.GetRequiredService(); - IEnumerable tools = - host.Services.GetServices(); - - foreach (Mcp.Model.IMcpTool tool in tools) - { - registry.RegisterTool(tool); - } + // Tools are already registered by McpToolRegistryInitializer (IHostedService) + // during host.Start(). No need to register them again here. IHostApplicationLifetime lifetime = host.Services.GetRequiredService();