diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 4fcae24788..3ccd79d536 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -2,13 +2,13 @@ import os import inspect -from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload +from typing import Any, Type, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload from typing_extensions import Self, override import httpx from ..auth import WorkloadIdentity -from .._types import NOT_GIVEN, Omit, Query, Headers, Timeout, NotGiven +from .._types import NOT_GIVEN, Omit, Query, Headers, Timeout, NotGiven, ResponseT from .._utils import is_given, is_mapping from .._client import OpenAI, AsyncOpenAI from .._compat import model_copy @@ -412,6 +412,44 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL url = realtime_url.copy_with(params={**query}) return url, auth_headers + @override + def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, + ) -> ResponseT: + result = super()._process_response( + cast_to=cast_to, + options=options, + response=response, + stream=stream, + stream_cls=stream_cls, + retries_taken=retries_taken, + ) + if options.url.startswith("/responses"): + served_model = response.headers.get("x-ms-served-model", "").strip() + if served_model: + from ..types.responses.response import Response as _ResponseType + + if not stream and isinstance(result, _ResponseType): + result.model = served_model + elif stream and isinstance(result, Stream): + _orig_iter = result._iterator + + def _patched_iter(orig: Any = _orig_iter, model: str = served_model): # type: ignore[return] + for event in orig: + if hasattr(event, "response") and isinstance(event.response, _ResponseType): + event.response.model = model + yield event + + result._iterator = _patched_iter() + return result + class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI): @overload @@ -733,3 +771,41 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt url = realtime_url.copy_with(params={**query}) return url, auth_headers + + @override + async def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, + ) -> ResponseT: + result = await super()._process_response( + cast_to=cast_to, + options=options, + response=response, + stream=stream, + stream_cls=stream_cls, + retries_taken=retries_taken, + ) + if options.url.startswith("/responses"): + served_model = response.headers.get("x-ms-served-model", "").strip() + if served_model: + from ..types.responses.response import Response as _ResponseType + + if not stream and isinstance(result, _ResponseType): + result.model = served_model + elif stream and isinstance(result, AsyncStream): + _orig_iter = result._iterator + + async def _patched_iter(orig: Any = _orig_iter, model: str = served_model): # type: ignore[return] + async for event in orig: + if hasattr(event, "response") and isinstance(event.response, _ResponseType): + event.response.model = model + yield event + + result._iterator = _patched_iter() + return result diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 3e1d783e2c..96ee08c199 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -14,6 +14,7 @@ from openai._utils import SensitiveHeadersFilter, is_dict from openai._models import FinalRequestOptions from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI +from openai.types.responses.response import Response as _OAIResponse Client = Union[AzureOpenAI, AsyncAzureOpenAI] @@ -953,3 +954,126 @@ def test_client_sets_base_url(client: Client) -> None: ) ) assert req.url == "https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01" + + +# --------------------------------------------------------------------------- +# Tests for x-ms-served-model header promotion into Response.model (#3271) +# --------------------------------------------------------------------------- + +_MINIMAL_RESPONSE_JSON = { + "id": "resp_test001", + "created_at": 1700000000.0, + "model": "gpt-4o", + "object": "response", + "output": [], + "parallel_tool_calls": False, + "tool_choice": "auto", + "tools": [], +} + +_RESPONSES_URL = "https://example-resource.azure.openai.com/openai/responses?api-version=2024-02-01" + + +def _make_response(status: int = 200, headers: dict | None = None) -> httpx.Response: + req = httpx.Request("POST", _RESPONSES_URL) + resp = httpx.Response(status, json=_MINIMAL_RESPONSE_JSON, headers=headers or {}) + resp.request = req + return resp + + +def _make_response_for_url(url: str, headers: dict | None = None) -> httpx.Response: + req = httpx.Request("POST", url) + resp = httpx.Response(200, json=_MINIMAL_RESPONSE_JSON, headers=headers or {}) + resp.request = req + return resp + + +def test_azure_responses_x_ms_served_model_promoted() -> None: + """x-ms-served-model header is promoted into Response.model for non-streaming calls.""" + client = AzureOpenAI( + api_version="2024-02-01", + api_key="test-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + result = client._process_response( + cast_to=_OAIResponse, + options=FinalRequestOptions.construct(method="post", url="/responses"), + response=_make_response(headers={"x-ms-served-model": "gpt-4o-2024-11-20"}), + stream=False, + stream_cls=None, + ) + assert isinstance(result, _OAIResponse) + assert result.model == "gpt-4o-2024-11-20" + + +def test_azure_responses_x_ms_served_model_absent() -> None: + """Without x-ms-served-model, Response.model is unchanged.""" + client = AzureOpenAI( + api_version="2024-02-01", + api_key="test-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + result = client._process_response( + cast_to=_OAIResponse, + options=FinalRequestOptions.construct(method="post", url="/responses"), + response=_make_response(), + stream=False, + stream_cls=None, + ) + assert isinstance(result, _OAIResponse) + assert result.model == "gpt-4o" + + +def test_azure_non_responses_endpoint_not_affected() -> None: + """x-ms-served-model on a non-/responses URL should not be promoted.""" + client = AzureOpenAI( + api_version="2024-02-01", + api_key="test-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + chat_url = "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions" + result = client._process_response( + cast_to=_OAIResponse, + options=FinalRequestOptions.construct(method="post", url="/deployments/gpt-4/chat/completions"), + response=_make_response_for_url(chat_url, headers={"x-ms-served-model": "should-be-ignored"}), + stream=False, + stream_cls=None, + ) + assert isinstance(result, _OAIResponse) + assert result.model == "gpt-4o" + + +async def test_async_azure_responses_x_ms_served_model_promoted() -> None: + """Async client: x-ms-served-model is promoted into Response.model.""" + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key="test-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + result = await client._process_response( + cast_to=_OAIResponse, + options=FinalRequestOptions.construct(method="post", url="/responses"), + response=_make_response(headers={"x-ms-served-model": "gpt-4o-2024-11-20"}), + stream=False, + stream_cls=None, + ) + assert isinstance(result, _OAIResponse) + assert result.model == "gpt-4o-2024-11-20" + + +async def test_async_azure_responses_x_ms_served_model_absent() -> None: + """Async client: without header, Response.model is unchanged.""" + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key="test-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + result = await client._process_response( + cast_to=_OAIResponse, + options=FinalRequestOptions.construct(method="post", url="/responses"), + response=_make_response(), + stream=False, + stream_cls=None, + ) + assert isinstance(result, _OAIResponse) + assert result.model == "gpt-4o"