Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions src/stagehand/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ._models import FinalRequestOptions
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import APIStatusError, StagehandError
from ._exceptions import APIStatusError
from ._base_client import (
DEFAULT_MAX_RETRIES,
SyncAPIClient,
Expand Down Expand Up @@ -52,7 +52,7 @@ class Stagehand(SyncAPIClient):
# client options
browserbase_api_key: str | None
browserbase_project_id: str | None
model_api_key: str
model_api_key: str | None

def __init__(
self,
Expand Down Expand Up @@ -115,10 +115,6 @@ def __init__(

if model_api_key is None:
model_api_key = os.environ.get("MODEL_API_KEY")
if model_api_key is None:
raise StagehandError(
"The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable"
)
self.model_api_key = model_api_key

self._sea_server: SeaServerManager | None = None
Expand Down Expand Up @@ -210,7 +206,7 @@ def _bb_project_id_auth(self) -> dict[str, str]:
@property
def _llm_model_api_key_auth(self) -> dict[str, str]:
model_api_key = self.model_api_key
return {"x-model-api-key": model_api_key}
return {"x-model-api-key": model_api_key} if model_api_key else {}

@property
@override
Expand Down Expand Up @@ -273,9 +269,11 @@ def copy(
return self.__class__(
browserbase_api_key=browserbase_api_key or self.browserbase_api_key,
browserbase_project_id=browserbase_project_id or self.browserbase_project_id,
model_api_key=model_api_key or self.model_api_key,
model_api_key=model_api_key if model_api_key is not None else self.model_api_key,
server=server or self._server_mode,
_local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path,
_local_stagehand_binary_path=_local_stagehand_binary_path
if _local_stagehand_binary_path is not None
else self._local_stagehand_binary_path,
local_host=local_host or self._local_host,
local_port=local_port if local_port is not None else self._local_port,
local_headless=local_headless if local_headless is not None else self._local_headless,
Expand Down Expand Up @@ -340,7 +338,7 @@ class AsyncStagehand(AsyncAPIClient):
# client options
browserbase_api_key: str | None
browserbase_project_id: str | None
model_api_key: str
model_api_key: str | None

def __init__(
self,
Expand Down Expand Up @@ -403,10 +401,6 @@ def __init__(

if model_api_key is None:
model_api_key = os.environ.get("MODEL_API_KEY")
if model_api_key is None:
raise StagehandError(
"The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable"
)
self.model_api_key = model_api_key

self._sea_server: SeaServerManager | None = None
Expand Down Expand Up @@ -497,7 +491,7 @@ def _bb_project_id_auth(self) -> dict[str, str]:
@property
def _llm_model_api_key_auth(self) -> dict[str, str]:
model_api_key = self.model_api_key
return {"x-model-api-key": model_api_key}
return {"x-model-api-key": model_api_key} if model_api_key else {}

@property
@override
Expand Down Expand Up @@ -560,9 +554,11 @@ def copy(
return self.__class__(
browserbase_api_key=browserbase_api_key or self.browserbase_api_key,
browserbase_project_id=browserbase_project_id or self.browserbase_project_id,
model_api_key=model_api_key or self.model_api_key,
model_api_key=model_api_key if model_api_key is not None else self.model_api_key,
server=server or self._server_mode,
_local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path,
_local_stagehand_binary_path=_local_stagehand_binary_path
if _local_stagehand_binary_path is not None
else self._local_stagehand_binary_path,
local_host=local_host or self._local_host,
local_port=local_port if local_port is not None else self._local_port,
local_headless=local_headless if local_headless is not None else self._local_headless,
Expand Down
4 changes: 4 additions & 0 deletions src/stagehand/resources/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def start(
browser: session_start_params.Browser | Omit = omit,
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
browserbase_session_id: str | Omit = omit,
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
dom_settle_timeout_ms: float | Omit = omit,
experimental: bool | Omit = omit,
self_heal: bool | Omit = omit,
Expand Down Expand Up @@ -976,6 +977,7 @@ def start(
"browser": browser,
"browserbase_session_create_params": browserbase_session_create_params,
"browserbase_session_id": browserbase_session_id,
"model_client_options": model_client_options,
"dom_settle_timeout_ms": dom_settle_timeout_ms,
"experimental": experimental,
"self_heal": self_heal,
Expand Down Expand Up @@ -1867,6 +1869,7 @@ async def start(
browser: session_start_params.Browser | Omit = omit,
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
browserbase_session_id: str | Omit = omit,
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
dom_settle_timeout_ms: float | Omit = omit,
experimental: bool | Omit = omit,
self_heal: bool | Omit = omit,
Expand Down Expand Up @@ -1928,6 +1931,7 @@ async def start(
"browser": browser,
"browserbase_session_create_params": browserbase_session_create_params,
"browserbase_session_id": browserbase_session_id,
"model_client_options": model_client_options,
"dom_settle_timeout_ms": dom_settle_timeout_ms,
"experimental": experimental,
"self_heal": self_heal,
Expand Down
49 changes: 47 additions & 2 deletions src/stagehand/resources/sessions_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from typing import Any
from typing_extensions import Literal, override

import httpx
Expand All @@ -27,6 +28,26 @@
from ..types.session_start_response import SessionStartResponse


def _has_explicit_aws_credentials(model_config: dict[str, Any]) -> bool:
return any(model_config.get(key) for key in ("access_key_id", "secret_access_key", "session_token"))


def _build_default_model_config(
*,
model_name: str,
model_client_options: session_start_params.ModelClientOptions | Omit,
fallback_api_key: str | None,
) -> dict[str, Any]:
model_config: dict[str, Any] = {"model_name": model_name}
if isinstance(model_client_options, dict):
model_config.update(model_client_options)

if fallback_api_key and "api_key" not in model_config and not _has_explicit_aws_credentials(model_config):
model_config["api_key"] = fallback_api_key

return model_config


class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):
def __init__(self, sessions: SessionsResourceWithHelpers) -> None: # type: ignore[name-defined]
super().__init__(sessions)
Expand Down Expand Up @@ -71,6 +92,7 @@ def start(
browser: session_start_params.Browser | Omit = omit,
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
browserbase_session_id: str | Omit = omit,
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
dom_settle_timeout_ms: float | Omit = omit,
experimental: bool | Omit = omit,
self_heal: bool | Omit = omit,
Expand All @@ -89,6 +111,7 @@ def start(
browser=browser,
browserbase_session_create_params=browserbase_session_create_params,
browserbase_session_id=browserbase_session_id,
model_client_options=model_client_options,
dom_settle_timeout_ms=dom_settle_timeout_ms,
experimental=experimental,
self_heal=self_heal,
Expand All @@ -101,7 +124,17 @@ def start(
extra_body=extra_body,
timeout=timeout,
)
return Session(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success)
return Session(
self._client,
start_response.data.session_id,
data=start_response.data,
success=start_response.success,
default_model=_build_default_model_config(
model_name=model_name,
model_client_options=model_client_options,
fallback_api_key=self._client.model_api_key,
),
)


class AsyncSessionsResourceWithHelpers(AsyncSessionsResource):
Expand All @@ -124,6 +157,7 @@ async def start(
browser: session_start_params.Browser | Omit = omit,
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
browserbase_session_id: str | Omit = omit,
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
dom_settle_timeout_ms: float | Omit = omit,
experimental: bool | Omit = omit,
self_heal: bool | Omit = omit,
Expand All @@ -142,6 +176,7 @@ async def start(
browser=browser,
browserbase_session_create_params=browserbase_session_create_params,
browserbase_session_id=browserbase_session_id,
model_client_options=model_client_options,
dom_settle_timeout_ms=dom_settle_timeout_ms,
experimental=experimental,
self_heal=self_heal,
Expand All @@ -154,4 +189,14 @@ async def start(
extra_body=extra_body,
timeout=timeout,
)
return AsyncSession(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success)
return AsyncSession(
self._client,
start_response.data.session_id,
data=start_response.data,
success=start_response.success,
default_model=_build_default_model_config(
model_name=model_name,
model_client_options=model_client_options,
fallback_api_key=self._client.model_api_key,
),
)
Loading
Loading