diff --git a/src/stagehand/_client.py b/src/stagehand/_client.py index 3997b5e..f631485 100644 --- a/src/stagehand/_client.py +++ b/src/stagehand/_client.py @@ -60,6 +60,7 @@ def __init__( browserbase_api_key: str | None = None, browserbase_project_id: str | None = None, model_api_key: str | None = None, + model_base_url: str | None = None, server: Literal["remote", "local"] = "remote", _local_stagehand_binary_path: str | os.PathLike[str] | None = None, local_host: str = "127.0.0.1", @@ -121,6 +122,10 @@ def __init__( ) self.model_api_key = model_api_key + if model_base_url is None: + model_base_url = os.environ.get("MODEL_BASE_URL") + self.model_base_url = model_base_url + self._sea_server: SeaServerManager | None = None if server == "local": # We'll switch `base_url` to the started server before the first request. @@ -195,7 +200,7 @@ def qs(self) -> Querystring: @property @override def auth_headers(self) -> dict[str, str]: - return {**self._bb_api_key_auth, **self._bb_project_id_auth, **self._llm_model_api_key_auth} + return {**self._bb_api_key_auth, **self._bb_project_id_auth, **self._llm_model_api_key_auth, **self._llm_model_base_url_header} @property def _bb_api_key_auth(self) -> dict[str, str]: @@ -212,6 +217,11 @@ def _llm_model_api_key_auth(self) -> dict[str, str]: model_api_key = self.model_api_key return {"x-model-api-key": model_api_key} + @property + def _llm_model_base_url_header(self) -> dict[str, str]: + model_base_url = self.model_base_url + return {"x-model-base-url": model_base_url} if model_base_url else {} + @property @override def default_headers(self) -> dict[str, str | Omit]: @@ -229,6 +239,7 @@ def copy( browserbase_api_key: str | None = None, browserbase_project_id: str | None = None, model_api_key: str | None = None, + model_base_url: str | None = None, server: Literal["remote", "local"] | None = None, _local_stagehand_binary_path: str | os.PathLike[str] | None = None, local_host: str | None = None, @@ -274,6 +285,7 @@ def copy( browserbase_api_key=browserbase_api_key or self.browserbase_api_key, browserbase_project_id=browserbase_project_id or self.browserbase_project_id, model_api_key=model_api_key or self.model_api_key, + model_base_url=model_base_url or self.model_base_url, server=server or self._server_mode, _local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path, local_host=local_host or self._local_host, @@ -348,6 +360,7 @@ def __init__( browserbase_api_key: str | None = None, browserbase_project_id: str | None = None, model_api_key: str | None = None, + model_base_url: str | None = None, server: Literal["remote", "local"] = "remote", _local_stagehand_binary_path: str | os.PathLike[str] | None = None, local_host: str = "127.0.0.1", @@ -409,6 +422,10 @@ def __init__( ) self.model_api_key = model_api_key + if model_base_url is None: + model_base_url = os.environ.get("MODEL_BASE_URL") + self.model_base_url = model_base_url + self._sea_server: SeaServerManager | None = None if server == "local": if base_url is None: @@ -482,7 +499,7 @@ def qs(self) -> Querystring: @property @override def auth_headers(self) -> dict[str, str]: - return {**self._bb_api_key_auth, **self._bb_project_id_auth, **self._llm_model_api_key_auth} + return {**self._bb_api_key_auth, **self._bb_project_id_auth, **self._llm_model_api_key_auth, **self._llm_model_base_url_header} @property def _bb_api_key_auth(self) -> dict[str, str]: @@ -499,6 +516,11 @@ def _llm_model_api_key_auth(self) -> dict[str, str]: model_api_key = self.model_api_key return {"x-model-api-key": model_api_key} + @property + def _llm_model_base_url_header(self) -> dict[str, str]: + model_base_url = self.model_base_url + return {"x-model-base-url": model_base_url} if model_base_url else {} + @property @override def default_headers(self) -> dict[str, str | Omit]: @@ -516,6 +538,7 @@ def copy( browserbase_api_key: str | None = None, browserbase_project_id: str | None = None, model_api_key: str | None = None, + model_base_url: str | None = None, server: Literal["remote", "local"] | None = None, _local_stagehand_binary_path: str | os.PathLike[str] | None = None, local_host: str | None = None, @@ -561,6 +584,7 @@ def copy( browserbase_api_key=browserbase_api_key or self.browserbase_api_key, browserbase_project_id=browserbase_project_id or self.browserbase_project_id, model_api_key=model_api_key or self.model_api_key, + model_base_url=model_base_url or self.model_base_url, server=server or self._server_mode, _local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path, local_host=local_host or self._local_host,