Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing import Any, Type, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing_extensions import Self, override

import httpx

from ..auth import WorkloadIdentity
from .._types import NOT_GIVEN, Omit, Query, Headers, Timeout, NotGiven
from .._types import NOT_GIVEN, Omit, Query, Headers, Timeout, NotGiven, ResponseT
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
Expand Down Expand Up @@ -412,6 +412,44 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL
url = realtime_url.copy_with(params={**query})
return url, auth_headers

@override
def _process_response(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
response: httpx.Response,
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
retries_taken: int = 0,
) -> ResponseT:
result = super()._process_response(
cast_to=cast_to,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=retries_taken,
)
if options.url.startswith("/responses"):
served_model = response.headers.get("x-ms-served-model", "").strip()
if served_model:
from ..types.responses.response import Response as _ResponseType

if not stream and isinstance(result, _ResponseType):
result.model = served_model
elif stream and isinstance(result, Stream):
_orig_iter = result._iterator

def _patched_iter(orig: Any = _orig_iter, model: str = served_model): # type: ignore[return]
for event in orig:
if hasattr(event, "response") and isinstance(event.response, _ResponseType):
event.response.model = model
yield event

result._iterator = _patched_iter()
return result


class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@overload
Expand Down Expand Up @@ -733,3 +771,41 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt

url = realtime_url.copy_with(params={**query})
return url, auth_headers

@override
async def _process_response(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
response: httpx.Response,
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
retries_taken: int = 0,
) -> ResponseT:
result = await super()._process_response(
cast_to=cast_to,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=retries_taken,
)
if options.url.startswith("/responses"):
served_model = response.headers.get("x-ms-served-model", "").strip()
if served_model:
from ..types.responses.response import Response as _ResponseType

if not stream and isinstance(result, _ResponseType):
result.model = served_model
elif stream and isinstance(result, AsyncStream):
_orig_iter = result._iterator

async def _patched_iter(orig: Any = _orig_iter, model: str = served_model): # type: ignore[return]
async for event in orig:
if hasattr(event, "response") and isinstance(event.response, _ResponseType):
event.response.model = model
yield event

result._iterator = _patched_iter()
return result
124 changes: 124 additions & 0 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from openai._utils import SensitiveHeadersFilter, is_dict
from openai._models import FinalRequestOptions
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
from openai.types.responses.response import Response as _OAIResponse

Client = Union[AzureOpenAI, AsyncAzureOpenAI]

Expand Down Expand Up @@ -953,3 +954,126 @@ def test_client_sets_base_url(client: Client) -> None:
)
)
assert req.url == "https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"


# ---------------------------------------------------------------------------
# Tests for x-ms-served-model header promotion into Response.model (#3271)
# ---------------------------------------------------------------------------

_MINIMAL_RESPONSE_JSON = {
"id": "resp_test001",
"created_at": 1700000000.0,
"model": "gpt-4o",
"object": "response",
"output": [],
"parallel_tool_calls": False,
"tool_choice": "auto",
"tools": [],
}

_RESPONSES_URL = "https://example-resource.azure.openai.com/openai/responses?api-version=2024-02-01"


def _make_response(status: int = 200, headers: dict | None = None) -> httpx.Response:
req = httpx.Request("POST", _RESPONSES_URL)
resp = httpx.Response(status, json=_MINIMAL_RESPONSE_JSON, headers=headers or {})
resp.request = req
return resp


def _make_response_for_url(url: str, headers: dict | None = None) -> httpx.Response:
req = httpx.Request("POST", url)
resp = httpx.Response(200, json=_MINIMAL_RESPONSE_JSON, headers=headers or {})
resp.request = req
return resp


def test_azure_responses_x_ms_served_model_promoted() -> None:
"""x-ms-served-model header is promoted into Response.model for non-streaming calls."""
client = AzureOpenAI(
api_version="2024-02-01",
api_key="test-key",
azure_endpoint="https://example-resource.azure.openai.com",
)
result = client._process_response(
cast_to=_OAIResponse,
options=FinalRequestOptions.construct(method="post", url="/responses"),
response=_make_response(headers={"x-ms-served-model": "gpt-4o-2024-11-20"}),
stream=False,
stream_cls=None,
)
assert isinstance(result, _OAIResponse)
assert result.model == "gpt-4o-2024-11-20"


def test_azure_responses_x_ms_served_model_absent() -> None:
"""Without x-ms-served-model, Response.model is unchanged."""
client = AzureOpenAI(
api_version="2024-02-01",
api_key="test-key",
azure_endpoint="https://example-resource.azure.openai.com",
)
result = client._process_response(
cast_to=_OAIResponse,
options=FinalRequestOptions.construct(method="post", url="/responses"),
response=_make_response(),
stream=False,
stream_cls=None,
)
assert isinstance(result, _OAIResponse)
assert result.model == "gpt-4o"


def test_azure_non_responses_endpoint_not_affected() -> None:
"""x-ms-served-model on a non-/responses URL should not be promoted."""
client = AzureOpenAI(
api_version="2024-02-01",
api_key="test-key",
azure_endpoint="https://example-resource.azure.openai.com",
)
chat_url = "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions"
result = client._process_response(
cast_to=_OAIResponse,
options=FinalRequestOptions.construct(method="post", url="/deployments/gpt-4/chat/completions"),
response=_make_response_for_url(chat_url, headers={"x-ms-served-model": "should-be-ignored"}),
stream=False,
stream_cls=None,
)
assert isinstance(result, _OAIResponse)
assert result.model == "gpt-4o"


async def test_async_azure_responses_x_ms_served_model_promoted() -> None:
"""Async client: x-ms-served-model is promoted into Response.model."""
client = AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="test-key",
azure_endpoint="https://example-resource.azure.openai.com",
)
result = await client._process_response(
cast_to=_OAIResponse,
options=FinalRequestOptions.construct(method="post", url="/responses"),
response=_make_response(headers={"x-ms-served-model": "gpt-4o-2024-11-20"}),
stream=False,
stream_cls=None,
)
assert isinstance(result, _OAIResponse)
assert result.model == "gpt-4o-2024-11-20"


async def test_async_azure_responses_x_ms_served_model_absent() -> None:
"""Async client: without header, Response.model is unchanged."""
client = AsyncAzureOpenAI(
api_version="2024-02-01",
api_key="test-key",
azure_endpoint="https://example-resource.azure.openai.com",
)
result = await client._process_response(
cast_to=_OAIResponse,
options=FinalRequestOptions.construct(method="post", url="/responses"),
response=_make_response(),
stream=False,
stream_cls=None,
)
assert isinstance(result, _OAIResponse)
assert result.model == "gpt-4o"