diff --git a/src/stagehand/_client.py b/src/stagehand/_client.py index 3997b5e1..fd1f8300 100644 --- a/src/stagehand/_client.py +++ b/src/stagehand/_client.py @@ -24,7 +24,7 @@ from ._models import FinalRequestOptions from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream -from ._exceptions import APIStatusError, StagehandError +from ._exceptions import APIStatusError from ._base_client import ( DEFAULT_MAX_RETRIES, SyncAPIClient, @@ -52,7 +52,7 @@ class Stagehand(SyncAPIClient): # client options browserbase_api_key: str | None browserbase_project_id: str | None - model_api_key: str + model_api_key: str | None def __init__( self, @@ -115,10 +115,6 @@ def __init__( if model_api_key is None: model_api_key = os.environ.get("MODEL_API_KEY") - if model_api_key is None: - raise StagehandError( - "The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable" - ) self.model_api_key = model_api_key self._sea_server: SeaServerManager | None = None @@ -210,7 +206,7 @@ def _bb_project_id_auth(self) -> dict[str, str]: @property def _llm_model_api_key_auth(self) -> dict[str, str]: model_api_key = self.model_api_key - return {"x-model-api-key": model_api_key} + return {"x-model-api-key": model_api_key} if model_api_key else {} @property @override @@ -273,9 +269,11 @@ def copy( return self.__class__( browserbase_api_key=browserbase_api_key or self.browserbase_api_key, browserbase_project_id=browserbase_project_id or self.browserbase_project_id, - model_api_key=model_api_key or self.model_api_key, + model_api_key=model_api_key if model_api_key is not None else self.model_api_key, server=server or self._server_mode, - _local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path, + _local_stagehand_binary_path=_local_stagehand_binary_path + if _local_stagehand_binary_path is not None + else self._local_stagehand_binary_path, local_host=local_host or self._local_host, local_port=local_port if local_port is not None else self._local_port, local_headless=local_headless if local_headless is not None else self._local_headless, @@ -340,7 +338,7 @@ class AsyncStagehand(AsyncAPIClient): # client options browserbase_api_key: str | None browserbase_project_id: str | None - model_api_key: str + model_api_key: str | None def __init__( self, @@ -403,10 +401,6 @@ def __init__( if model_api_key is None: model_api_key = os.environ.get("MODEL_API_KEY") - if model_api_key is None: - raise StagehandError( - "The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable" - ) self.model_api_key = model_api_key self._sea_server: SeaServerManager | None = None @@ -497,7 +491,7 @@ def _bb_project_id_auth(self) -> dict[str, str]: @property def _llm_model_api_key_auth(self) -> dict[str, str]: model_api_key = self.model_api_key - return {"x-model-api-key": model_api_key} + return {"x-model-api-key": model_api_key} if model_api_key else {} @property @override @@ -560,9 +554,11 @@ def copy( return self.__class__( browserbase_api_key=browserbase_api_key or self.browserbase_api_key, browserbase_project_id=browserbase_project_id or self.browserbase_project_id, - model_api_key=model_api_key or self.model_api_key, + model_api_key=model_api_key if model_api_key is not None else self.model_api_key, server=server or self._server_mode, - _local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path, + _local_stagehand_binary_path=_local_stagehand_binary_path + if _local_stagehand_binary_path is not None + else self._local_stagehand_binary_path, local_host=local_host or self._local_host, local_port=local_port if local_port is not None else self._local_port, local_headless=local_headless if local_headless is not None else self._local_headless, diff --git a/src/stagehand/resources/sessions.py b/src/stagehand/resources/sessions.py index 09ed74c3..c9cce3e5 100644 --- a/src/stagehand/resources/sessions.py +++ b/src/stagehand/resources/sessions.py @@ -915,6 +915,7 @@ def start( browser: session_start_params.Browser | Omit = omit, browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, browserbase_session_id: str | Omit = omit, + model_client_options: session_start_params.ModelClientOptions | Omit = omit, dom_settle_timeout_ms: float | Omit = omit, experimental: bool | Omit = omit, self_heal: bool | Omit = omit, @@ -976,6 +977,7 @@ def start( "browser": browser, "browserbase_session_create_params": browserbase_session_create_params, "browserbase_session_id": browserbase_session_id, + "model_client_options": model_client_options, "dom_settle_timeout_ms": dom_settle_timeout_ms, "experimental": experimental, "self_heal": self_heal, @@ -1867,6 +1869,7 @@ async def start( browser: session_start_params.Browser | Omit = omit, browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, browserbase_session_id: str | Omit = omit, + model_client_options: session_start_params.ModelClientOptions | Omit = omit, dom_settle_timeout_ms: float | Omit = omit, experimental: bool | Omit = omit, self_heal: bool | Omit = omit, @@ -1928,6 +1931,7 @@ async def start( "browser": browser, "browserbase_session_create_params": browserbase_session_create_params, "browserbase_session_id": browserbase_session_id, + "model_client_options": model_client_options, "dom_settle_timeout_ms": dom_settle_timeout_ms, "experimental": experimental, "self_heal": self_heal, diff --git a/src/stagehand/resources/sessions_helpers.py b/src/stagehand/resources/sessions_helpers.py index 8eafd75e..139ac84b 100644 --- a/src/stagehand/resources/sessions_helpers.py +++ b/src/stagehand/resources/sessions_helpers.py @@ -2,6 +2,7 @@ from __future__ import annotations +from typing import Any from typing_extensions import Literal, override import httpx @@ -27,6 +28,26 @@ from ..types.session_start_response import SessionStartResponse +def _has_explicit_aws_credentials(model_config: dict[str, Any]) -> bool: + return any(model_config.get(key) for key in ("access_key_id", "secret_access_key", "session_token")) + + +def _build_default_model_config( + *, + model_name: str, + model_client_options: session_start_params.ModelClientOptions | Omit, + fallback_api_key: str | None, +) -> dict[str, Any]: + model_config: dict[str, Any] = {"model_name": model_name} + if isinstance(model_client_options, dict): + model_config.update(model_client_options) + + if fallback_api_key and "api_key" not in model_config and not _has_explicit_aws_credentials(model_config): + model_config["api_key"] = fallback_api_key + + return model_config + + class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse): def __init__(self, sessions: SessionsResourceWithHelpers) -> None: # type: ignore[name-defined] super().__init__(sessions) @@ -71,6 +92,7 @@ def start( browser: session_start_params.Browser | Omit = omit, browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, browserbase_session_id: str | Omit = omit, + model_client_options: session_start_params.ModelClientOptions | Omit = omit, dom_settle_timeout_ms: float | Omit = omit, experimental: bool | Omit = omit, self_heal: bool | Omit = omit, @@ -89,6 +111,7 @@ def start( browser=browser, browserbase_session_create_params=browserbase_session_create_params, browserbase_session_id=browserbase_session_id, + model_client_options=model_client_options, dom_settle_timeout_ms=dom_settle_timeout_ms, experimental=experimental, self_heal=self_heal, @@ -101,7 +124,17 @@ def start( extra_body=extra_body, timeout=timeout, ) - return Session(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success) + return Session( + self._client, + start_response.data.session_id, + data=start_response.data, + success=start_response.success, + default_model=_build_default_model_config( + model_name=model_name, + model_client_options=model_client_options, + fallback_api_key=self._client.model_api_key, + ), + ) class AsyncSessionsResourceWithHelpers(AsyncSessionsResource): @@ -124,6 +157,7 @@ async def start( browser: session_start_params.Browser | Omit = omit, browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit, browserbase_session_id: str | Omit = omit, + model_client_options: session_start_params.ModelClientOptions | Omit = omit, dom_settle_timeout_ms: float | Omit = omit, experimental: bool | Omit = omit, self_heal: bool | Omit = omit, @@ -142,6 +176,7 @@ async def start( browser=browser, browserbase_session_create_params=browserbase_session_create_params, browserbase_session_id=browserbase_session_id, + model_client_options=model_client_options, dom_settle_timeout_ms=dom_settle_timeout_ms, experimental=experimental, self_heal=self_heal, @@ -154,4 +189,14 @@ async def start( extra_body=extra_body, timeout=timeout, ) - return AsyncSession(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success) + return AsyncSession( + self._client, + start_response.data.session_id, + data=start_response.data, + success=start_response.success, + default_model=_build_default_model_config( + model_name=model_name, + model_client_options=model_client_options, + fallback_api_key=self._client.model_api_key, + ), + ) diff --git a/src/stagehand/session.py b/src/stagehand/session.py index 1224cb08..852322f3 100644 --- a/src/stagehand/session.py +++ b/src/stagehand/session.py @@ -49,9 +49,7 @@ def _extract_frame_id_from_playwright_page(page: Any) -> str: new_cdp_session = getattr(context, "new_cdp_session", None) if not callable(new_cdp_session): - raise StagehandError( - "page must be a Playwright Page; expected page.context.new_cdp_session(...) to exist" - ) + raise StagehandError("page must be a Playwright Page; expected page.context.new_cdp_session(...) to exist") pw_context = cast(_PlaywrightContext, context) cdp = pw_context.new_cdp_session(page) @@ -87,9 +85,7 @@ async def _extract_frame_id_from_playwright_page_async(page: Any) -> str: new_cdp_session = getattr(context, "new_cdp_session", None) if not callable(new_cdp_session): - raise StagehandError( - "page must be a Playwright Page; expected page.context.new_cdp_session(...) to exist" - ) + raise StagehandError("page must be a Playwright Page; expected page.context.new_cdp_session(...) to exist") pw_context = cast(_PlaywrightContext, context) cdp = pw_context.new_cdp_session(page) @@ -127,15 +123,78 @@ async def _maybe_inject_frame_id_async(params: dict[str, Any], page: Any | None) return {**params, "frame_id": await _extract_frame_id_from_playwright_page_async(page)} +def _merge_default_model(params: dict[str, Any], default_model: dict[str, Any] | None) -> dict[str, Any]: + if not default_model: + return params + + options = params.get("options") + if isinstance(options, dict): + options = cast(dict[str, Any], options) + if options.get("model") is not None: + return params + return { + **params, + "options": { + **options, + "model": dict(default_model), + }, + } + + if options: + return params + + return { + **params, + "options": { + "model": dict(default_model), + }, + } + + +def _merge_default_agent_model(params: dict[str, Any], default_model: dict[str, Any] | None) -> dict[str, Any]: + if not default_model: + return params + + agent_config = params.get("agent_config") + if isinstance(agent_config, dict): + agent_config = cast(dict[str, Any], agent_config) + if agent_config.get("model") is not None: + return params + return { + **params, + "agent_config": { + **agent_config, + "model": dict(default_model), + }, + } + + if agent_config: + return params + + return { + **params, + "agent_config": { + "model": dict(default_model), + }, + } + + class Session(SessionStartResponse): """A Stagehand session bound to a specific `session_id`.""" - def __init__(self, client: Stagehand, id: str, data: SessionStartResponseData, success: bool) -> None: + def __init__( + self, + client: Stagehand, + id: str, + data: SessionStartResponseData, + success: bool, + default_model: dict[str, Any] | None = None, + ) -> None: # Must call super().__init__() first to initialize Pydantic's __pydantic_extra__ before setting attributes super().__init__(data=data, success=success) self._client = client + self._default_model = default_model self.id = id - def navigate( self, @@ -166,15 +225,16 @@ def act( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_act_params.SessionActParamsNonStreaming], ) -> SessionActResponse: + request_params = _merge_default_model(dict(params), self._default_model) return cast( SessionActResponse, self._client.sessions.act( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_maybe_inject_frame_id(request_params, page), ), ) @@ -188,15 +248,16 @@ def observe( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_observe_params.SessionObserveParamsNonStreaming], ) -> SessionObserveResponse: + request_params = _merge_default_model(dict(params), self._default_model) return cast( SessionObserveResponse, self._client.sessions.observe( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_maybe_inject_frame_id(request_params, page), ), ) @@ -210,15 +271,16 @@ def extract( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], ) -> SessionExtractResponse: + request_params = _merge_default_model(dict(params), self._default_model) return cast( SessionExtractResponse, self._client.sessions.extract( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_maybe_inject_frame_id(request_params, page), ), ) @@ -232,15 +294,16 @@ def execute( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_execute_params.SessionExecuteParamsNonStreaming], ) -> SessionExecuteResponse: + request_params = _merge_default_agent_model(dict(params), self._default_model) return cast( SessionExecuteResponse, self._client.sessions.execute( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **_maybe_inject_frame_id(dict(params), page), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **_maybe_inject_frame_id(request_params, page), ), ) @@ -266,10 +329,18 @@ def end( class AsyncSession(SessionStartResponse): """Async variant of `Session`.""" - def __init__(self, client: AsyncStagehand, id: str, data: SessionStartResponseData, success: bool) -> None: + def __init__( + self, + client: AsyncStagehand, + id: str, + data: SessionStartResponseData, + success: bool, + default_model: dict[str, Any] | None = None, + ) -> None: # Must call super().__init__() first to initialize Pydantic's __pydantic_extra__ before setting attributes super().__init__(data=data, success=success) self._client = client + self._default_model = default_model self.id = id async def navigate( @@ -301,15 +372,16 @@ async def act( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_act_params.SessionActParamsNonStreaming], ) -> SessionActResponse: + request_params = _merge_default_model(dict(params), self._default_model) return cast( SessionActResponse, await self._client.sessions.act( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **(await _maybe_inject_frame_id_async(request_params, page)), ), ) @@ -323,15 +395,16 @@ async def observe( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_observe_params.SessionObserveParamsNonStreaming], ) -> SessionObserveResponse: + request_params = _merge_default_model(dict(params), self._default_model) return cast( SessionObserveResponse, await self._client.sessions.observe( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **(await _maybe_inject_frame_id_async(request_params, page)), ), ) @@ -345,15 +418,16 @@ async def extract( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], ) -> SessionExtractResponse: + request_params = _merge_default_model(dict(params), self._default_model) return cast( SessionExtractResponse, await self._client.sessions.extract( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **(await _maybe_inject_frame_id_async(request_params, page)), ), ) @@ -367,15 +441,16 @@ async def execute( timeout: float | httpx.Timeout | None | NotGiven = not_given, **params: Unpack[session_execute_params.SessionExecuteParamsNonStreaming], ) -> SessionExecuteResponse: + request_params = _merge_default_agent_model(dict(params), self._default_model) return cast( SessionExecuteResponse, await self._client.sessions.execute( - id=self.id, - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - **(await _maybe_inject_frame_id_async(dict(params), page)), + id=self.id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + **(await _maybe_inject_frame_id_async(request_params, page)), ), ) diff --git a/src/stagehand/types/model_config_param.py b/src/stagehand/types/model_config_param.py index 4699e677..7b787084 100644 --- a/src/stagehand/types/model_config_param.py +++ b/src/stagehand/types/model_config_param.py @@ -16,8 +16,23 @@ class ModelConfigParam(TypedDict, total=False): api_key: Annotated[str, PropertyInfo(alias="apiKey")] """API key for the model provider""" + access_key_id: Annotated[str, PropertyInfo(alias="accessKeyId")] + """AWS access key ID for Bedrock""" + base_url: Annotated[str, PropertyInfo(alias="baseURL")] """Base URL for the model provider""" - provider: Literal["openai", "anthropic", "google", "microsoft"] + headers: dict[str, str] + """Additional headers for the model provider""" + + provider: Literal["openai", "anthropic", "google", "microsoft", "bedrock"] """AI provider for the model (or provide a baseURL endpoint instead)""" + + region: str + """AWS region for Bedrock""" + + secret_access_key: Annotated[str, PropertyInfo(alias="secretAccessKey")] + """AWS secret access key for Bedrock""" + + session_token: Annotated[str, PropertyInfo(alias="sessionToken")] + """AWS session token for Bedrock""" diff --git a/src/stagehand/types/session_execute_params.py b/src/stagehand/types/session_execute_params.py index d1afa802..86c367b5 100644 --- a/src/stagehand/types/session_execute_params.py +++ b/src/stagehand/types/session_execute_params.py @@ -62,7 +62,7 @@ class AgentConfig(TypedDict, total=False): 'anthropic/claude-4.5-opus') """ - provider: Literal["openai", "anthropic", "google", "microsoft"] + provider: Literal["openai", "anthropic", "google", "microsoft", "bedrock"] """AI provider for the agent (legacy, use model: openai/gpt-5-nano instead)""" system_prompt: Annotated[str, PropertyInfo(alias="systemPrompt")] diff --git a/src/stagehand/types/session_start_params.py b/src/stagehand/types/session_start_params.py index 2df12abd..d8afb2d4 100644 --- a/src/stagehand/types/session_start_params.py +++ b/src/stagehand/types/session_start_params.py @@ -24,6 +24,7 @@ "BrowserbaseSessionCreateParamsProxiesProxyConfigListBrowserbaseProxyConfig", "BrowserbaseSessionCreateParamsProxiesProxyConfigListBrowserbaseProxyConfigGeolocation", "BrowserbaseSessionCreateParamsProxiesProxyConfigListExternalProxyConfig", + "ModelClientOptions", ] @@ -47,6 +48,9 @@ class SessionStartParams(TypedDict, total=False): browserbase_session_id: Annotated[str, PropertyInfo(alias="browserbaseSessionID")] """Existing Browserbase session ID to resume""" + model_client_options: Annotated[ModelClientOptions, PropertyInfo(alias="modelClientOptions")] + """Provider-specific model client options such as Bedrock auth and region""" + dom_settle_timeout_ms: Annotated[float, PropertyInfo(alias="domSettleTimeoutMs")] """Timeout in ms to wait for DOM to settle""" @@ -68,6 +72,32 @@ class SessionStartParams(TypedDict, total=False): """Whether to stream the response via SSE""" +class ModelClientOptions(TypedDict, total=False): + api_key: Annotated[str, PropertyInfo(alias="apiKey")] + """API key for the model provider""" + + access_key_id: Annotated[str, PropertyInfo(alias="accessKeyId")] + """AWS access key ID for Bedrock""" + + base_url: Annotated[str, PropertyInfo(alias="baseURL")] + """Base URL for the model provider""" + + headers: Dict[str, str] + """Additional headers for the model provider""" + + provider: Literal["openai", "anthropic", "google", "microsoft", "bedrock"] + """AI provider for the model""" + + region: str + """AWS region for Bedrock""" + + secret_access_key: Annotated[str, PropertyInfo(alias="secretAccessKey")] + """AWS secret access key for Bedrock""" + + session_token: Annotated[str, PropertyInfo(alias="sessionToken")] + """AWS session token for Bedrock""" + + class BrowserLaunchOptionsProxy(TypedDict, total=False): server: Required[str] diff --git a/tests/test_client.py b/tests/test_client.py index 95758e1e..845490ce 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,7 +24,7 @@ from stagehand._utils import asyncify from stagehand._models import BaseModel, FinalRequestOptions from stagehand._streaming import Stream, AsyncStream -from stagehand._exceptions import APIStatusError, StagehandError, APITimeoutError, APIResponseValidationError +from stagehand._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError from stagehand._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, @@ -464,22 +464,23 @@ def test_validate_headers(self) -> None: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-model-api-key") == model_api_key - with pytest.raises(StagehandError): - with update_env( - **{ - "BROWSERBASE_API_KEY": Omit(), - "BROWSERBASE_PROJECT_ID": Omit(), - "MODEL_API_KEY": Omit(), - } - ): - client2 = Stagehand( - base_url=base_url, - browserbase_api_key=None, - browserbase_project_id=None, - model_api_key=None, - _strict_response_validation=True, - ) - client2.sessions.start(model_name="openai/gpt-5-nano") + with update_env( + **{ + "BROWSERBASE_API_KEY": Omit(), + "BROWSERBASE_PROJECT_ID": Omit(), + "MODEL_API_KEY": Omit(), + } + ): + client2 = Stagehand( + base_url=base_url, + browserbase_api_key=None, + browserbase_project_id=None, + model_api_key=None, + _strict_response_validation=True, + ) + request2 = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request2.headers.get("x-model-api-key") is None + client2.close() def test_default_query_option(self) -> None: client = Stagehand( @@ -1512,22 +1513,22 @@ def test_validate_headers(self) -> None: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-model-api-key") == model_api_key - with pytest.raises(StagehandError): - with update_env( - **{ - "BROWSERBASE_API_KEY": Omit(), - "BROWSERBASE_PROJECT_ID": Omit(), - "MODEL_API_KEY": Omit(), - } - ): - client2 = AsyncStagehand( - base_url=base_url, - browserbase_api_key=None, - browserbase_project_id=None, - model_api_key=None, - _strict_response_validation=True, - ) - _ = client2 + with update_env( + **{ + "BROWSERBASE_API_KEY": Omit(), + "BROWSERBASE_PROJECT_ID": Omit(), + "MODEL_API_KEY": Omit(), + } + ): + client2 = AsyncStagehand( + base_url=base_url, + browserbase_api_key=None, + browserbase_project_id=None, + model_api_key=None, + _strict_response_validation=True, + ) + request2 = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request2.headers.get("x-model-api-key") is None async def test_default_query_option(self) -> None: client = AsyncStagehand( diff --git a/tests/test_sessions_create_helper.py b/tests/test_sessions_create_helper.py index 7b5bf47f..8dd78f73 100644 --- a/tests/test_sessions_create_helper.py +++ b/tests/test_sessions_create_helper.py @@ -47,6 +47,68 @@ def test_sessions_create_returns_bound_session(respx_mock: MockRouter, client: S assert "frameId" not in request_body +@pytest.mark.respx(base_url=base_url) +def test_sessions_create_preserves_default_model_config(respx_mock: MockRouter) -> None: + session_id = "00000000-0000-0000-0000-000000000000" + client = Stagehand( + base_url=base_url, + browserbase_api_key="My Browserbase API Key", + browserbase_project_id="My Browserbase Project ID", + model_api_key=None, + ) + + start_route = respx_mock.post("/v1/sessions/start").mock( + return_value=httpx.Response( + 200, + json={ + "success": True, + "data": {"available": True, "sessionId": session_id}, + }, + ) + ) + + extract_route = respx_mock.post(f"/v1/sessions/{session_id}/extract").mock( + return_value=httpx.Response( + 200, + json={"success": True, "data": {"result": {"title": "Example"}}}, + ) + ) + + session = client.sessions.start( + model_name="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", + model_client_options={ + "api_key": "bedrock-bearer-token", + "region": "us-east-1", + }, + ) + + start_call = cast(Call, start_route.calls[0]) + start_request_body = json.loads(start_call.request.content) + assert start_request_body == { + "modelName": "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", + "modelClientOptions": { + "apiKey": "bedrock-bearer-token", + "region": "us-east-1", + }, + } + assert start_call.request.headers.get("x-model-api-key") is None + + session.extract( + instruction="extract the page title", + schema={"type": "object", "properties": {"title": {"type": "string"}}}, + ) + + extract_call = cast(Call, extract_route.calls[0]) + extract_request_body = json.loads(extract_call.request.content) + assert extract_request_body["options"]["model"] == { + "modelName": "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", + "apiKey": "bedrock-bearer-token", + "region": "us-east-1", + } + + client.close() + + @pytest.mark.respx(base_url=base_url) async def test_async_sessions_create_returns_bound_session( respx_mock: MockRouter, async_client: AsyncStagehand @@ -78,3 +140,73 @@ async def test_async_sessions_create_returns_bound_session( first_call = cast(Call, navigate_route.calls[0]) request_body = json.loads(first_call.request.content) assert "frameId" not in request_body + + +@pytest.mark.respx(base_url=base_url) +async def test_async_sessions_create_preserves_default_model_config( + respx_mock: MockRouter, +) -> None: + session_id = "00000000-0000-0000-0000-000000000000" + client = AsyncStagehand( + base_url=base_url, + browserbase_api_key="My Browserbase API Key", + browserbase_project_id="My Browserbase Project ID", + model_api_key=None, + ) + + start_route = respx_mock.post("/v1/sessions/start").mock( + return_value=httpx.Response( + 200, + json={ + "success": True, + "data": {"available": True, "sessionId": session_id}, + }, + ) + ) + + extract_route = respx_mock.post(f"/v1/sessions/{session_id}/extract").mock( + return_value=httpx.Response( + 200, + json={"success": True, "data": {"result": {"title": "Example"}}}, + ) + ) + + session = await client.sessions.start( + model_name="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", + model_client_options={ + "access_key_id": "AKIAIOSFODNN7EXAMPLE", + "secret_access_key": "secret", + "session_token": "session-token", + "region": "us-east-1", + }, + ) + + start_call = cast(Call, start_route.calls[0]) + start_request_body = json.loads(start_call.request.content) + assert start_request_body == { + "modelName": "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", + "modelClientOptions": { + "accessKeyId": "AKIAIOSFODNN7EXAMPLE", + "secretAccessKey": "secret", + "sessionToken": "session-token", + "region": "us-east-1", + }, + } + assert start_call.request.headers.get("x-model-api-key") is None + + await session.extract( + instruction="extract the page title", + schema={"type": "object", "properties": {"title": {"type": "string"}}}, + ) + + extract_call = cast(Call, extract_route.calls[0]) + extract_request_body = json.loads(extract_call.request.content) + assert extract_request_body["options"]["model"] == { + "modelName": "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", + "accessKeyId": "AKIAIOSFODNN7EXAMPLE", + "secretAccessKey": "secret", + "sessionToken": "session-token", + "region": "us-east-1", + } + + await client.close()