diff --git a/agentrun/sandbox/__sandbox_async_template.py b/agentrun/sandbox/__sandbox_async_template.py index 780c65d..2cc43f7 100644 --- a/agentrun/sandbox/__sandbox_async_template.py +++ b/agentrun/sandbox/__sandbox_async_template.py @@ -187,6 +187,7 @@ async def create_async( nas_config=nas_config, oss_mount_config=oss_mount_config, polar_fs_config=polar_fs_config, + config=config, ) # 根据 template 类型转换为对应的子类实例 @@ -216,35 +217,42 @@ async def create_async( return sandbox @classmethod - async def stop_by_id_async(cls, sandbox_id: str): + async def stop_by_id_async( + cls, sandbox_id: str, config: Optional[Config] = None + ): """通过 ID 停止 Sandbox(异步) Args: sandbox_id: Sandbox ID - config: 配置对象 + config: 配置对象 / Config object Returns: Sandbox: Sandbox 对象 """ if sandbox_id is None: raise ValueError("sandbox_id is required") - # todo 后续适配后使用 stop() - return await cls.__get_client().stop_sandbox_async(sandbox_id) + return await cls.__get_client().stop_sandbox_async( + sandbox_id, config=config + ) @classmethod - async def delete_by_id_async(cls, sandbox_id: str): + async def delete_by_id_async( + cls, sandbox_id: str, config: Optional[Config] = None + ): """通过 ID 删除 Sandbox(异步) Args: sandbox_id: Sandbox ID - config: 配置对象 + config: 配置对象 / Config object Returns: Sandbox: Sandbox 对象 """ if sandbox_id is None: raise ValueError("sandbox_id is required") - return await cls.__get_client().delete_sandbox_async(sandbox_id) + return await cls.__get_client().delete_sandbox_async( + sandbox_id, config=config + ) @classmethod async def list_async( @@ -476,16 +484,17 @@ async def get_async(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to get a Sandbox") - return await self.connect_async(self.sandbox_id) + return await self.connect_async(self.sandbox_id, config=self._config) async def delete_async(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to delete a Sandbox") - return await self.delete_by_id_async(self.sandbox_id) + return await self.delete_by_id_async( + self.sandbox_id, config=self._config + ) async def stop_async(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to stop a Sandbox") - # todo 后续适配后使用 stop() - return await self.stop_by_id_async(self.sandbox_id) + return await self.stop_by_id_async(self.sandbox_id, config=self._config) diff --git a/agentrun/sandbox/sandbox.py b/agentrun/sandbox/sandbox.py index a504531..e0607b9 100644 --- a/agentrun/sandbox/sandbox.py +++ b/agentrun/sandbox/sandbox.py @@ -257,6 +257,7 @@ async def create_async( nas_config=nas_config, oss_mount_config=oss_mount_config, polar_fs_config=polar_fs_config, + config=config, ) # 根据 template 类型转换为对应的子类实例 @@ -332,6 +333,7 @@ def create( nas_config=nas_config, oss_mount_config=oss_mount_config, polar_fs_config=polar_fs_config, + config=config, ) # 根据 template 类型转换为对应的子类实例 @@ -361,66 +363,72 @@ def create( return sandbox @classmethod - async def stop_by_id_async(cls, sandbox_id: str): + async def stop_by_id_async( + cls, sandbox_id: str, config: Optional[Config] = None + ): """通过 ID 停止 Sandbox(异步) Args: sandbox_id: Sandbox ID - config: 配置对象 + config: 配置对象 / Config object Returns: Sandbox: Sandbox 对象 """ if sandbox_id is None: raise ValueError("sandbox_id is required") - # todo 后续适配后使用 stop() - return await cls.__get_client().stop_sandbox_async(sandbox_id) + return await cls.__get_client().stop_sandbox_async( + sandbox_id, config=config + ) @classmethod - def stop_by_id(cls, sandbox_id: str): + def stop_by_id(cls, sandbox_id: str, config: Optional[Config] = None): """通过 ID 停止 Sandbox(同步) Args: sandbox_id: Sandbox ID - config: 配置对象 + config: 配置对象 / Config object Returns: Sandbox: Sandbox 对象 """ if sandbox_id is None: raise ValueError("sandbox_id is required") - # todo 后续适配后使用 stop() - return cls.__get_client().stop_sandbox(sandbox_id) + return cls.__get_client().stop_sandbox(sandbox_id, config=config) @classmethod - async def delete_by_id_async(cls, sandbox_id: str): + async def delete_by_id_async( + cls, sandbox_id: str, config: Optional[Config] = None + ): """通过 ID 删除 Sandbox(异步) Args: sandbox_id: Sandbox ID - config: 配置对象 + config: 配置对象 / Config object Returns: Sandbox: Sandbox 对象 """ if sandbox_id is None: raise ValueError("sandbox_id is required") - return await cls.__get_client().delete_sandbox_async(sandbox_id) + return await cls.__get_client().delete_sandbox_async( + sandbox_id, config=config + ) @classmethod - def delete_by_id(cls, sandbox_id: str): + def delete_by_id(cls, sandbox_id: str, config: Optional[Config] = None): """通过 ID 删除 Sandbox(同步) Args: sandbox_id: Sandbox ID - config: 配置对象 + config: 配置对象 / Config object Returns: Sandbox: Sandbox 对象 """ if sandbox_id is None: raise ValueError("sandbox_id is required") - return cls.__get_client().delete_sandbox(sandbox_id) + return cls.__get_client().delete_sandbox(sandbox_id, config=config) @classmethod async def list_async( @@ -866,34 +874,34 @@ async def get_async(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to get a Sandbox") - return await self.connect_async(self.sandbox_id) + return await self.connect_async(self.sandbox_id, config=self._config) def get(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to get a Sandbox") - return self.connect(self.sandbox_id) + return self.connect(self.sandbox_id, config=self._config) async def delete_async(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to delete a Sandbox") - return await self.delete_by_id_async(self.sandbox_id) + return await self.delete_by_id_async( + self.sandbox_id, config=self._config + ) def delete(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to delete a Sandbox") - return self.delete_by_id(self.sandbox_id) + return self.delete_by_id(self.sandbox_id, config=self._config) async def stop_async(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to stop a Sandbox") - # todo 后续适配后使用 stop() - return await self.stop_by_id_async(self.sandbox_id) + return await self.stop_by_id_async(self.sandbox_id, config=self._config) def stop(self): if self.sandbox_id is None: raise ValueError("sandbox_id is required to stop a Sandbox") - # todo 后续适配后使用 stop() - return self.stop_by_id(self.sandbox_id) + return self.stop_by_id(self.sandbox_id, config=self._config) diff --git a/tests/unittests/sandbox/test_sandbox.py b/tests/unittests/sandbox/test_sandbox.py index dd11c90..5c42471 100644 --- a/tests/unittests/sandbox/test_sandbox.py +++ b/tests/unittests/sandbox/test_sandbox.py @@ -1111,3 +1111,316 @@ async def test_stop_async_calls_stop_by_id( sb = Sandbox(sandbox_id="sb-1") result = await sb.stop_async() assert result.sandbox_id == "sb-1" + + +# ==================== Config 透传测试 ==================== + + +class TestSandboxConfigPassthrough: + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_stop_by_id_passes_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.stop_sandbox.return_value = { + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + Sandbox.stop_by_id("sb-1", config=my_config) + mock_data_api.stop_sandbox.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + @pytest.mark.asyncio + async def test_stop_by_id_async_passes_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.stop_sandbox_async = AsyncMock( + return_value={ + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + ) + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + await Sandbox.stop_by_id_async("sb-1", config=my_config) + mock_data_api.stop_sandbox_async.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_delete_by_id_passes_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.delete_sandbox.return_value = { + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + Sandbox.delete_by_id("sb-1", config=my_config) + mock_data_api.delete_sandbox.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + @pytest.mark.asyncio + async def test_delete_by_id_async_passes_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.delete_sandbox_async = AsyncMock( + return_value={ + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + ) + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + await Sandbox.delete_by_id_async("sb-1", config=my_config) + mock_data_api.delete_sandbox_async.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_create_passes_config_to_create_sandbox( + self, mock_data_api_class, mock_control_api_class + ): + mock_control_api = MagicMock() + mock_control_api.get_template.return_value = MockTemplateData() + mock_control_api_class.return_value = mock_control_api + + mock_data_api = MagicMock() + mock_data_api.create_sandbox.return_value = { + "code": "SUCCESS", + "data": { + "sandboxId": "sandbox-ci-123", + "templateName": "test-template", + }, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + result = Sandbox.create( + template_type=TemplateType.CODE_INTERPRETER, + template_name="test-template", + config=my_config, + ) + assert isinstance(result, CodeInterpreterSandbox) + call_kwargs = mock_data_api.create_sandbox.call_args + assert call_kwargs.kwargs.get("config") is my_config + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + @pytest.mark.asyncio + async def test_create_async_passes_config_to_create_sandbox( + self, mock_data_api_class, mock_control_api_class + ): + mock_control_api = MagicMock() + mock_control_api.get_template_async = AsyncMock( + return_value=MockTemplateData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_data_api = MagicMock() + mock_data_api.create_sandbox_async = AsyncMock( + return_value={ + "code": "SUCCESS", + "data": { + "sandboxId": "sandbox-ci-123", + "templateName": "test-template", + }, + } + ) + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + result = await Sandbox.create_async( + template_type=TemplateType.CODE_INTERPRETER, + template_name="test-template", + config=my_config, + ) + assert isinstance(result, CodeInterpreterSandbox) + call_kwargs = mock_data_api.create_sandbox_async.call_args + assert call_kwargs.kwargs.get("config") is my_config + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_instance_stop_passes_self_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.stop_sandbox.return_value = { + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + sb = Sandbox(sandbox_id="sb-1") + sb._config = my_config + sb.stop() + mock_data_api.stop_sandbox.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + @pytest.mark.asyncio + async def test_instance_stop_async_passes_self_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.stop_sandbox_async = AsyncMock( + return_value={ + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + ) + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + sb = Sandbox(sandbox_id="sb-1") + sb._config = my_config + await sb.stop_async() + mock_data_api.stop_sandbox_async.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_instance_delete_passes_self_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.delete_sandbox.return_value = { + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + sb = Sandbox(sandbox_id="sb-1") + sb._config = my_config + sb.delete() + mock_data_api.delete_sandbox.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + @pytest.mark.asyncio + async def test_instance_delete_async_passes_self_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_data_api = MagicMock() + mock_data_api.delete_sandbox_async = AsyncMock( + return_value={ + "code": "SUCCESS", + "data": {"sandboxId": "sb-1"}, + } + ) + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + sb = Sandbox(sandbox_id="sb-1") + sb._config = my_config + await sb.delete_async() + mock_data_api.delete_sandbox_async.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_instance_get_passes_self_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_control_api = MagicMock() + mock_control_api.get_template.return_value = MockTemplateData() + mock_control_api_class.return_value = mock_control_api + + mock_data_api = MagicMock() + mock_data_api.get_sandbox.return_value = { + "code": "SUCCESS", + "data": {"sandboxId": "sb-1", "templateName": "tpl"}, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + sb = Sandbox(sandbox_id="sb-1") + sb._config = my_config + sb.get() + mock_data_api.get_sandbox.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + @pytest.mark.asyncio + async def test_instance_get_async_passes_self_config( + self, mock_data_api_class, mock_control_api_class + ): + mock_control_api = MagicMock() + mock_control_api.get_template_async = AsyncMock( + return_value=MockTemplateData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_data_api = MagicMock() + mock_data_api.get_sandbox_async = AsyncMock( + return_value={ + "code": "SUCCESS", + "data": {"sandboxId": "sb-1", "templateName": "tpl"}, + } + ) + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + sb = Sandbox(sandbox_id="sb-1") + sb._config = my_config + await sb.get_async() + mock_data_api.get_sandbox_async.assert_called_once_with( + "sb-1", config=my_config + ) + + @patch("agentrun.sandbox.client.SandboxControlAPI") + @patch("agentrun.sandbox.client.SandboxDataAPI") + def test_create_stores_config_on_instance( + self, mock_data_api_class, mock_control_api_class + ): + mock_control_api = MagicMock() + mock_control_api.get_template.return_value = MockTemplateData() + mock_control_api_class.return_value = mock_control_api + + mock_data_api = MagicMock() + mock_data_api.create_sandbox.return_value = { + "code": "SUCCESS", + "data": { + "sandboxId": "sandbox-ci-123", + "templateName": "test-template", + }, + } + mock_data_api_class.return_value = mock_data_api + + my_config = Config(headers={"X-Test": "val"}) + result = Sandbox.create( + template_type=TemplateType.CODE_INTERPRETER, + template_name="test-template", + config=my_config, + ) + assert result._config is my_config