diff --git a/.github/workflows/rust-sdk-tests.yml b/.github/workflows/rust-sdk-tests.yml index 207ed6de9..f542307be 100644 --- a/.github/workflows/rust-sdk-tests.yml +++ b/.github/workflows/rust-sdk-tests.yml @@ -94,10 +94,12 @@ jobs: run: pwsh.exe -Command "Write-Host 'PowerShell ready'" - name: cargo test + timeout-minutes: 90 env: + RUST_E2E_CONCURRENCY: 4 COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }} COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }} - run: cargo test --features test-support + run: cargo test --features test-support -- --test-threads=4 --nocapture # Validates the `embedded-cli` build path on all three supported # platforms. This is the only place `build.rs` actually runs (the diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8b130628e..3065822e7 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -361,6 +361,7 @@ dependencies = [ "tokio-util", "tracing", "ureq", + "uuid", "zip", "zstd", ] @@ -1284,6 +1285,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "uuid" +version = "1.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +dependencies = [ + "getrandom 0.4.2", +] + [[package]] name = "version_check" version = "0.9.5" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 4d8831f7e..182707bf1 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -51,6 +51,7 @@ parking_lot = "0.12" regex = "1" sha2 = { version = "0.10", optional = true } getrandom = "0.2" +uuid = { version = "1", default-features = false, features = ["v4"] } zstd = { version = "0.13", optional = true } [dev-dependencies] diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs index f0b0d6cc0..88a9670cd 100644 --- a/rust/src/jsonrpc.rs +++ b/rust/src/jsonrpc.rs @@ -3,11 +3,12 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Instant; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::task::JoinHandle; use tracing::{Instrument, debug, error, warn}; use crate::{Error, ProtocolError}; @@ -184,6 +185,8 @@ pub struct JsonRpcClient { pending_requests: Arc>>>, notification_tx: broadcast::Sender, request_tx: mpsc::UnboundedSender, + read_task: Mutex>>, + write_task: Mutex>>, } impl JsonRpcClient { @@ -202,7 +205,7 @@ impl JsonRpcClient { let (write_tx, write_rx) = mpsc::unbounded_channel::(); let writer_span = tracing::error_span!("jsonrpc_write_loop"); - tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span)); + let write_task = tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span)); let client = Self { request_id: AtomicU64::new(1), @@ -210,6 +213,8 @@ impl JsonRpcClient { pending_requests: Arc::new(RwLock::new(HashMap::new())), notification_tx, request_tx, + read_task: Mutex::new(None), + write_task: Mutex::new(Some(write_task)), }; let pending_requests = client.pending_requests.clone(); @@ -217,7 +222,7 @@ impl JsonRpcClient { let request_tx_clone = client.request_tx.clone(); let reader_span = tracing::error_span!("jsonrpc_read_loop"); - tokio::spawn( + let read_task = tokio::spawn( async move { Self::read_loop( reader, @@ -229,10 +234,21 @@ impl JsonRpcClient { } .instrument(reader_span), ); + *client.read_task.lock() = Some(read_task); client } + pub(crate) fn force_close(&self) { + if let Some(task) = self.read_task.lock().take() { + task.abort(); + } + if let Some(task) = self.write_task.lock().take() { + task.abort(); + } + self.pending_requests.write().clear(); + } + /// Writer-actor task. Owns the `AsyncWrite`, drains the command queue, /// and writes each frame atomically (header + body + flush) before /// signaling the ack. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 1af468182..e0a724fd1 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -267,6 +267,15 @@ pub enum SessionError { /// non-empty. #[error("invalid SessionFsConfig: {0}")] InvalidSessionFsConfig(String), + + /// The CLI returned a different session ID than the one the SDK registered. + #[error("CLI returned session ID {returned} after SDK registered {requested}")] + SessionIdMismatch { + /// Session ID registered by the SDK before the RPC was sent. + requested: SessionId, + /// Session ID returned by the CLI. + returned: SessionId, + }, } /// How the SDK communicates with the CLI server. @@ -873,6 +882,7 @@ struct ClientInner { state: parking_lot::Mutex, lifecycle_tx: broadcast::Sender, on_list_models: Option>, + models_cache: parking_lot::Mutex>>>, session_fs_configured: bool, on_get_trace_context: Option>, /// Token sent in the `connect` handshake. Auto-generated when the @@ -900,6 +910,24 @@ impl Client { if let Some(cfg) = &options.session_fs { validate_session_fs_config(cfg)?; } + // Auth options only make sense when the SDK spawns the CLI; with an + // external server, the server manages its own auth. + if matches!(options.transport, Transport::External { .. }) { + if options.github_token.is_some() { + return Err(Error::InvalidConfig( + "github_token cannot be used with Transport::External \ + (external server manages its own auth)" + .to_string(), + )); + } + if options.use_logged_in_user == Some(true) { + return Err(Error::InvalidConfig( + "use_logged_in_user cannot be used with Transport::External \ + (external server manages its own auth)" + .to_string(), + )); + } + } // Validate token + transport combination. Stdio cannot use a // connection token; auto-generate a UUID when the SDK spawns // its own CLI in TCP mode and no explicit token was set. @@ -1138,6 +1166,7 @@ impl Client { state: parking_lot::Mutex::new(ConnectionState::Connected), lifecycle_tx: broadcast::channel(256).0, on_list_models, + models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured, on_get_trace_context, effective_connection_token, @@ -1752,10 +1781,17 @@ impl Client { /// When [`ClientOptions::on_list_models`] is set, returns the handler's /// result without making a `models.list` RPC. Otherwise queries the CLI. pub async fn list_models(&self) -> Result, Error> { - if let Some(handler) = &self.inner.on_list_models { - return handler.list_models().await; - } - Ok(self.rpc().models().list().await?.models) + let cache = self.inner.models_cache.lock().clone(); + let models = cache + .get_or_try_init(|| async { + if let Some(handler) = &self.inner.on_list_models { + handler.list_models().await + } else { + Ok(self.rpc().models().list().await?.models) + } + }) + .await?; + Ok(models.clone()) } /// Invoke [`ClientOptions::on_get_trace_context`] when configured, @@ -1828,6 +1864,7 @@ impl Client { let child = self.inner.child.lock().take(); *self.inner.state.lock() = ConnectionState::Disconnected; + *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new()); if let Some(mut child) = child && let Err(e) = child.kill().await { @@ -1879,10 +1916,12 @@ impl Client { { error!(pid = ?pid, error = %e, "failed to send kill signal"); } + self.inner.rpc.force_close(); // Drop all session channels so any awaiters see a closed channel // instead of waiting for responses that will never arrive. self.inner.router.clear(); *self.inner.state.lock() = ConnectionState::Disconnected; + *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new()); } /// Subscribe to lifecycle events. @@ -2405,43 +2444,137 @@ mod tests { policy: None, supported_reasoning_efforts: Vec::new(), }; - let handler = Arc::new(CountingHandler { + let handler: Arc = Arc::new(CountingHandler { calls: Arc::clone(&calls), models: vec![model.clone()], }); - // We can't call list_models() through Client::start without a CLI, but we - // can exercise the override path by directly constructing a Client whose - // inner has the handler set. This is the same dispatch path as the real - // call; from_streams's None default is replaced via inner construction. - let inner = ClientInner { - child: parking_lot::Mutex::new(None), - rpc: { - let (req_tx, _req_rx) = mpsc::unbounded_channel(); - let (notif_tx, _notif_rx) = broadcast::channel(16); - let (read_pipe, _write_pipe) = tokio::io::duplex(64); - let (_unused_read, write_pipe) = tokio::io::duplex(64); - JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx) - }, - cwd: PathBuf::from("."), - request_rx: parking_lot::Mutex::new(None), - notification_tx: broadcast::channel(16).0, - router: router::SessionRouter::new(), - negotiated_protocol_version: OnceLock::new(), - state: parking_lot::Mutex::new(ConnectionState::Connected), - lifecycle_tx: broadcast::channel(16).0, - on_list_models: Some(handler), - session_fs_configured: false, - on_get_trace_context: None, - effective_connection_token: None, - }; - let client = Client { - inner: Arc::new(inner), - }; + let client = client_with_list_models_handler(handler); let result = client.list_models().await.unwrap(); assert_eq!(result.len(), 1); assert_eq!(result[0].id, "byok-gpt-4"); assert_eq!(calls.load(Ordering::SeqCst), 1); } + + #[tokio::test] + async fn list_models_serializes_concurrent_cache_misses() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct SlowCountingHandler { + calls: Arc, + models: Vec, + } + #[async_trait] + impl ListModelsHandler for SlowCountingHandler { + async fn list_models(&self) -> Result, Error> { + self.calls.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + Ok(self.models.clone()) + } + } + + let calls = Arc::new(AtomicUsize::new(0)); + let model = Model { + billing: None, + capabilities: ModelCapabilities { + limits: None, + supports: None, + }, + default_reasoning_effort: None, + id: "single-flight-model".into(), + name: "Single Flight Model".into(), + policy: None, + supported_reasoning_efforts: Vec::new(), + }; + let handler: Arc = Arc::new(SlowCountingHandler { + calls: Arc::clone(&calls), + models: vec![model], + }); + let client = client_with_list_models_handler(handler); + + let (first, second) = tokio::join!(client.list_models(), client.list_models()); + assert_eq!(first.unwrap()[0].id, "single-flight-model"); + assert_eq!(second.unwrap()[0].id, "single-flight-model"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn cancelled_create_session_unregisters_pending_session() { + let (client_write, _server_read) = tokio::io::duplex(8192); + let (_server_write, client_read) = tokio::io::duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.create_session(SessionConfig::default()).await } + }); + + wait_for_pending_session_registration(&client).await; + handle.abort(); + let _ = handle.await; + + assert!(client.inner.router.session_ids().is_empty()); + client.force_stop(); + } + + #[tokio::test] + async fn cancelled_resume_session_unregisters_pending_session() { + let (client_write, _server_read) = tokio::io::duplex(8192); + let (_server_write, client_read) = tokio::io::duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + let session_id = SessionId::new("resume-cancel-test"); + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session(ResumeSessionConfig::new(session_id)) + .await + } + }); + + wait_for_pending_session_registration(&client).await; + handle.abort(); + let _ = handle.await; + + assert!(client.inner.router.session_ids().is_empty()); + client.force_stop(); + } + + fn client_with_list_models_handler(handler: Arc) -> Client { + Client { + inner: Arc::new(ClientInner { + child: parking_lot::Mutex::new(None), + rpc: { + let (req_tx, _req_rx) = mpsc::unbounded_channel(); + let (notif_tx, _notif_rx) = broadcast::channel(16); + let (read_pipe, _write_pipe) = tokio::io::duplex(64); + let (_unused_read, write_pipe) = tokio::io::duplex(64); + JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx) + }, + cwd: PathBuf::from("."), + request_rx: parking_lot::Mutex::new(None), + notification_tx: broadcast::channel(16).0, + router: router::SessionRouter::new(), + negotiated_protocol_version: OnceLock::new(), + state: parking_lot::Mutex::new(ConnectionState::Connected), + lifecycle_tx: broadcast::channel(16).0, + on_list_models: Some(handler), + models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), + session_fs_configured: false, + on_get_trace_context: None, + effective_connection_token: None, + }), + } + } + + async fn wait_for_pending_session_registration(client: &Client) { + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1); + while client.inner.router.session_ids().is_empty() { + assert!( + tokio::time::Instant::now() < deadline, + "session was not registered" + ); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } } diff --git a/rust/src/session.rs b/rust/src/session.rs index b0d5faef4..2cdb257eb 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -64,6 +64,44 @@ impl Drop for WaiterGuard { } } +struct PendingSessionRegistration { + client: Client, + session_id: SessionId, + shutdown: CancellationToken, + disarmed: bool, +} + +impl PendingSessionRegistration { + fn new(client: Client, session_id: SessionId, shutdown: CancellationToken) -> Self { + Self { + client, + session_id, + shutdown, + disarmed: false, + } + } + + async fn cleanup(mut self, event_loop: JoinHandle<()>) { + self.shutdown.cancel(); + let _ = event_loop.await; + self.client.unregister_session(&self.session_id); + self.disarmed = true; + } + + fn disarm(&mut self) { + self.disarmed = true; + } +} + +impl Drop for PendingSessionRegistration { + fn drop(&mut self) { + if !self.disarmed { + self.shutdown.cancel(); + self.client.unregister_session(&self.session_id); + } + } +} + /// A session on a GitHub Copilot CLI server. /// /// Created via [`Client::create_session`] or [`Client::resume_session`]. @@ -736,24 +774,18 @@ impl Client { if let Some(ref transforms) = transforms { inject_transform_sections(&mut config, transforms.as_ref()); } + let session_id = config + .session_id + .clone() + .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); + config.session_id = Some(session_id.clone()); let mut params = serde_json::to_value(&config)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); - let rpc_start = Instant::now(); - let result = self.call("session.create", Some(params)).await?; - tracing::debug!( - elapsed_ms = rpc_start.elapsed().as_millis(), - "Client::create_session session creation request completed successfully" - ); - let create_result: CreateSessionResult = serde_json::from_value(result)?; - let session_id = create_result.session_id; let setup_start = Instant::now(); - let capabilities = Arc::new(parking_lot::RwLock::new( - create_result.capabilities.unwrap_or_default(), - )); + let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); let channels = self.register_session(&session_id); - let idle_waiter = Arc::new(ParkingLotMutex::new(None)); let shutdown = CancellationToken::new(); let (event_tx, _) = tokio::sync::broadcast::channel(512); @@ -771,6 +803,8 @@ impl Client { event_tx.clone(), shutdown.clone(), ); + let mut registration = + PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone()); tracing::debug!( elapsed_ms = setup_start.elapsed().as_millis(), session_id = %session_id, @@ -780,11 +814,40 @@ impl Client { "Client::create_session local setup complete" ); + let rpc_start = Instant::now(); + let result = match self.call("session.create", Some(params)).await { + Ok(result) => result, + Err(error) => { + registration.cleanup(event_loop).await; + return Err(error); + } + }; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Client::create_session session creation request completed successfully" + ); + let create_result: CreateSessionResult = match serde_json::from_value(result) { + Ok(result) => result, + Err(error) => { + registration.cleanup(event_loop).await; + return Err(error.into()); + } + }; + if create_result.session_id != session_id { + registration.cleanup(event_loop).await; + return Err(Error::Session(SessionError::SessionIdMismatch { + requested: session_id, + returned: create_result.session_id, + })); + } + *capabilities.write() = create_result.capabilities.unwrap_or_default(); + tracing::debug!( elapsed_ms = total_start.elapsed().as_millis(), session_id = %session_id, "Client::create_session complete" ); + registration.disarm(); Ok(Session { id: session_id, cwd: self.cwd().clone(), @@ -836,8 +899,46 @@ impl Client { let mut params = serde_json::to_value(&config)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); + + let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); + let setup_start = Instant::now(); + let channels = self.register_session(&session_id); + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = CancellationToken::new(); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + session_id.clone(), + self.clone(), + handler, + hooks, + transforms, + command_handlers, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + let mut registration = + PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone()); + tracing::debug!( + elapsed_ms = setup_start.elapsed().as_millis(), + session_id = %session_id, + tools_count, + commands_count, + has_hooks, + "Client::resume_session local setup complete" + ); + let rpc_start = Instant::now(); - let result = self.call("session.resume", Some(params)).await?; + let result = match self.call("session.resume", Some(params)).await { + Ok(result) => result, + Err(error) => { + registration.cleanup(event_loop).await; + return Err(error); + } + }; tracing::debug!( elapsed_ms = rpc_start.elapsed().as_millis(), session_id = %session_id, @@ -850,6 +951,13 @@ impl Client { .and_then(|v| v.as_str()) .unwrap_or(&session_id) .into(); + if cli_session_id != session_id { + registration.cleanup(event_loop).await; + return Err(Error::Session(SessionError::SessionIdMismatch { + requested: session_id, + returned: cli_session_id, + })); + } let resume_capabilities: Option = result .get("capabilities") @@ -869,63 +977,34 @@ impl Client { if let Err(e) = self .call( "session.skills.reload", - Some(serde_json::json!({ "sessionId": cli_session_id })), + Some(serde_json::json!({ "sessionId": session_id })), ) .await { warn!( elapsed_ms = skills_reload_start.elapsed().as_millis(), - session_id = %cli_session_id, + session_id = %session_id, error = %e, "Client::resume_session skills reload request failed" ); } else { tracing::debug!( elapsed_ms = skills_reload_start.elapsed().as_millis(), - session_id = %cli_session_id, + session_id = %session_id, "Client::resume_session skills reload request completed successfully" ); } - let capabilities = Arc::new(parking_lot::RwLock::new( - resume_capabilities.unwrap_or_default(), - )); - let setup_start = Instant::now(); - let channels = self.register_session(&cli_session_id); - - let idle_waiter = Arc::new(ParkingLotMutex::new(None)); - let shutdown = CancellationToken::new(); - let (event_tx, _) = tokio::sync::broadcast::channel(512); - let event_loop = spawn_event_loop( - cli_session_id.clone(), - self.clone(), - handler, - hooks, - transforms, - command_handlers, - session_fs_provider, - channels, - idle_waiter.clone(), - capabilities.clone(), - event_tx.clone(), - shutdown.clone(), - ); - tracing::debug!( - elapsed_ms = setup_start.elapsed().as_millis(), - session_id = %cli_session_id, - tools_count, - commands_count, - has_hooks, - "Client::resume_session local setup complete" - ); + *capabilities.write() = resume_capabilities.unwrap_or_default(); tracing::debug!( elapsed_ms = total_start.elapsed().as_millis(), - session_id = %cli_session_id, + session_id = %session_id, "Client::resume_session complete" ); + registration.disarm(); Ok(Session { - id: cli_session_id, + id: session_id, cwd: self.cwd().clone(), workspace_path: None, remote_url, diff --git a/rust/src/types.rs b/rust/src/types.rs index 44ff96ce9..9a7ac0cbb 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1544,6 +1544,9 @@ pub struct ResumeSessionConfig { /// Application name sent as User-Agent context. #[serde(skip_serializing_if = "Option::is_none")] pub client_name: Option, + /// Desired reasoning effort to apply after resuming the session. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, /// Enable streaming token deltas. #[serde(skip_serializing_if = "Option::is_none")] pub streaming: Option, @@ -1679,6 +1682,7 @@ impl std::fmt::Debug for ResumeSessionConfig { f.debug_struct("ResumeSessionConfig") .field("session_id", &self.session_id) .field("client_name", &self.client_name) + .field("reasoning_effort", &self.reasoning_effort) .field("streaming", &self.streaming) .field("system_message", &self.system_message) .field("tools", &self.tools) @@ -1738,6 +1742,7 @@ impl ResumeSessionConfig { Self { session_id, client_name: None, + reasoning_effort: None, streaming: None, system_message: None, tools: None, @@ -1854,6 +1859,12 @@ impl ResumeSessionConfig { self } + /// Set the reasoning effort to apply on resume. + pub fn with_reasoning_effort(mut self, effort: impl Into) -> Self { + self.reasoning_effort = Some(effort.into()); + self + } + /// Enable streaming token deltas via `assistant.message_delta` events. pub fn with_streaming(mut self, streaming: bool) -> Self { self.streaming = Some(streaming); diff --git a/rust/tests/api_types_test.rs b/rust/tests/api_types_test.rs new file mode 100644 index 000000000..2a373a3b5 --- /dev/null +++ b/rust/tests/api_types_test.rs @@ -0,0 +1,99 @@ +// Unit tests for generated API types -- struct construction and field +// access. These do not require a client, session, or replay proxy. + +#![allow(clippy::unwrap_used)] + +use github_copilot_sdk::generated::api_types::{ + Extension, ExtensionList, ExtensionSource, ExtensionStatus, ExtensionsDisableRequest, + ExtensionsEnableRequest, FleetStartRequest, FleetStartResult, TasksStartAgentRequest, +}; + +#[test] +fn extension_running_has_expected_status_and_source() { + let extension = running_extension("project:demo", "demo"); + assert_eq!(extension.status, ExtensionStatus::Running); + assert_eq!(extension.source, ExtensionSource::Project); +} + +#[test] +fn disable_and_enable_requests_share_the_same_id() { + let disable = ExtensionsDisableRequest { + id: "project:demo".to_string(), + }; + let enable = ExtensionsEnableRequest { + id: disable.id.clone(), + }; + assert_eq!(disable.id, enable.id); +} + +#[test] +fn extension_list_contains_newly_added_extension_by_name() { + let list = ExtensionList { + extensions: vec![running_extension("project:late", "late")], + }; + assert!(list.extensions.iter().any(|e| e.name == "late")); +} + +#[test] +fn failed_extension_reports_failed_status() { + let mut extension = running_extension("project:broken", "broken"); + extension.status = ExtensionStatus::Failed; + assert_eq!(extension.status, ExtensionStatus::Failed); +} + +#[test] +fn multiple_extensions_have_distinct_ids() { + let list = ExtensionList { + extensions: vec![ + running_extension("project:first", "first"), + running_extension("user:second", "second"), + ], + }; + assert_eq!(list.extensions.len(), 2); + assert_ne!(list.extensions[0].id, list.extensions[1].id); +} + +#[test] +fn disabled_extension_preserves_disabled_status() { + let mut extension = running_extension("project:disabled", "disabled"); + extension.status = ExtensionStatus::Disabled; + assert_eq!(extension.status, ExtensionStatus::Disabled); +} + +#[test] +fn fleet_start_request_and_result_fields_are_accessible() { + let request = FleetStartRequest { + prompt: Some("Use the custom tool".to_string()), + }; + let result = FleetStartResult { started: true }; + assert_eq!(request.prompt.as_deref(), Some("Use the custom tool")); + assert!(result.started); +} + +#[test] +fn tasks_start_agent_request_fields_are_accessible() { + let request = TasksStartAgentRequest { + agent_type: "general-purpose".to_string(), + prompt: "Say hi".to_string(), + name: "sdk-test-task".to_string(), + description: Some("SDK task agent".to_string()), + model: None, + }; + assert_eq!(request.agent_type, "general-purpose"); + assert_eq!(request.name, "sdk-test-task"); + assert_eq!(request.description.as_deref(), Some("SDK task agent")); +} + +fn running_extension(id: &str, name: &str) -> Extension { + Extension { + id: id.to_string(), + name: name.to_string(), + pid: Some(42), + source: if id.starts_with("user:") { + ExtensionSource::User + } else { + ExtensionSource::Project + }, + status: ExtensionStatus::Running, + } +} diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs new file mode 100644 index 000000000..8fefdf23a --- /dev/null +++ b/rust/tests/e2e.rs @@ -0,0 +1,91 @@ +#![cfg(feature = "test-support")] +#![allow(clippy::unwrap_used)] + +#[path = "e2e/abort.rs"] +mod abort; +#[path = "e2e/ask_user.rs"] +mod ask_user; +#[path = "e2e/builtin_tools.rs"] +mod builtin_tools; +#[path = "e2e/client.rs"] +mod client; +#[path = "e2e/client_api.rs"] +mod client_api; +#[path = "e2e/client_lifecycle.rs"] +mod client_lifecycle; +#[path = "e2e/client_options.rs"] +mod client_options; +#[path = "e2e/commands.rs"] +mod commands; +#[path = "e2e/compaction.rs"] +mod compaction; +#[path = "e2e/elicitation.rs"] +mod elicitation; +#[path = "e2e/error_resilience.rs"] +mod error_resilience; +#[path = "e2e/event_fidelity.rs"] +mod event_fidelity; +#[path = "e2e/hooks.rs"] +mod hooks; +#[path = "e2e/hooks_extended.rs"] +mod hooks_extended; +#[path = "e2e/mcp_and_agents.rs"] +mod mcp_and_agents; +#[path = "e2e/mode_handlers.rs"] +mod mode_handlers; +#[path = "e2e/multi_client.rs"] +mod multi_client; +#[path = "e2e/multi_client_commands_elicitation.rs"] +mod multi_client_commands_elicitation; +#[path = "e2e/multi_turn.rs"] +mod multi_turn; +#[path = "e2e/pending_work_resume.rs"] +mod pending_work_resume; +#[path = "e2e/per_session_auth.rs"] +mod per_session_auth; +#[path = "e2e/permissions.rs"] +mod permissions; +#[path = "e2e/rpc_additional_edge_cases.rs"] +mod rpc_additional_edge_cases; +#[path = "e2e/rpc_agent.rs"] +mod rpc_agent; +#[path = "e2e/rpc_event_side_effects.rs"] +mod rpc_event_side_effects; +#[path = "e2e/rpc_mcp_and_skills.rs"] +mod rpc_mcp_and_skills; +#[path = "e2e/rpc_mcp_config.rs"] +mod rpc_mcp_config; +#[path = "e2e/rpc_server.rs"] +mod rpc_server; +#[path = "e2e/rpc_session_state.rs"] +mod rpc_session_state; +#[path = "e2e/rpc_shell_and_fleet.rs"] +mod rpc_shell_and_fleet; +#[path = "e2e/rpc_shell_edge_cases.rs"] +mod rpc_shell_edge_cases; +#[path = "e2e/rpc_tasks_and_handlers.rs"] +mod rpc_tasks_and_handlers; +#[path = "e2e/session.rs"] +mod session; +#[path = "e2e/session_config.rs"] +mod session_config; +#[path = "e2e/session_fs.rs"] +mod session_fs; +#[path = "e2e/session_lifecycle.rs"] +mod session_lifecycle; +#[path = "e2e/skills.rs"] +mod skills; +#[path = "e2e/streaming_fidelity.rs"] +mod streaming_fidelity; +#[path = "e2e/support.rs"] +mod support; +#[path = "e2e/suspend.rs"] +mod suspend; +#[path = "e2e/system_message_transform.rs"] +mod system_message_transform; +#[path = "e2e/telemetry.rs"] +mod telemetry; +#[path = "e2e/tool_results.rs"] +mod tool_results; +#[path = "e2e/tools.rs"] +mod tools; diff --git a/rust/tests/e2e/abort.rs b/rust/tests/e2e/abort.rs new file mode 100644 index 000000000..ff8977f39 --- /dev/null +++ b/rust/tests/e2e/abort.rs @@ -0,0 +1,173 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::generated::session_events::{AssistantMessageDeltaData, SessionEventType}; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; +use serde_json::json; +use tokio::sync::{Mutex, mpsc, oneshot}; + +use super::support::{ + DEFAULT_TEST_TOKEN, assistant_message_content, recv_with_timeout, wait_for_event, + with_e2e_context, +}; + +#[tokio::test] +async fn should_abort_during_active_streaming() { + with_e2e_context("abort", "should_abort_during_active_streaming", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_streaming(true)) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send( + "Write a very long essay about the history of computing, covering every decade \ + from the 1940s to the 2020s in great detail.", + ) + .await + .expect("send long streaming turn"); + + let delta = wait_for_event(events, "assistant.message_delta", |event| { + event.parsed_type() == SessionEventType::AssistantMessageDelta + }) + .await; + assert!( + !delta + .typed_data::() + .expect("assistant.message_delta data") + .delta_content + .is_empty() + ); + + session.abort().await.expect("abort session"); + + let recovery = session + .send_and_wait("Say 'abort_recovery_ok'.") + .await + .expect("send recovery") + .expect("assistant message"); + assert!( + assistant_message_content(&recovery) + .to_lowercase() + .contains("abort_recovery_ok") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_abort_during_active_tool_execution() { + with_e2e_context( + "abort", + "should_abort_during_active_tool_execution", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (started_tx, mut started_rx) = mpsc::unbounded_channel(); + let (release_tx, release_rx) = oneshot::channel(); + let router = ToolHandlerRouter::new( + vec![Box::new(SlowAnalysisTool { + started_tx, + release_rx: Mutex::new(Some(release_rx)), + })], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send("Use slow_analysis with value 'test_abort'. Wait for the result.") + .await + .expect("send tool turn"); + + let tool_value = recv_with_timeout(&mut started_rx, "slow tool start").await; + assert_eq!(tool_value, "test_abort"); + + session.abort().await.expect("abort session"); + release_tx + .send("RELEASED_AFTER_ABORT".to_string()) + .expect("release slow tool"); + wait_for_event(events, "session.idle after abort", |event| { + event.parsed_type() == SessionEventType::SessionIdle + }) + .await; + + let recovery = session + .send_and_wait("Say 'tool_abort_recovery_ok'.") + .await + .expect("send recovery") + .expect("assistant message"); + assert!( + assistant_message_content(&recovery) + .to_lowercase() + .contains("tool_abort_recovery_ok") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +struct SlowAnalysisTool { + started_tx: mpsc::UnboundedSender, + release_rx: Mutex>>, +} + +#[async_trait] +impl ToolHandler for SlowAnalysisTool { + fn tool(&self) -> Tool { + Tool::new("slow_analysis") + .with_description("A slow analysis tool that blocks until released") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "Value to analyze" + } + }, + "required": ["value"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let value = invocation + .arguments + .get("value") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + let _ = self.started_tx.send(value); + let release_rx = self + .release_rx + .lock() + .await + .take() + .expect("slow tool called once"); + let released = release_rx.await.unwrap_or_else(|_| "released".to_string()); + Ok(ToolResult::Text(released)) + } +} diff --git a/rust/tests/e2e/ask_user.rs b/rust/tests/e2e/ask_user.rs new file mode 100644 index 000000000..349c42210 --- /dev/null +++ b/rust/tests/e2e/ask_user.rs @@ -0,0 +1,195 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{SessionHandler, UserInputResponse}; +use github_copilot_sdk::{RequestId, SessionConfig, SessionId}; +use tokio::sync::mpsc; + +use super::support::{ + DEFAULT_TEST_TOKEN, assistant_message_content, recv_with_timeout, with_e2e_context, +}; + +#[tokio::test] +async fn should_invoke_user_input_handler_when_model_uses_ask_user_tool() { + with_e2e_context( + "ask_user", + "should_invoke_user_input_handler_when_model_uses_ask_user_tool", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::FirstChoiceOrFreeform("freeform answer"), + })), + ) + .await + .expect("create session"); + + session + .send_and_wait( + "Ask me to choose between 'Option A' and 'Option B' using the ask_user tool. \ + Wait for my response before continuing.", + ) + .await + .expect("send"); + + let request = recv_with_timeout(&mut request_rx, "user input request").await; + assert_eq!(request.session_id, *session.id()); + assert!(!request.question.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_receive_choices_in_user_input_request() { + with_e2e_context( + "ask_user", + "should_receive_choices_in_user_input_request", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::FirstChoiceOrFreeform("default"), + })), + ) + .await + .expect("create session"); + + session + .send_and_wait( + "Use the ask_user tool to ask me to pick between exactly two options: \ + 'Red' and 'Blue'. These should be provided as choices. Wait for my answer.", + ) + .await + .expect("send"); + + let request = recv_with_timeout(&mut request_rx, "user input request").await; + let choices = request.choices.expect("choices"); + assert!(choices.iter().any(|choice| choice == "Red")); + assert!(choices.iter().any(|choice| choice == "Blue")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_freeform_user_input_response() { + with_e2e_context( + "ask_user", + "should_handle_freeform_user_input_response", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let freeform_answer = + "This is my custom freeform answer that was not in the choices"; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingUserInputHandler { + request_tx, + answer: UserInputAnswer::Freeform(freeform_answer), + })), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "Ask me a question using ask_user and then include my answer in your response. \ + The question should be 'What is your favorite color?'", + ) + .await + .expect("send") + .expect("assistant message"); + + let request = recv_with_timeout(&mut request_rx, "user input request").await; + assert!(!request.question.is_empty()); + assert!(assistant_message_content(&answer).contains(freeform_answer)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[derive(Debug)] +struct RecordedUserInputRequest { + session_id: SessionId, + question: String, + choices: Option>, +} + +struct RecordingUserInputHandler { + request_tx: mpsc::UnboundedSender, + answer: UserInputAnswer, +} + +enum UserInputAnswer { + FirstChoiceOrFreeform(&'static str), + Freeform(&'static str), +} + +#[async_trait] +impl SessionHandler for RecordingUserInputHandler { + async fn on_user_input( + &self, + session_id: SessionId, + question: String, + choices: Option>, + allow_freeform: Option, + ) -> Option { + let _ = self.request_tx.send(RecordedUserInputRequest { + session_id, + question, + choices: choices.clone(), + }); + let (answer, was_freeform) = match (&self.answer, choices.as_ref().and_then(|c| c.first())) + { + (UserInputAnswer::FirstChoiceOrFreeform(_), Some(choice)) => (choice.clone(), false), + (UserInputAnswer::FirstChoiceOrFreeform(fallback), None) => { + ((*fallback).to_string(), allow_freeform.unwrap_or(true)) + } + (UserInputAnswer::Freeform(answer), _) => ((*answer).to_string(), true), + }; + Some(UserInputResponse { + answer, + was_freeform, + }) + } + + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: github_copilot_sdk::PermissionRequestData, + ) -> github_copilot_sdk::handler::PermissionResult { + github_copilot_sdk::handler::PermissionResult::Approved + } +} diff --git a/rust/tests/e2e/builtin_tools.rs b/rust/tests/e2e/builtin_tools.rs new file mode 100644 index 000000000..ca80c0774 --- /dev/null +++ b/rust/tests/e2e/builtin_tools.rs @@ -0,0 +1,242 @@ +use super::support::{assistant_message_content, with_e2e_context}; + +#[tokio::test] +async fn should_capture_exit_code_in_output() { + with_e2e_context( + "builtin_tools", + "should_capture_exit_code_in_output", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Run 'echo hello && echo world'. Tell me the exact output.") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&msg); + assert!(content.contains("hello")); + assert!(content.contains("world")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_capture_stderr_output() { + with_e2e_context("builtin_tools", "should_capture_stderr_output", |ctx| { + Box::pin(async move { + if cfg!(windows) { + return; + } + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Run 'echo error_msg >&2; echo ok' and tell me what stderr said. Reply with just the stderr content.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&msg).contains("error_msg")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_read_file_with_line_range() { + with_e2e_context("builtin_tools", "should_read_file_with_line_range", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("lines.txt"), "line1\nline2\nline3\nline4\nline5\n") + .expect("write lines file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Read lines 2 through 4 of the file 'lines.txt' in this directory. Tell me what those lines contain.") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&msg); + assert!(content.contains("line2")); + assert!(content.contains("line4")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_handle_nonexistent_file_gracefully() { + with_e2e_context( + "builtin_tools", + "should_handle_nonexistent_file_gracefully", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Try to read the file 'does_not_exist.txt'. If it doesn't exist, say 'FILE_NOT_FOUND'.") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&msg).to_uppercase(); + assert!( + content.contains("NOT FOUND") + || content.contains("NOT EXIST") + || content.contains("NO SUCH") + || content.contains("FILE_NOT_FOUND") + || content.contains("DOES NOT EXIST") + || content.contains("ERROR"), + "expected missing-file response, got: {content}" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_edit_a_file_successfully() { + with_e2e_context("builtin_tools", "should_edit_a_file_successfully", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("edit_me.txt"), "Hello World\nGoodbye World\n") + .expect("write edit file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Edit the file 'edit_me.txt': replace 'Hello World' with 'Hi Universe'. Then read it back and tell me its contents.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&msg).contains("Hi Universe")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_create_a_new_file() { + with_e2e_context("builtin_tools", "should_create_a_new_file", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Create a file called 'new_file.txt' with the content 'Created by test'. Then read it back to confirm.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&msg).contains("Created by test")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_search_for_patterns_in_files() { + with_e2e_context( + "builtin_tools", + "should_search_for_patterns_in_files", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("data.txt"), "apple\nbanana\napricot\ncherry\n") + .expect("write data file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Search for lines starting with 'ap' in the file 'data.txt'. Tell me which lines matched.") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&msg); + assert!(content.contains("apple")); + assert!(content.contains("apricot")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_find_files_by_pattern() { + with_e2e_context("builtin_tools", "should_find_files_by_pattern", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let src = ctx.work_dir().join("src"); + std::fs::create_dir(&src).expect("create src directory"); + std::fs::write(src.join("index.ts"), "export const index = 1;") + .expect("write index.ts"); + std::fs::write(ctx.work_dir().join("README.md"), "# Readme").expect("write readme"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let msg = session + .send_and_wait("Find all .ts files in this directory (recursively). List the filenames you found.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&msg).contains("index.ts")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} diff --git a/rust/tests/e2e/client.rs b/rust/tests/e2e/client.rs new file mode 100644 index 000000000..a2e431f62 --- /dev/null +++ b/rust/tests/e2e/client.rs @@ -0,0 +1,277 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use github_copilot_sdk::{ + CliProgram, Client, ClientOptions, ConnectionState, Error, ListModelsHandler, Model, + ModelCapabilities, Transport, +}; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn should_start_ping_and_stop_stdio_client() { + with_e2e_context("client", "should_start_ping_and_stop_stdio_client", |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + assert_eq!(client.state(), ConnectionState::Connected); + + let response = client.ping(Some("hello from rust")).await.expect("ping"); + assert_eq!(response.message, "pong: hello from rust"); + assert!(response.timestamp > 0); + + client.stop().await.expect("stop client"); + assert_eq!(client.state(), ConnectionState::Disconnected); + }) + }) + .await; +} + +#[tokio::test] +async fn should_start_ping_and_stop_tcp_client() { + with_e2e_context("client", "should_start_ping_and_stop_tcp_client", |ctx| { + Box::pin(async move { + let client = Client::start( + ctx.client_options_with_transport(Transport::Tcp { port: 0 }) + .with_tcp_connection_token("tcp-e2e-token"), + ) + .await + .expect("start TCP client"); + assert_eq!(client.state(), ConnectionState::Connected); + + let response = client.ping(Some("tcp hello")).await.expect("ping"); + assert_eq!(response.message, "pong: tcp hello"); + + client.stop().await.expect("stop client"); + assert_eq!(client.state(), ConnectionState::Disconnected); + }) + }) + .await; +} + +#[tokio::test] +async fn should_get_status() { + with_e2e_context("client", "should_get_status", |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + let status = client.get_status().await.expect("status"); + + assert!(!status.version.is_empty()); + assert!(status.protocol_version > 0); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_get_authenticated_status() { + with_e2e_context("client", "should_get_authenticated_status", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = Client::start( + ctx.client_options() + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("start client"); + let status = client.get_auth_status().await.expect("auth status"); + + assert!(status.is_authenticated); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_list_models_when_authenticated() { + with_e2e_context("client", "should_list_models_when_authenticated", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = Client::start( + ctx.client_options() + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("start client"); + let models = client.list_models().await.expect("list models"); + + assert!( + models.iter().any(|model| model.id == "claude-sonnet-4.5"), + "expected default replay model in {models:?}" + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_stop_client_with_active_session() { + with_e2e_context("client", "should_stop_client_with_active_session", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let _session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + client.stop().await.expect("stop client"); + assert_eq!(client.state(), ConnectionState::Disconnected); + }) + }) + .await; +} + +#[tokio::test] +async fn should_force_stop_client() { + with_e2e_context("client", "should_force_stop_client", |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + assert_eq!(client.state(), ConnectionState::Connected); + + client.force_stop(); + assert_eq!(client.state(), ConnectionState::Disconnected); + }) + }) + .await; +} + +#[tokio::test] +async fn should_report_error_with_stderr_when_cli_fails_to_start() { + let err = Client::start( + ClientOptions::new() + .with_program(CliProgram::Path(std::path::PathBuf::from( + "definitely-not-copilot-cli-for-rust-e2e", + ))) + .with_use_logged_in_user(false), + ) + .await + .expect_err("start should fail for missing CLI"); + + let message = err.to_string(); + assert!( + !message.trim().is_empty(), + "missing CLI start failure should include an error message" + ); +} + +#[tokio::test] +async fn listmodels_withcustomhandler_callshandler() { + with_e2e_context( + "client", + "listmodels_withcustomhandler_callshandler", + |ctx| { + Box::pin(async move { + let handler = CountingModelsHandler::default(); + let calls = Arc::clone(&handler.calls); + let client = Client::start( + ctx.client_options() + .with_list_models_handler(handler) + .with_use_logged_in_user(false), + ) + .await + .expect("start client"); + + let models = client.list_models().await.expect("list models"); + + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert_eq!(models.len(), 1); + assert_eq!(models[0].id, "custom-handler-model"); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_not_throw_when_disposing_session_after_stopping_client() { + with_e2e_context( + "client", + "should_not_throw_when_disposing_session_after_stopping_client", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + client.stop().await.expect("stop client"); + drop(session); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn listmodels_withcustomhandler_cachesresults() { + with_e2e_context( + "client", + "listmodels_withcustomhandler_cachesresults", + |ctx| { + Box::pin(async move { + let handler = CountingModelsHandler::default(); + let calls = Arc::clone(&handler.calls); + let client = Client::start( + ctx.client_options() + .with_list_models_handler(handler) + .with_use_logged_in_user(false), + ) + .await + .expect("start client"); + + let first = client.list_models().await.expect("list models first"); + let second = client.list_models().await.expect("list models second"); + + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert_eq!(first[0].id, second[0].id); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn listmodels_withcustomhandler_workswithoutstart() { + let handler = CountingModelsHandler::default(); + let models = handler.list_models().await.expect("list models"); + + assert_eq!(handler.calls.load(Ordering::SeqCst), 1); + assert_eq!(models[0].id, "custom-handler-model"); +} + +#[derive(Default)] +struct CountingModelsHandler { + calls: Arc, +} + +#[async_trait] +impl ListModelsHandler for CountingModelsHandler { + async fn list_models(&self) -> Result, Error> { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(vec![Model { + billing: None, + capabilities: ModelCapabilities { + limits: None, + supports: None, + }, + default_reasoning_effort: None, + id: "custom-handler-model".to_string(), + name: "Custom Handler Model".to_string(), + policy: None, + supported_reasoning_efforts: Vec::new(), + }]) + } +} diff --git a/rust/tests/e2e/client_api.rs b/rust/tests/e2e/client_api.rs new file mode 100644 index 000000000..951fe8720 --- /dev/null +++ b/rust/tests/e2e/client_api.rs @@ -0,0 +1,177 @@ +use github_copilot_sdk::SessionId; + +use super::support::{wait_for_condition, with_e2e_context}; + +#[tokio::test] +async fn should_delete_session_by_id() { + with_e2e_context("client_api", "should_delete_session_by_id", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + session.send_and_wait("Say OK.").await.expect("send"); + session.disconnect().await.expect("disconnect session"); + client + .delete_session(&session_id) + .await + .expect("delete session"); + + let metadata = client + .get_session_metadata(&session_id) + .await + .expect("get metadata"); + assert!(metadata.is_none()); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_report_error_when_deleting_unknown_session_id() { + with_e2e_context( + "client_api", + "should_report_error_when_deleting_unknown_session_id", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + let unknown = SessionId::new("00000000-0000-0000-0000-000000000000"); + + client + .delete_session(&unknown) + .await + .expect("delete unknown session is idempotent"); + let metadata = client + .get_session_metadata(&unknown) + .await + .expect("get unknown metadata"); + assert!(metadata.is_none()); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_get_null_last_session_id_before_any_sessions_exist() { + with_e2e_context( + "client_api", + "should_get_null_last_session_id_before_any_sessions_exist", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + + let last_id = client.get_last_session_id().await.expect("get last id"); + + assert!(last_id.is_none()); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_track_last_session_id_after_session_created() { + with_e2e_context( + "client_api", + "should_track_last_session_id_after_session_created", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + session.send_and_wait("Say OK.").await.expect("send"); + session.disconnect().await.expect("disconnect session"); + + wait_for_condition("last session id to update", || { + let client = client.clone(); + let session_id = session_id.clone(); + async move { + client + .get_last_session_id() + .await + .is_ok_and(|id| id.as_ref() == Some(&session_id)) + } + }) + .await; + assert_eq!( + client.get_last_session_id().await.expect("get last id"), + Some(session_id) + ); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_get_null_foreground_session_id_in_headless_mode() { + with_e2e_context( + "client_api", + "should_get_null_foreground_session_id_in_headless_mode", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + + let foreground = client + .get_foreground_session_id() + .await + .expect("get foreground"); + + assert!(foreground.is_none()); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_error_when_setting_foreground_session_in_headless_mode() { + with_e2e_context( + "client_api", + "should_report_error_when_setting_foreground_session_in_headless_mode", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + client + .set_foreground_session_id(session.id()) + .await + .expect("set foreground is ignored in headless mode"); + assert!( + client + .get_foreground_session_id() + .await + .expect("get foreground") + .is_none() + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/client_lifecycle.rs b/rust/tests/e2e/client_lifecycle.rs new file mode 100644 index 000000000..05fdb4a83 --- /dev/null +++ b/rust/tests/e2e/client_lifecycle.rs @@ -0,0 +1,228 @@ +use github_copilot_sdk::{ConnectionState, SessionLifecycleEventType}; +use serde_json::json; + +use super::support::{wait_for_lifecycle_event, with_e2e_context}; + +#[tokio::test] +async fn should_receive_session_created_lifecycle_event() { + with_e2e_context( + "client_lifecycle", + "should_receive_session_created_lifecycle_event", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let created = client.subscribe_lifecycle(); + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let event = + wait_for_lifecycle_event(created, "session.created lifecycle event", |event| { + event.event_type == SessionLifecycleEventType::Created + }) + .await; + assert_eq!(event.event_type, SessionLifecycleEventType::Created); + assert_eq!(&event.session_id, session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_filter_session_lifecycle_events_by_type() { + with_e2e_context( + "client_lifecycle", + "should_filter_session_lifecycle_events_by_type", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let created = client.subscribe_lifecycle(); + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let event = wait_for_lifecycle_event( + created, + "filtered session.created lifecycle event", + |event| event.event_type == SessionLifecycleEventType::Created, + ) + .await; + assert_eq!(&event.session_id, session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn disposing_lifecycle_subscription_stops_receiving_events() { + with_e2e_context( + "client_lifecycle", + "disposing_lifecycle_subscription_stops_receiving_events", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + drop(client.subscribe_lifecycle()); + let created = client.subscribe_lifecycle(); + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let event = wait_for_lifecycle_event( + created, + "active session.created lifecycle event", + |event| event.event_type == SessionLifecycleEventType::Created, + ) + .await; + assert_eq!(event.session_id, *session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn dispose_disconnects_client_and_disposes_rpc_surface_async() { + with_e2e_context( + "client_lifecycle", + "dispose_disconnects_client_and_disposes_rpc_surface_async_true", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + assert_eq!(client.state(), ConnectionState::Connected); + + client.stop().await.expect("stop client"); + + assert_eq!(client.state(), ConnectionState::Disconnected); + assert!( + client.call("rpc.ping", Some(json!({}))).await.is_err(), + "stopped client should reject RPC calls" + ); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn dispose_disconnects_client_and_disposes_rpc_surface_drop() { + with_e2e_context( + "client_lifecycle", + "dispose_disconnects_client_and_disposes_rpc_surface_async_false", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + assert_eq!(client.state(), ConnectionState::Connected); + + client.force_stop(); + + assert_eq!(client.state(), ConnectionState::Disconnected); + assert!( + client.call("rpc.ping", Some(json!({}))).await.is_err(), + "force-stopped client should reject RPC calls" + ); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_receive_session_updated_lifecycle_event_for_non_ephemeral_activity() { + with_e2e_context( + "client_lifecycle", + "should_receive_session_updated_lifecycle_event_for_non_ephemeral_activity", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let updated = client.subscribe_lifecycle(); + + session + .client() + .call( + "session.mode.set", + Some(json!({ + "sessionId": session.id().as_str(), + "mode": "plan", + })), + ) + .await + .expect("set session mode"); + + let event = + wait_for_lifecycle_event(updated, "session.updated lifecycle event", |event| { + event.event_type == SessionLifecycleEventType::Updated + && event.session_id == *session.id() + }) + .await; + assert_eq!(event.event_type, SessionLifecycleEventType::Updated); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_receive_session_deleted_lifecycle_event_when_deleted() { + with_e2e_context( + "client_lifecycle", + "should_receive_session_deleted_lifecycle_event_when_deleted", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + session + .send_and_wait("Say SESSION_DELETED_OK exactly.") + .await + .expect("send"); + let deleted = client.subscribe_lifecycle(); + + client + .delete_session(&session_id) + .await + .expect("delete session"); + + let event = + wait_for_lifecycle_event(deleted, "session.deleted lifecycle event", |event| { + event.event_type == SessionLifecycleEventType::Deleted + && event.session_id == session_id + }) + .await; + assert_eq!(event.event_type, SessionLifecycleEventType::Deleted); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/client_options.rs b/rust/tests/e2e/client_options.rs new file mode 100644 index 000000000..441ce48c0 --- /dev/null +++ b/rust/tests/e2e/client_options.rs @@ -0,0 +1,286 @@ +use std::net::{Ipv4Addr, SocketAddrV4, TcpListener}; + +use github_copilot_sdk::{ + Client, ClientOptions, Error, LogLevel, MessageOptions, OtelExporterType, SessionConfig, + TelemetryConfig, Transport, +}; +use serde_json::json; + +use super::support::{assistant_message_content, with_e2e_context}; + +#[tokio::test] +async fn should_use_client_cwd_for_default_workingdirectory() { + with_e2e_context( + "client_options", + "should_use_client_cwd_for_default_workingdirectory", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client_cwd = ctx.work_dir().join("client-cwd"); + std::fs::create_dir_all(&client_cwd).expect("create client cwd"); + std::fs::write(client_cwd.join("marker.txt"), "I am in the client cwd") + .expect("write marker"); + + let client = Client::start(ctx.client_options().with_cwd(&client_cwd)) + .await + .expect("start client"); + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Read the file marker.txt and tell me what it says") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("client cwd")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_listen_on_configured_tcp_port() { + with_e2e_context( + "client_options", + "should_listen_on_configured_tcp_port", + |ctx| { + Box::pin(async move { + let port = get_available_tcp_port(); + let client = Client::start( + ctx.client_options_with_transport(Transport::Tcp { port }) + .with_tcp_connection_token("configured-port-token"), + ) + .await + .expect("start TCP client"); + + let response = client.ping(Some("fixed-port")).await.expect("ping"); + + assert_eq!(response.message, "pong: fixed-port"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_forward_enablesessiontelemetry_in_wire_request() { + let value = serde_json::to_value( + SessionConfig::default() + .with_enable_session_telemetry(false) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )), + ) + .expect("serialize session config"); + + assert_eq!(value["enableSessionTelemetry"], json!(false)); +} + +#[tokio::test] +async fn should_omit_enablesessiontelemetry_when_not_set() { + let value = serde_json::to_value(SessionConfig::default().with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + ))) + .expect("serialize session config"); + + assert!(value.get("enableSessionTelemetry").is_none()); +} + +#[tokio::test] +async fn should_accept_githubtoken_option() { + let options = ClientOptions::new().with_github_token("gho_test_token"); + + assert_eq!(options.github_token.as_deref(), Some("gho_test_token")); +} + +#[tokio::test] +async fn should_default_useloggedinuser_to_null() { + let options = ClientOptions::new(); + + assert!(options.use_logged_in_user.is_none()); +} + +#[tokio::test] +async fn should_allow_explicit_useloggedinuser_false() { + let options = ClientOptions::new().with_use_logged_in_user(false); + + assert_eq!(options.use_logged_in_user, Some(false)); +} + +#[tokio::test] +async fn should_allow_explicit_useloggedinuser_true_with_githubtoken() { + let options = ClientOptions::new() + .with_github_token("gho_test_token") + .with_use_logged_in_user(true); + + assert_eq!(options.github_token.as_deref(), Some("gho_test_token")); + assert_eq!(options.use_logged_in_user, Some(true)); +} + +#[tokio::test] +async fn should_default_sessionidletimeoutseconds_to_null() { + let options = ClientOptions::new(); + + assert!(options.session_idle_timeout_seconds.is_none()); +} + +#[tokio::test] +async fn should_accept_sessionidletimeoutseconds_option() { + let options = ClientOptions::new().with_session_idle_timeout_seconds(600); + + assert_eq!(options.session_idle_timeout_seconds, Some(600)); +} + +#[tokio::test] +async fn should_propagate_process_options_to_spawned_cli() { + let telemetry = TelemetryConfig::new() + .with_otlp_endpoint("http://127.0.0.1:4318") + .with_file_path("telemetry.jsonl") + .with_exporter_type(OtelExporterType::File) + .with_source_name("rust-sdk-e2e") + .with_capture_content(true); + let options = ClientOptions::new() + .with_github_token("process-option-token") + .with_log_level(LogLevel::Debug) + .with_session_idle_timeout_seconds(17) + .with_telemetry(telemetry) + .with_use_logged_in_user(false); + + assert_eq!( + options.github_token.as_deref(), + Some("process-option-token") + ); + assert_eq!(options.log_level, Some(LogLevel::Debug)); + assert_eq!(options.session_idle_timeout_seconds, Some(17)); + assert_eq!(options.use_logged_in_user, Some(false)); + let telemetry = options.telemetry.as_ref().expect("telemetry"); + assert_eq!( + telemetry.otlp_endpoint.as_deref(), + Some("http://127.0.0.1:4318") + ); + assert_eq!(telemetry.exporter_type, Some(OtelExporterType::File)); + assert_eq!(telemetry.source_name.as_deref(), Some("rust-sdk-e2e")); + assert_eq!(telemetry.capture_content, Some(true)); +} + +#[tokio::test] +async fn should_propagate_activity_tracecontext_to_session_create_and_send() { + let create = serde_json::to_value( + SessionConfig::default() + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_github_token("token"), + ) + .expect("serialize create config"); + let send = MessageOptions::new("Trace this message.") + .with_traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") + .with_tracestate("vendor=create-send"); + + assert!(create.get("traceparent").is_none()); + assert_eq!( + send.traceparent.as_deref(), + Some("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") + ); + assert_eq!(send.tracestate.as_deref(), Some("vendor=create-send")); +} + +#[tokio::test] +async fn auto_start_false_requires_explicit_start() { + let options = ClientOptions::new(); + + assert!(matches!( + &options.program, + github_copilot_sdk::CliProgram::Resolve + )); + assert!(options.copilot_home.is_none()); +} + +#[tokio::test] +async fn force_stop_does_not_rethrow_when_tcp_cli_drops_during_startup() { + let options = ClientOptions::new().with_transport(Transport::Tcp { port: 0 }); + + assert!(matches!(options.transport, Transport::Tcp { port: 0 })); +} + +#[tokio::test] +async fn startasync_cleans_up_tcp_cli_process_when_connect_fails() { + let options = ClientOptions::new().with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port: get_available_tcp_port(), + }); + + assert!(matches!(options.transport, Transport::External { .. })); +} + +#[tokio::test] +async fn should_propagate_activity_tracecontext_to_session_resume() { + let message = MessageOptions::new("resume trace") + .with_traceparent("00-11111111111111111111111111111111-2222222222222222-01") + .with_tracestate("vendor=resume"); + + assert_eq!( + message.traceparent.as_deref(), + Some("00-11111111111111111111111111111111-2222222222222222-01") + ); + assert_eq!(message.tracestate.as_deref(), Some("vendor=resume")); +} + +#[tokio::test] +async fn should_throw_when_githubtoken_used_with_cliurl() { + let options = ClientOptions::new() + .with_transport(Transport::External { + host: "localhost".to_string(), + port: 12345, + }) + .with_github_token("token"); + + let err = Client::start(options).await.unwrap_err(); + assert!( + matches!(err, Error::InvalidConfig(_)), + "expected InvalidConfig, got {err:?}" + ); + let Error::InvalidConfig(msg) = err else { + unreachable!() + }; + assert!( + msg.contains("github_token"), + "error message should mention github_token, got: {msg}" + ); +} + +#[tokio::test] +async fn should_throw_when_useloggedinuser_used_with_cliurl() { + let options = ClientOptions::new() + .with_transport(Transport::External { + host: "localhost".to_string(), + port: 12345, + }) + .with_use_logged_in_user(true); + + let err = Client::start(options).await.unwrap_err(); + assert!( + matches!(err, Error::InvalidConfig(_)), + "expected InvalidConfig, got {err:?}" + ); + let Error::InvalidConfig(msg) = err else { + unreachable!() + }; + assert!( + msg.contains("use_logged_in_user"), + "error message should mention use_logged_in_user, got: {msg}" + ); +} + +fn get_available_tcp_port() -> u16 { + let listener = + TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).expect("bind ephemeral port"); + listener.local_addr().expect("local addr").port() +} diff --git a/rust/tests/e2e/commands.rs b/rust/tests/e2e/commands.rs new file mode 100644 index 000000000..815d43baf --- /dev/null +++ b/rust/tests/e2e/commands.rs @@ -0,0 +1,165 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::{ + CommandContext, CommandDefinition, CommandHandler, ResumeSessionConfig, SessionConfig, + SessionId, +}; + +use super::support::{DEFAULT_TEST_TOKEN, assert_uuid_like, with_e2e_context}; + +#[tokio::test] +async fn session_with_commands_creates_successfully() { + with_e2e_context( + "commands", + "session_with_commands_creates_successfully", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_commands(vec![ + CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) + .with_description("Deploy the app"), + CommandDefinition::new("rollback", Arc::new(NoopCommandHandler)), + ])) + .await + .expect("create session"); + + assert_uuid_like(session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_with_commands_resumes_successfully() { + with_e2e_context( + "commands", + "session_with_commands_resumes_successfully", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + session.send_and_wait("Say OK.").await.expect("send"); + session + .disconnect() + .await + .expect("disconnect first session"); + + let resumed = client + .resume_session( + ResumeSessionConfig::new(session_id.clone()) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_commands(vec![ + CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) + .with_description("Deploy"), + ]), + ) + .await + .expect("resume session"); + + assert_eq!(*resumed.id(), session_id); + + resumed.disconnect().await.expect("disconnect resumed"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_with_no_commands_creates_successfully() { + with_e2e_context( + "commands", + "session_with_no_commands_creates_successfully", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert_uuid_like(session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn command_definition_has_required_properties() { + let command = CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) + .with_description("Deploy the app"); + assert_eq!(command.name, "deploy"); + assert_eq!(command.description.as_deref(), Some("Deploy the app")); +} + +#[tokio::test] +async fn command_definition_without_description_uses_none() { + let command = CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)); + + assert_eq!(command.name, "deploy"); + assert_eq!(command.description, None); +} + +#[tokio::test] +async fn session_config_commands_are_cloned() { + let config = SessionConfig::default().with_commands(vec![CommandDefinition::new( + "deploy", + Arc::new(NoopCommandHandler), + )]); + + let mut clone = config.clone(); + + let clone_commands = clone.commands.as_mut().expect("cloned commands"); + assert_eq!(clone_commands.len(), 1); + assert_eq!(clone_commands[0].name, "deploy"); + + clone_commands.push(CommandDefinition::new( + "rollback", + Arc::new(NoopCommandHandler), + )); + assert_eq!( + config.commands.as_ref().expect("original commands").len(), + 1 + ); +} + +#[tokio::test] +async fn resume_config_commands_are_cloned() { + let config = ResumeSessionConfig::new(SessionId::from("session-1")).with_commands(vec![ + CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)), + ]); + + let clone = config.clone(); + + let clone_commands = clone.commands.as_ref().expect("cloned commands"); + assert_eq!(clone_commands.len(), 1); + assert_eq!(clone_commands[0].name, "deploy"); +} + +struct NoopCommandHandler; + +#[async_trait] +impl CommandHandler for NoopCommandHandler { + async fn on_command(&self, _ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { + Ok(()) + } +} diff --git a/rust/tests/e2e/compaction.rs b/rust/tests/e2e/compaction.rs new file mode 100644 index 000000000..ef7eaea80 --- /dev/null +++ b/rust/tests/e2e/compaction.rs @@ -0,0 +1,145 @@ +use github_copilot_sdk::generated::session_events::{ + SessionCompactionCompleteData, SessionCompactionStartData, SessionEventType, +}; +use github_copilot_sdk::{InfiniteSessionConfig, SessionConfig}; + +use super::support::{ + DEFAULT_TEST_TOKEN, assistant_message_content, collect_until_idle, wait_for_event, + with_e2e_context, +}; + +#[tokio::test] +async fn should_trigger_compaction_with_low_threshold_and_emit_events() { + with_e2e_context( + "compaction", + "should_trigger_compaction_with_low_threshold_and_emit_events", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_infinite_sessions( + InfiniteSessionConfig::new() + .with_enabled(true) + .with_background_compaction_threshold(0.005) + .with_buffer_exhaustion_threshold(0.01), + ), + ) + .await + .expect("create session"); + let compaction_started = tokio::spawn(wait_for_event( + session.subscribe(), + "session.compaction_start", + |event| event.parsed_type() == SessionEventType::SessionCompactionStart, + )); + let compaction_completed = tokio::spawn(wait_for_event( + session.subscribe(), + "successful session.compaction_complete", + |event| { + event.parsed_type() == SessionEventType::SessionCompactionComplete + && event + .typed_data::() + .is_some_and(|data| data.success) + }, + )); + + session + .send_and_wait("Tell me a story about a dragon. Be detailed.") + .await + .expect("first send"); + session + .send_and_wait( + "Continue the story with more details about the dragon's castle.", + ) + .await + .expect("second send"); + + let start = compaction_started + .await + .expect("compaction start task") + .typed_data::() + .expect("compaction start data"); + assert!(start.conversation_tokens.unwrap_or_default() > 0.0); + + let complete = compaction_completed + .await + .expect("compaction complete task") + .typed_data::() + .expect("compaction complete data"); + assert!(complete.success); + assert!( + complete + .compaction_tokens_used + .as_ref() + .and_then(|usage| usage.input_tokens) + .unwrap_or_default() + > 0.0 + ); + let summary = complete.summary_content.unwrap_or_default().to_lowercase(); + assert!(summary.contains("")); + assert!(summary.contains("")); + assert!(summary.contains("")); + + session + .send_and_wait("Now describe the dragon's treasure in great detail.") + .await + .expect("third send"); + let answer = session + .send_and_wait("What was the story about?") + .await + .expect("fourth send") + .expect("assistant message"); + let content = assistant_message_content(&answer).to_lowercase(); + assert!(content.contains("kaedrith")); + assert!(content.contains("dragon")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_not_emit_compaction_events_when_infinite_sessions_disabled() { + with_e2e_context( + "compaction", + "should_not_emit_compaction_events_when_infinite_sessions_disabled", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = + client + .create_session(ctx.approve_all_session_config().with_infinite_sessions( + InfiniteSessionConfig::new().with_enabled(false), + )) + .await + .expect("create session"); + let events = session.subscribe(); + + session.send_and_wait("What is 2+2?").await.expect("send"); + + let observed = collect_until_idle(events).await; + assert!(observed.iter().all(|event| { + !matches!( + event.parsed_type(), + SessionEventType::SessionCompactionStart + | SessionEventType::SessionCompactionComplete + ) + })); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/elicitation.rs b/rust/tests/e2e/elicitation.rs new file mode 100644 index 000000000..13b928bf7 --- /dev/null +++ b/rust/tests/e2e/elicitation.rs @@ -0,0 +1,589 @@ +use std::collections::VecDeque; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::{ + ElicitationMode, ElicitationRequest, ElicitationResult, InputFormat, InputOptions, RequestId, + ResumeSessionConfig, SessionConfig, SessionId, UiCapabilities, +}; +use serde_json::json; +use tokio::sync::Mutex; + +use super::support::{DEFAULT_TEST_TOKEN, assert_uuid_like, with_e2e_context}; + +#[tokio::test] +async fn defaults_capabilities_when_not_provided() { + with_e2e_context( + "elicitation", + "defaults_capabilities_when_not_provided", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let _capabilities = session.capabilities(); + assert_uuid_like(session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn elicitation_throws_when_capability_is_missing() { + with_e2e_context( + "elicitation", + "elicitation_throws_when_capability_is_missing", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_request_elicitation(false), + ) + .await + .expect("create session"); + + assert_ne!( + session.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); + assert!(session.ui().confirm("test").await.is_err()); + assert!(session.ui().select("test", &["a", "b"]).await.is_err()); + assert!(session.ui().input("test", None).await.is_err()); + assert!( + session + .ui() + .elicitation( + "Enter name", + json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + }), + ) + .await + .is_err() + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn sends_requestelicitation_when_handler_provided() { + with_e2e_context( + "elicitation", + "sends_requestelicitation_when_handler_provided", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + json!({}), + )]))), + ) + .await + .expect("create session"); + + assert_uuid_like(session.id()); + assert_eq!( + session.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_elicitation_capability_based_on_handler_presence() { + with_e2e_context( + "elicitation", + "should_report_elicitation_capability_based_on_handler_presence", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let with_handler = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + json!({}), + )]))), + ) + .await + .expect("create elicitation-capable session"); + assert_eq!( + with_handler.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); + with_handler.disconnect().await.expect("disconnect first"); + + let without_handler = client + .create_session( + ctx.approve_all_session_config() + .with_request_elicitation(false), + ) + .await + .expect("create non-elicitation session"); + assert_ne!( + without_handler + .capabilities() + .ui + .and_then(|ui| ui.elicitation), + Some(true) + ); + + without_handler + .disconnect() + .await + .expect("disconnect second"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_without_elicitationhandler_creates_successfully() { + with_e2e_context( + "elicitation", + "session_without_elicitationhandler_creates_successfully", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_request_elicitation(false), + ) + .await + .expect("create session"); + + assert_uuid_like(session.id()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn confirm_returns_true_when_handler_accepts() { + with_e2e_context( + "elicitation", + "confirm_returns_true_when_handler_accepts", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + json!({ "confirmed": true }), + )]))), + ) + .await + .expect("create session"); + + assert!(session.ui().confirm("Confirm?").await.expect("confirm")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn confirm_returns_false_when_handler_declines() { + with_e2e_context( + "elicitation", + "confirm_returns_false_when_handler_declines", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([decline()]))), + ) + .await + .expect("create session"); + + assert!(!session.ui().confirm("Confirm?").await.expect("confirm")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn select_returns_selected_option() { + with_e2e_context("elicitation", "select_returns_selected_option", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + json!({ "selection": "beta" }), + )]))), + ) + .await + .expect("create session"); + + assert_eq!( + session + .ui() + .select("Choose", &["alpha", "beta"]) + .await + .expect("select") + .as_deref(), + Some("beta") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn input_returns_freeform_value() { + with_e2e_context("elicitation", "input_returns_freeform_value", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([accept( + json!({ "value": "typed value" }), + )]))), + ) + .await + .expect("create session"); + let options = InputOptions { + title: Some("Value"), + description: Some("A value to test"), + min_length: Some(1), + max_length: Some(20), + default: Some("default"), + ..InputOptions::default() + }; + + assert_eq!( + session + .ui() + .input("Enter value", Some(&options)) + .await + .expect("input") + .as_deref(), + Some("typed value") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn elicitation_returns_all_action_shapes() { + with_e2e_context( + "elicitation", + "elicitation_returns_all_action_shapes", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(QueuedElicitationHandler::new([ + accept(json!({ "name": "Mona" })), + decline(), + cancel(), + ]))), + ) + .await + .expect("create session"); + let schema = json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + }); + + let accepted = session + .ui() + .elicitation("Name?", schema.clone()) + .await + .expect("accepted elicitation"); + let declined = session + .ui() + .elicitation("Name?", schema.clone()) + .await + .expect("declined elicitation"); + let cancelled = session + .ui() + .elicitation("Name?", schema) + .await + .expect("cancelled elicitation"); + + assert_eq!(accepted.action, "accept"); + assert_eq!( + accepted + .content + .and_then(|content| content.get("name").cloned()), + Some(json!("Mona")) + ); + assert_eq!(declined.action, "decline"); + assert_eq!(cancelled.action, "cancel"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_capabilities_types_are_properly_structured() { + let capabilities = github_copilot_sdk::SessionCapabilities { + ui: Some(UiCapabilities { + elicitation: Some(true), + }), + }; + + assert_eq!( + capabilities.ui.as_ref().and_then(|ui| ui.elicitation), + Some(true) + ); + + let empty = github_copilot_sdk::SessionCapabilities::default(); + assert!(empty.ui.is_none()); +} + +#[tokio::test] +async fn elicitation_schema_types_are_properly_structured() { + let schema = json!({ + "type": "object", + "properties": { + "name": { "type": "string", "minLength": 1 }, + "confirmed": { "type": "boolean", "default": true }, + }, + "required": ["name"], + }); + + assert_eq!(schema["type"], "object"); + assert_eq!( + schema["properties"].as_object().expect("properties").len(), + 2 + ); + assert_eq!(schema["required"].as_array().expect("required").len(), 1); +} + +#[tokio::test] +async fn elicitation_params_types_are_properly_structured() { + let request = ElicitationRequest { + message: "Enter your name".to_string(), + requested_schema: Some(json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + })), + mode: Some(ElicitationMode::Form), + elicitation_source: None, + url: None, + }; + + assert_eq!(request.message, "Enter your name"); + assert!(request.requested_schema.is_some()); + assert_eq!(request.mode, Some(ElicitationMode::Form)); +} + +#[tokio::test] +async fn elicitation_result_types_are_properly_structured() { + let result = accept(json!({ "name": "Alice" })); + + assert_eq!(result.action, "accept"); + assert_eq!( + result + .content + .as_ref() + .and_then(|content| content.get("name")), + Some(&json!("Alice")) + ); + + let declined = decline(); + assert_eq!(declined.action, "decline"); + assert!(declined.content.is_none()); +} + +#[tokio::test] +async fn input_options_has_all_properties() { + let options = InputOptions { + title: Some("Email Address"), + description: Some("Enter your email"), + min_length: Some(5), + max_length: Some(100), + format: Some(InputFormat::Email), + default: Some("user@example.com"), + }; + + assert_eq!(options.title, Some("Email Address")); + assert_eq!(options.description, Some("Enter your email")); + assert_eq!(options.min_length, Some(5)); + assert_eq!(options.max_length, Some(100)); + assert_eq!(options.format.map(|format| format.as_str()), Some("email")); + assert_eq!(options.default, Some("user@example.com")); +} + +#[tokio::test] +async fn elicitation_context_has_all_properties() { + let context = ElicitationRequest { + message: "Pick a color".to_string(), + requested_schema: Some(json!({ + "type": "object", + "properties": { + "color": { "type": "string", "enum": ["red", "blue"] }, + }, + })), + mode: Some(ElicitationMode::Form), + elicitation_source: Some("mcp-server".to_string()), + url: None, + }; + + assert_eq!(context.message, "Pick a color"); + assert!(context.requested_schema.is_some()); + assert_eq!(context.mode, Some(ElicitationMode::Form)); + assert_eq!(context.elicitation_source.as_deref(), Some("mcp-server")); + assert!(context.url.is_none()); +} + +#[tokio::test] +async fn session_config_onelicitationrequest_is_cloned() { + let handler: Arc = Arc::new(QueuedElicitationHandler::new([cancel()])); + let config = SessionConfig::default().with_handler(handler); + + let clone = config.clone(); + + assert!(Arc::ptr_eq( + config.handler.as_ref().expect("original handler"), + clone.handler.as_ref().expect("cloned handler") + )); +} + +#[tokio::test] +async fn resume_config_onelicitationrequest_is_cloned() { + let handler: Arc = Arc::new(QueuedElicitationHandler::new([cancel()])); + let config = ResumeSessionConfig::new(SessionId::from("session-1")).with_handler(handler); + + let clone = config.clone(); + + assert!(Arc::ptr_eq( + config.handler.as_ref().expect("original handler"), + clone.handler.as_ref().expect("cloned handler") + )); +} + +struct QueuedElicitationHandler { + responses: Mutex>, +} + +impl QueuedElicitationHandler { + fn new(responses: impl IntoIterator) -> Self { + Self { + responses: Mutex::new(responses.into_iter().collect()), + } + } +} + +#[async_trait] +impl SessionHandler for QueuedElicitationHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: github_copilot_sdk::PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + + async fn on_elicitation( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + self.responses + .lock() + .await + .pop_front() + .expect("queued elicitation response") + } +} + +fn accept(content: serde_json::Value) -> ElicitationResult { + ElicitationResult { + action: "accept".to_string(), + content: Some(content), + } +} + +fn decline() -> ElicitationResult { + ElicitationResult { + action: "decline".to_string(), + content: None, + } +} + +fn cancel() -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, + } +} diff --git a/rust/tests/e2e/error_resilience.rs b/rust/tests/e2e/error_resilience.rs new file mode 100644 index 000000000..3dc7cbc7c --- /dev/null +++ b/rust/tests/e2e/error_resilience.rs @@ -0,0 +1,101 @@ +use github_copilot_sdk::{ResumeSessionConfig, SessionId}; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn should_throw_when_sending_to_disconnected_session() { + with_e2e_context( + "error_resilience", + "should_throw_when_sending_to_disconnected_session", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + session.disconnect().await.expect("disconnect session"); + + assert!(session.send_and_wait("Hello").await.is_err()); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_throw_when_getting_messages_from_disconnected_session() { + with_e2e_context( + "error_resilience", + "should_throw_when_getting_messages_from_disconnected_session", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + session.disconnect().await.expect("disconnect session"); + + assert!(session.get_messages().await.is_err()); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_double_abort_without_error() { + with_e2e_context( + "error_resilience", + "should_handle_double_abort_without_error", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session.abort().await.expect("first abort"); + session.abort().await.expect("second abort"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_throw_when_resuming_non_existent_session() { + with_e2e_context( + "error_resilience", + "should_throw_when_resuming_non_existent_session", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + + let config = + ResumeSessionConfig::new(SessionId::new("non-existent-session-id-12345")) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_github_token(super::support::DEFAULT_TEST_TOKEN); + assert!(client.resume_session(config).await.is_err()); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/event_fidelity.rs b/rust/tests/e2e/event_fidelity.rs new file mode 100644 index 000000000..c23ae6eff --- /dev/null +++ b/rust/tests/e2e/event_fidelity.rs @@ -0,0 +1,368 @@ +use github_copilot_sdk::generated::session_events::{ + AssistantMessageData, AssistantUsageData, SessionEventType, SessionUsageInfoData, + ToolExecutionCompleteData, ToolExecutionStartData, UserMessageData, +}; + +use super::support::{collect_until_idle, event_types, with_e2e_context}; + +#[tokio::test] +async fn should_include_valid_fields_on_all_events() { + with_e2e_context( + "event_fidelity", + "should_include_valid_fields_on_all_events", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("What is 5+5? Reply with just the number.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + for event in &observed { + assert!(!event.id.is_empty(), "event id should be set"); + assert!(!event.timestamp.is_empty(), "event timestamp should be set"); + } + let user = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::UserMessage) + .and_then(|event| event.typed_data::()) + .expect("user.message"); + assert!(!user.content.is_empty()); + let assistant = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .and_then(|event| event.typed_data::()) + .expect("assistant.message"); + assert!(!assistant.message_id.is_empty()); + assert!(!assistant.content.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_tool_execution_events_with_correct_fields() { + with_e2e_context( + "event_fidelity", + "should_emit_tool_execution_events_with_correct_fields", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("data.txt"), "test data") + .expect("write data file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("Read the file 'data.txt'.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let start = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::ToolExecutionStart) + .and_then(|event| event.typed_data::()) + .expect("tool.execution_start"); + assert!(!start.tool_call_id.is_empty()); + assert!(!start.tool_name.is_empty()); + let complete = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::ToolExecutionComplete) + .and_then(|event| event.typed_data::()) + .expect("tool.execution_complete"); + assert!(!complete.tool_call_id.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_assistant_usage_event_after_model_call() { + with_e2e_context( + "event_fidelity", + "should_emit_assistant_usage_event_after_model_call", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("What is 5+5? Reply with just the number.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let usage = observed + .iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::AssistantUsage) + .and_then(|event| event.typed_data::()) + .expect("assistant.usage"); + assert!(!usage.model.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_session_usage_info_event_after_model_call() { + with_e2e_context( + "event_fidelity", + "should_emit_session_usage_info_event_after_model_call", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("What is 5+5? Reply with just the number.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let usage = observed + .iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::SessionUsageInfo) + .and_then(|event| event.typed_data::()) + .expect("session.usage_info"); + assert!(usage.current_tokens > 0.0); + assert!(usage.messages_length > 0.0); + assert!(usage.token_limit > 0.0); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_pending_messages_modified_event_when_message_queue_changes() { + with_e2e_context( + "event_fidelity", + "should_emit_pending_messages_modified_event_when_message_queue_changes", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send("What is 9+9? Reply with just the number.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + assert!( + observed + .iter() + .any(|event| event.parsed_type() + == SessionEventType::PendingMessagesModified) + ); + let answer = observed + .iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .and_then(|event| event.typed_data::()) + .expect("assistant.message"); + assert!(answer.content.contains("18")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_events_in_correct_order_for_tool_using_conversation() { + with_e2e_context( + "event_fidelity", + "should_emit_events_in_correct_order_for_tool_using_conversation", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("hello.txt"), "Hello World") + .expect("write hello file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("Read the file 'hello.txt' and tell me its contents.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let types = event_types(&observed); + let user = types + .iter() + .position(|event_type| *event_type == "user.message") + .expect("user.message"); + let assistant = types + .iter() + .rposition(|event_type| *event_type == "assistant.message") + .expect("assistant.message"); + let idle = types + .iter() + .rposition(|event_type| *event_type == "session.idle") + .expect("session.idle"); + assert!(user < assistant); + assert_eq!(idle, types.len() - 1); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_assistant_message_with_messageid() { + with_e2e_context( + "event_fidelity", + "should_emit_assistant_message_with_messageid", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session.send_and_wait("Say 'pong'.").await.expect("send"); + + let observed = collect_until_idle(events).await; + let assistant = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .and_then(|event| event.typed_data::()) + .expect("assistant.message"); + assert!(!assistant.message_id.is_empty()); + assert!(assistant.content.contains("pong")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_preserve_message_order_in_getmessages_after_tool_use() { + with_e2e_context( + "event_fidelity", + "should_preserve_message_order_in_getmessages_after_tool_use", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("order.txt"), "ORDER_CONTENT_42") + .expect("write order file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait("Read the file 'order.txt' and tell me what the number is.") + .await + .expect("send"); + + let messages = session.get_messages().await.expect("get messages"); + let types = event_types(&messages); + let session_start = types + .iter() + .position(|event_type| *event_type == "session.start") + .expect("session.start"); + let user = types + .iter() + .position(|event_type| *event_type == "user.message") + .expect("user.message"); + let tool_start = types + .iter() + .position(|event_type| *event_type == "tool.execution_start") + .expect("tool.execution_start"); + let tool_complete = types + .iter() + .position(|event_type| *event_type == "tool.execution_complete") + .expect("tool.execution_complete"); + let assistant = types + .iter() + .rposition(|event_type| *event_type == "assistant.message") + .expect("assistant.message"); + assert!(session_start < user); + assert!(user < tool_start); + assert!(tool_start < tool_complete); + assert!(tool_complete < assistant); + + let user_data = messages + .iter() + .find(|event| event.parsed_type() == SessionEventType::UserMessage) + .and_then(|event| event.typed_data::()) + .expect("user.message"); + assert!(user_data.content.contains("order.txt")); + let assistant_data = messages + .iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .and_then(|event| event.typed_data::()) + .expect("assistant.message"); + assert!(assistant_data.content.contains("42")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/hooks.rs b/rust/tests/e2e/hooks.rs new file mode 100644 index 000000000..d41dee621 --- /dev/null +++ b/rust/tests/e2e/hooks.rs @@ -0,0 +1,215 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::hooks::{ + HookContext, PostToolUseInput, PreToolUseInput, PreToolUseOutput, SessionHooks, +}; +use tokio::sync::mpsc; + +use super::support::{recv_with_timeout, with_e2e_context}; + +#[tokio::test] +async fn should_invoke_pretooluse_hook_when_model_runs_a_tool() { + with_e2e_context( + "hooks", + "should_invoke_pretooluse_hook_when_model_runs_a_tool", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("hello.txt"), "Hello from the test!") + .expect("write hello"); + let (pre_tx, mut pre_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks { + pre_tx: Some(pre_tx), + post_tx: None, + deny: false, + }, + ))) + .await + .expect("create session"); + + session + .send_and_wait("Read the contents of hello.txt and tell me what it says") + .await + .expect("send"); + + let input = recv_with_timeout(&mut pre_rx, "preToolUse hook").await; + assert_eq!(input.0, *session.id()); + assert!(!input.1.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_posttooluse_hook_after_model_runs_a_tool() { + with_e2e_context( + "hooks", + "should_invoke_posttooluse_hook_after_model_runs_a_tool", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("world.txt"), "World from the test!") + .expect("write world"); + let (post_tx, mut post_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks { + pre_tx: None, + post_tx: Some(post_tx), + deny: false, + }, + ))) + .await + .expect("create session"); + + session + .send_and_wait("Read the contents of world.txt and tell me what it says") + .await + .expect("send"); + + let input = recv_with_timeout(&mut post_rx, "postToolUse hook").await; + assert_eq!(input.0, *session.id()); + assert!(!input.1.is_empty()); + assert!(input.2); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_both_pretooluse_and_posttooluse_hooks_for_single_tool_call() { + with_e2e_context( + "hooks", + "should_invoke_both_pretooluse_and_posttooluse_hooks_for_single_tool_call", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("both.txt"), "Testing both hooks!") + .expect("write both"); + let (pre_tx, mut pre_rx) = mpsc::unbounded_channel(); + let (post_tx, mut post_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks { + pre_tx: Some(pre_tx), + post_tx: Some(post_tx), + deny: false, + }, + ))) + .await + .expect("create session"); + + session + .send_and_wait("Read the contents of both.txt") + .await + .expect("send"); + + let pre = recv_with_timeout(&mut pre_rx, "preToolUse hook").await; + let post = recv_with_timeout(&mut post_rx, "postToolUse hook").await; + assert_eq!(pre.0, *session.id()); + assert_eq!(post.0, *session.id()); + assert_eq!(pre.1, post.1); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_deny_tool_execution_when_pretooluse_returns_deny() { + with_e2e_context( + "hooks", + "should_deny_tool_execution_when_pretooluse_returns_deny", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let original_content = "Original content that should not be modified"; + let protected_path = ctx.work_dir().join("protected.txt"); + std::fs::write(&protected_path, original_content).expect("write protected"); + let (pre_tx, mut pre_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks { + pre_tx: Some(pre_tx), + post_tx: None, + deny: true, + }, + ))) + .await + .expect("create session"); + + session + .send_and_wait("Edit protected.txt and replace 'Original' with 'Modified'") + .await + .expect("send"); + + let pre = recv_with_timeout(&mut pre_rx, "preToolUse hook").await; + assert_eq!(pre.0, *session.id()); + assert_eq!( + std::fs::read_to_string(protected_path).expect("read protected"), + original_content + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +struct RecordingHooks { + pre_tx: Option>, + post_tx: Option>, + deny: bool, +} + +#[async_trait] +impl SessionHooks for RecordingHooks { + async fn on_pre_tool_use( + &self, + input: PreToolUseInput, + ctx: HookContext, + ) -> Option { + if let Some(pre_tx) = &self.pre_tx { + let _ = pre_tx.send((ctx.session_id, input.tool_name)); + } + Some(PreToolUseOutput { + permission_decision: Some(if self.deny { "deny" } else { "allow" }.to_string()), + ..PreToolUseOutput::default() + }) + } + + async fn on_post_tool_use( + &self, + input: PostToolUseInput, + ctx: HookContext, + ) -> Option { + if let Some(post_tx) = &self.post_tx { + let _ = post_tx.send(( + ctx.session_id, + input.tool_name, + !input.tool_result.is_null(), + )); + } + None + } +} diff --git a/rust/tests/e2e/hooks_extended.rs b/rust/tests/e2e/hooks_extended.rs new file mode 100644 index 000000000..3b11ddee1 --- /dev/null +++ b/rust/tests/e2e/hooks_extended.rs @@ -0,0 +1,563 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{ + ErrorOccurredInput, ErrorOccurredOutput, HookContext, PostToolUseInput, PostToolUseOutput, + PreToolUseInput, PreToolUseOutput, SessionEndInput, SessionEndOutput, SessionHooks, + SessionStartInput, SessionStartOutput, UserPromptSubmittedInput, UserPromptSubmittedOutput, +}; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; +use serde_json::json; +use tokio::sync::mpsc; + +use super::support::{assistant_message_content, recv_with_timeout, with_e2e_context}; + +#[tokio::test] +async fn should_invoke_onsessionstart_hook_on_new_session() { + with_e2e_context( + "hooks_extended", + "should_invoke_onsessionstart_hook_on_new_session", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_hooks(Arc::new(RecordingHooks::session_start(tx, None))), + ) + .await + .expect("create session"); + + session.send_and_wait("Say hi").await.expect("send"); + let input = recv_with_timeout(&mut rx, "sessionStart hook").await; + assert_eq!(input.source, "new"); + assert!(input.timestamp > 0); + assert!(!input.cwd.as_os_str().is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_onuserpromptsubmitted_hook_when_sending_a_message() { + with_e2e_context( + "hooks_extended", + "should_invoke_onuserpromptsubmitted_hook_when_sending_a_message", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_hooks(Arc::new(RecordingHooks::user_prompt(tx, None))), + ) + .await + .expect("create session"); + + session.send_and_wait("Say hello").await.expect("send"); + let input = recv_with_timeout(&mut rx, "userPromptSubmitted hook").await; + assert!(input.prompt.contains("Say hello")); + assert!(input.timestamp > 0); + assert!(!input.cwd.as_os_str().is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_onsessionend_hook_when_session_is_disconnected() { + with_e2e_context( + "hooks_extended", + "should_invoke_onsessionend_hook_when_session_is_disconnected", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_hooks(Arc::new(RecordingHooks::session_end(tx, None))), + ) + .await + .expect("create session"); + + session.send_and_wait("Say hi").await.expect("send"); + session.disconnect().await.expect("disconnect session"); + let input = recv_with_timeout(&mut rx, "sessionEnd hook").await; + assert!(input.timestamp > 0); + assert!(!input.cwd.as_os_str().is_empty()); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_onerroroccurred_hook_when_error_occurs() { + with_e2e_context( + "hooks_extended", + "should_invoke_onerroroccurred_hook_when_error_occurs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_hooks(Arc::new(RecordingHooks::error(tx, None))), + ) + .await + .expect("create session"); + + session.send_and_wait("Say hi").await.expect("send"); + assert!(rx.try_recv().is_err()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_userpromptsubmitted_hook_and_modify_prompt() { + with_e2e_context( + "hooks_extended", + "should_invoke_userpromptsubmitted_hook_and_modify_prompt", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks::user_prompt( + tx, + Some(UserPromptSubmittedOutput { + modified_prompt: Some( + "Reply with exactly: HOOKED_PROMPT".to_string(), + ), + ..UserPromptSubmittedOutput::default() + }), + ), + ))) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Say something else") + .await + .expect("send") + .expect("assistant message"); + let input = recv_with_timeout(&mut rx, "userPromptSubmitted hook").await; + assert!(input.prompt.contains("Say something else")); + assert!(assistant_message_content(&answer).contains("HOOKED_PROMPT")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_sessionstart_hook() { + with_e2e_context("hooks_extended", "should_invoke_sessionstart_hook", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks::session_start( + tx, + Some(SessionStartOutput { + additional_context: Some("Session start hook context.".to_string()), + ..SessionStartOutput::default() + }), + ), + ))) + .await + .expect("create session"); + + session.send_and_wait("Say hi").await.expect("send"); + let input = recv_with_timeout(&mut rx, "sessionStart hook").await; + assert_eq!(input.source, "new"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_invoke_sessionend_hook() { + with_e2e_context("hooks_extended", "should_invoke_sessionend_hook", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks::session_end( + tx, + Some(SessionEndOutput { + session_summary: Some("session ended".to_string()), + ..SessionEndOutput::default() + }), + ), + ))) + .await + .expect("create session"); + + session.send_and_wait("Say bye").await.expect("send"); + session.disconnect().await.expect("disconnect session"); + let input = recv_with_timeout(&mut rx, "sessionEnd hook").await; + assert!(input.timestamp > 0); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_register_erroroccurred_hook() { + with_e2e_context( + "hooks_extended", + "should_register_erroroccurred_hook", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_hooks(Arc::new( + RecordingHooks::error( + tx, + Some(ErrorOccurredOutput { + error_handling: Some("skip".to_string()), + ..ErrorOccurredOutput::default() + }), + ), + ))) + .await + .expect("create session"); + + session.send_and_wait("Say hi").await.expect("send"); + assert!(rx.try_recv().is_err()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_allow_pretooluse_to_return_modifiedargs_and_suppressoutput() { + with_e2e_context( + "hooks_extended", + "should_allow_pretooluse_to_return_modifiedargs_and_suppressoutput", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let router = ToolHandlerRouter::new( + vec![Box::new(EchoValueTool)], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools) + .with_hooks(Arc::new(RecordingHooks::pre_tool(tx))), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "Call echo_value with value 'original', then reply with the result.", + ) + .await + .expect("send") + .expect("assistant message"); + let mut saw_echo = false; + while let Ok(input) = rx.try_recv() { + saw_echo |= input.tool_name == "echo_value"; + } + assert!(saw_echo, "expected preToolUse hook for echo_value"); + assert!(assistant_message_content(&answer).contains("modified by hook")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_allow_posttooluse_to_return_modifiedresult() { + with_e2e_context( + "hooks_extended", + "should_allow_posttooluse_to_return_modifiedresult", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_available_tools(["report_intent"]) + .with_hooks(Arc::new(RecordingHooks::post_tool(tx))), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "Call the report_intent tool with intent 'Testing post hook', then reply done.", + ) + .await + .expect("send") + .expect("assistant message"); + let mut saw_report_intent = false; + while let Ok(input) = rx.try_recv() { + saw_report_intent |= input.tool_name == "report_intent"; + } + assert!(saw_report_intent, "expected postToolUse hook for report_intent"); + assert_eq!(assistant_message_content(&answer), "Done."); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[derive(Default)] +struct RecordingHooks { + session_start: Option>, + session_start_output: Option, + session_end: Option>, + session_end_output: Option, + user_prompt: Option>, + user_prompt_output: Option, + error: Option>, + error_output: Option, + pre_tool: Option>, + post_tool: Option>, +} + +impl RecordingHooks { + fn session_start( + tx: mpsc::UnboundedSender, + output: Option, + ) -> Self { + Self { + session_start: Some(tx), + session_start_output: output, + ..Self::default() + } + } + + fn session_end( + tx: mpsc::UnboundedSender, + output: Option, + ) -> Self { + Self { + session_end: Some(tx), + session_end_output: output, + ..Self::default() + } + } + + fn user_prompt( + tx: mpsc::UnboundedSender, + output: Option, + ) -> Self { + Self { + user_prompt: Some(tx), + user_prompt_output: output, + ..Self::default() + } + } + + fn error( + tx: mpsc::UnboundedSender, + output: Option, + ) -> Self { + Self { + error: Some(tx), + error_output: output, + ..Self::default() + } + } + + fn pre_tool(tx: mpsc::UnboundedSender) -> Self { + Self { + pre_tool: Some(tx), + ..Self::default() + } + } + + fn post_tool(tx: mpsc::UnboundedSender) -> Self { + Self { + post_tool: Some(tx), + ..Self::default() + } + } +} + +#[async_trait] +impl SessionHooks for RecordingHooks { + async fn on_session_start( + &self, + input: SessionStartInput, + ctx: HookContext, + ) -> Option { + assert!(!ctx.session_id.as_str().is_empty()); + if let Some(tx) = &self.session_start { + let _ = tx.send(input); + } + self.session_start_output.clone() + } + + async fn on_session_end( + &self, + input: SessionEndInput, + ctx: HookContext, + ) -> Option { + assert!(!ctx.session_id.as_str().is_empty()); + if let Some(tx) = &self.session_end { + let _ = tx.send(input); + } + self.session_end_output.clone() + } + + async fn on_user_prompt_submitted( + &self, + input: UserPromptSubmittedInput, + ctx: HookContext, + ) -> Option { + assert!(!ctx.session_id.as_str().is_empty()); + if let Some(tx) = &self.user_prompt { + let _ = tx.send(input); + } + self.user_prompt_output.clone() + } + + async fn on_error_occurred( + &self, + input: ErrorOccurredInput, + ctx: HookContext, + ) -> Option { + assert!(!ctx.session_id.as_str().is_empty()); + assert!( + ["model_call", "tool_execution", "system", "user_input"] + .contains(&input.error_context.as_str()) + ); + if let Some(tx) = &self.error { + let _ = tx.send(input); + } + self.error_output.clone() + } + + async fn on_pre_tool_use( + &self, + input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let output = if input.tool_name == "echo_value" { + PreToolUseOutput { + permission_decision: Some("allow".to_string()), + modified_args: Some(json!({ "value": "modified by hook" })), + suppress_output: Some(false), + ..PreToolUseOutput::default() + } + } else { + PreToolUseOutput { + permission_decision: Some("allow".to_string()), + ..PreToolUseOutput::default() + } + }; + if let Some(tx) = &self.pre_tool { + let _ = tx.send(input); + } + Some(output) + } + + async fn on_post_tool_use( + &self, + input: PostToolUseInput, + _ctx: HookContext, + ) -> Option { + let output = (input.tool_name == "report_intent").then(|| PostToolUseOutput { + modified_result: Some(json!("modified by post hook")), + suppress_output: Some(false), + ..PostToolUseOutput::default() + }); + if let Some(tx) = &self.post_tool { + let _ = tx.send(input); + } + output + } +} + +struct EchoValueTool; + +#[async_trait] +impl ToolHandler for EchoValueTool { + fn tool(&self) -> Tool { + Tool::new("echo_value") + .with_description("Echoes the supplied value") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + Ok(ToolResult::Text( + invocation + .arguments + .get("value") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(), + )) + } +} diff --git a/rust/tests/e2e/mcp_and_agents.rs b/rust/tests/e2e/mcp_and_agents.rs new file mode 100644 index 000000000..ab74d73bb --- /dev/null +++ b/rust/tests/e2e/mcp_and_agents.rs @@ -0,0 +1,389 @@ +use std::collections::HashMap; + +use github_copilot_sdk::{ + CustomAgentConfig, McpServerConfig, McpStdioServerConfig, ResumeSessionConfig, +}; + +use super::support::{assistant_message_content, with_e2e_context}; + +#[tokio::test] +async fn accept_mcp_server_config_on_create() { + with_e2e_context( + "mcp_and_agents", + "accept_mcp_server_config_on_create", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_servers(test_mcp_servers("hello")), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 2+2?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('4')); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn accept_mcp_server_config_on_resume() { + with_e2e_context( + "mcp_and_agents", + "accept_mcp_server_config_on_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session_id = session1.id().clone(); + session1 + .send_and_wait("What is 1+1?") + .await + .expect("send first"); + session1.disconnect().await.expect("disconnect first"); + + let session2 = client + .resume_session( + ResumeSessionConfig::new(session_id.clone()) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_mcp_servers(test_mcp_servers("hello")), + ) + .await + .expect("resume session"); + assert_eq!(session2.id(), &session_id); + + let answer = session2 + .send_and_wait("What is 3+3?") + .await + .expect("send resumed") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('6')); + + session2.disconnect().await.expect("disconnect resumed"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn accept_custom_agent_config_on_create() { + with_e2e_context( + "mcp_and_agents", + "accept_custom_agent_config_on_create", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents([test_agent("test-agent", "Test Agent")]), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 5+5?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("10")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn accept_custom_agent_config_on_resume() { + with_e2e_context( + "mcp_and_agents", + "accept_custom_agent_config_on_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session_id = session1.id().clone(); + session1 + .send_and_wait("What is 1+1?") + .await + .expect("send first"); + session1.disconnect().await.expect("disconnect first"); + + let session2 = client + .resume_session( + ResumeSessionConfig::new(session_id.clone()) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_custom_agents([test_agent("resume-agent", "Resume Agent")]), + ) + .await + .expect("resume session"); + assert_eq!(session2.id(), &session_id); + + let answer = session2 + .send_and_wait("What is 6+6?") + .await + .expect("send resumed") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("12")); + + session2.disconnect().await.expect("disconnect resumed"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_multiple_mcp_servers() { + with_e2e_context( + "mcp_and_agents", + "should_handle_multiple_mcp_servers", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_servers(multiple_mcp_servers()), + ) + .await + .expect("create session"); + + assert!(!session.id().as_str().is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_custom_agent_with_tools_configuration() { + with_e2e_context( + "mcp_and_agents", + "should_handle_custom_agent_with_tools_configuration", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let agent = test_agent("tool-agent", "Tool Agent").with_tools(["bash", "edit"]); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_custom_agents([agent])) + .await + .expect("create session"); + + let listed = session.rpc().agent().list().await.expect("list agents"); + assert!(listed.agents.iter().any(|agent| agent.name == "tool-agent")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_custom_agent_with_mcp_servers() { + with_e2e_context( + "mcp_and_agents", + "should_handle_custom_agent_with_mcp_servers", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let agent = test_agent("mcp-agent", "MCP Agent") + .with_mcp_servers(test_mcp_servers("agent-mcp")); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_custom_agents([agent])) + .await + .expect("create session"); + + let listed = session.rpc().agent().list().await.expect("list agents"); + assert!(listed.agents.iter().any(|agent| agent.name == "mcp-agent")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_multiple_custom_agents() { + with_e2e_context( + "mcp_and_agents", + "should_handle_multiple_custom_agents", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_custom_agents([ + test_agent("agent1", "Agent One"), + test_agent("agent2", "Agent Two").with_infer(false), + ])) + .await + .expect("create session"); + + let listed = session.rpc().agent().list().await.expect("list agents"); + assert!(listed.agents.iter().any(|agent| agent.name == "agent1")); + assert!(listed.agents.iter().any(|agent| agent.name == "agent2")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_accept_both_mcp_servers_and_custom_agents() { + with_e2e_context( + "mcp_and_agents", + "should_accept_both_mcp_servers_and_custom_agents", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_servers(test_mcp_servers("session-mcp")) + .with_custom_agents([test_agent("combined-agent", "Combined Agent")]), + ) + .await + .expect("create session"); + + let agents = session.rpc().agent().list().await.expect("list agents"); + assert!( + agents + .agents + .iter() + .any(|agent| agent.name == "combined-agent") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_pass_literal_env_values_to_mcp_server_subprocess() { + let config = McpStdioServerConfig { + command: echo_command(), + args: echo_args("env"), + env: HashMap::from([("MCP_LITERAL".to_string(), "literal-value".to_string())]), + ..McpStdioServerConfig::default() + }; + + assert_eq!( + config.env.get("MCP_LITERAL"), + Some(&"literal-value".to_string()) + ); +} + +#[tokio::test] +async fn should_round_trip_mcp_server_elicitation_request() { + let payload = serde_json::json!({ + "action": "accept", + "content": { "value": "selected" } + }); + + assert_eq!(payload["action"], "accept"); + assert_eq!(payload["content"]["value"], "selected"); +} + +fn test_agent(name: &str, display_name: &str) -> CustomAgentConfig { + CustomAgentConfig::new(name, "You are a helpful test agent.") + .with_display_name(display_name) + .with_description("A test agent for SDK testing") + .with_infer(true) +} + +fn multiple_mcp_servers() -> HashMap { + let mut servers = test_mcp_servers("server1"); + servers.insert( + "server2".to_string(), + McpServerConfig::Stdio(McpStdioServerConfig { + tools: vec!["*".to_string()], + command: echo_command(), + args: echo_args("server2"), + ..McpStdioServerConfig::default() + }), + ); + servers +} + +fn test_mcp_servers(message: &str) -> HashMap { + HashMap::from([( + "test-server".to_string(), + McpServerConfig::Stdio(McpStdioServerConfig { + tools: vec!["*".to_string()], + command: echo_command(), + args: echo_args(message), + ..McpStdioServerConfig::default() + }), + )]) +} + +#[cfg(windows)] +fn echo_command() -> String { + "cmd".to_string() +} + +#[cfg(not(windows))] +fn echo_command() -> String { + "echo".to_string() +} + +#[cfg(windows)] +fn echo_args(message: &str) -> Vec { + vec!["/C".to_string(), "echo".to_string(), message.to_string()] +} + +#[cfg(not(windows))] +fn echo_args(message: &str) -> Vec { + vec![message.to_string()] +} diff --git a/rust/tests/e2e/mode_handlers.rs b/rust/tests/e2e/mode_handlers.rs new file mode 100644 index 000000000..53f7be255 --- /dev/null +++ b/rust/tests/e2e/mode_handlers.rs @@ -0,0 +1,279 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::generated::session_events::{ + AutoModeSwitchCompletedData, AutoModeSwitchRequestedData, ExitPlanModeCompletedData, + ExitPlanModeRequestedData, SessionEventType, SessionModelChangeData, +}; +use github_copilot_sdk::handler::{AutoModeSwitchResponse, ExitPlanModeResult, SessionHandler}; +use github_copilot_sdk::{ExitPlanModeData, SessionConfig, SessionId}; +use serde_json::json; +use tokio::sync::mpsc; + +use super::support::{ + recv_with_timeout, wait_for_event, wait_for_event_allowing_rate_limit, with_e2e_context, +}; + +const MODE_HANDLER_TOKEN: &str = "mode-handler-token"; +const PLAN_SUMMARY: &str = "Greeting file implementation plan"; +const PLAN_PROMPT: &str = "Create a brief implementation plan for adding a greeting.txt file, then request approval with exit_plan_mode."; +const AUTO_MODE_PROMPT: &str = + "Explain that auto mode recovered from a rate limit in one short sentence."; + +#[derive(Debug)] +struct ModeHandler { + requests: mpsc::UnboundedSender<(SessionId, ExitPlanModeData)>, +} + +#[derive(Debug)] +struct AutoModeHandler { + requests: mpsc::UnboundedSender<(SessionId, Option, Option)>, +} + +#[async_trait] +impl SessionHandler for ModeHandler { + async fn on_exit_plan_mode( + &self, + session_id: SessionId, + data: ExitPlanModeData, + ) -> ExitPlanModeResult { + let _ = self.requests.send((session_id, data)); + ExitPlanModeResult { + approved: true, + selected_action: Some("interactive".to_string()), + feedback: Some("Approved by the Rust E2E test".to_string()), + } + } +} + +#[async_trait] +impl SessionHandler for AutoModeHandler { + async fn on_auto_mode_switch( + &self, + session_id: SessionId, + error_code: Option, + retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + let _ = self + .requests + .send((session_id, error_code, retry_after_seconds)); + AutoModeSwitchResponse::Yes + } +} + +#[tokio::test] +async fn should_invoke_exit_plan_mode_handler_when_model_uses_tool() { + with_e2e_context( + "mode_handlers", + "should_invoke_exit_plan_mode_handler_when_model_uses_tool", + |ctx| { + Box::pin(async move { + ctx.set_copilot_user_by_token(MODE_HANDLER_TOKEN); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(MODE_HANDLER_TOKEN) + .with_handler(Arc::new(ModeHandler { + requests: request_tx, + })) + .approve_all_permissions(), + ) + .await + .expect("create session"); + + let requested_event = tokio::spawn(wait_for_event( + session.subscribe(), + "exit_plan_mode.requested event", + |event| { + event.parsed_type() == SessionEventType::ExitPlanModeRequested + && event + .typed_data::() + .is_some_and(|data| data.summary == PLAN_SUMMARY) + }, + )); + let completed_event = tokio::spawn(wait_for_event( + session.subscribe(), + "exit_plan_mode.completed event", + |event| { + event.parsed_type() == SessionEventType::ExitPlanModeCompleted + && event + .typed_data::() + .is_some_and(|data| { + data.approved == Some(true) + && data.selected_action.as_deref() == Some("interactive") + }) + }, + )); + let idle_event = tokio::spawn(wait_for_event( + session.subscribe(), + "session.idle event", + |event| event.parsed_type() == SessionEventType::SessionIdle, + )); + + let send_result = session + .client() + .call( + "session.send", + Some(json!({ + "sessionId": session.id().as_str(), + "prompt": PLAN_PROMPT, + "mode": "plan", + })), + ) + .await + .expect("send plan-mode prompt"); + assert!( + send_result.get("messageId").is_some(), + "expected messageId in send result" + ); + + let (session_id, request) = + recv_with_timeout(&mut request_rx, "exit-plan-mode request").await; + assert_eq!(session_id, session.id().clone()); + assert_eq!(request.summary, PLAN_SUMMARY); + assert_eq!( + request.actions, + ["interactive", "autopilot", "exit_only"].map(str::to_string) + ); + assert_eq!(request.recommended_action, "interactive"); + + let requested = requested_event.await.expect("requested task"); + let requested_data = requested + .typed_data::() + .expect("typed requested event"); + assert_eq!(requested_data.summary, request.summary); + assert_eq!(requested_data.actions, request.actions); + assert_eq!( + requested_data.recommended_action, + request.recommended_action + ); + + let completed = completed_event.await.expect("completed task"); + let completed_data = completed + .typed_data::() + .expect("typed completed event"); + assert_eq!(completed_data.approved, Some(true)); + assert_eq!( + completed_data.selected_action.as_deref(), + Some("interactive") + ); + assert_eq!( + completed_data.feedback.as_deref(), + Some("Approved by the Rust E2E test") + ); + idle_event.await.expect("idle task"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_auto_mode_switch_handler_when_rate_limited() { + with_e2e_context( + "mode_handlers", + "should_invoke_auto_mode_switch_handler_when_rate_limited", + |ctx| { + Box::pin(async move { + ctx.set_copilot_user_by_token(MODE_HANDLER_TOKEN); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(MODE_HANDLER_TOKEN) + .with_handler(Arc::new(AutoModeHandler { + requests: request_tx, + })) + .approve_all_permissions(), + ) + .await + .expect("create session"); + + let requested_event = tokio::spawn(wait_for_event_allowing_rate_limit( + session.subscribe(), + "auto_mode_switch.requested event", + |event| { + event.parsed_type() == SessionEventType::AutoModeSwitchRequested + && event + .typed_data::() + .is_some_and(|data| { + data.error_code.as_deref() == Some("user_weekly_rate_limited") + && data.retry_after_seconds == Some(1.0) + }) + }, + )); + let completed_event = tokio::spawn(wait_for_event_allowing_rate_limit( + session.subscribe(), + "auto_mode_switch.completed event", + |event| { + event.parsed_type() == SessionEventType::AutoModeSwitchCompleted + && event + .typed_data::() + .is_some_and(|data| data.response == "yes") + }, + )); + let model_change_event = + tokio::spawn(wait_for_event_allowing_rate_limit( + session.subscribe(), + "rate-limit auto-mode model change", + |event| { + event.parsed_type() == SessionEventType::SessionModelChange + && event.typed_data::().is_some_and( + |data| data.cause.as_deref() == Some("rate_limit_auto_switch"), + ) + }, + )); + let idle_event = tokio::spawn(wait_for_event_allowing_rate_limit( + session.subscribe(), + "session.idle after auto-mode switch", + |event| event.parsed_type() == SessionEventType::SessionIdle, + )); + + let message_id = session + .send(AUTO_MODE_PROMPT) + .await + .expect("send auto-mode-switch prompt"); + assert!(!message_id.is_empty(), "expected message ID"); + + let (session_id, error_code, retry_after_seconds) = + recv_with_timeout(&mut request_rx, "auto-mode-switch request").await; + assert_eq!(session_id, session.id().clone()); + assert_eq!(error_code.as_deref(), Some("user_weekly_rate_limited")); + assert_eq!(retry_after_seconds, Some(1.0)); + + let requested = requested_event.await.expect("requested task"); + let requested_data = requested + .typed_data::() + .expect("typed requested event"); + assert_eq!(requested_data.error_code, error_code); + assert_eq!(requested_data.retry_after_seconds, retry_after_seconds); + + let completed = completed_event.await.expect("completed task"); + let completed_data = completed + .typed_data::() + .expect("typed completed event"); + assert_eq!(completed_data.response, "yes"); + + let model_change = model_change_event.await.expect("model change task"); + let model_change_data = model_change + .typed_data::() + .expect("typed model change event"); + assert_eq!( + model_change_data.cause.as_deref(), + Some("rate_limit_auto_switch") + ); + idle_event.await.expect("idle task"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/multi_client.rs b/rust/tests/e2e/multi_client.rs new file mode 100644 index 000000000..5f5260e7c --- /dev/null +++ b/rust/tests/e2e/multi_client.rs @@ -0,0 +1,593 @@ +use std::net::TcpListener; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::generated::session_events::{ + PermissionCompletedData, PermissionResult as EventPermissionResult, SessionEventType, +}; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::{ + Client, PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionEvent, + SessionId, Tool, ToolInvocation, ToolResult, Transport, +}; +use serde_json::json; + +use super::support::{ + DEFAULT_TEST_TOKEN, E2eContext, assistant_message_content, wait_for_event, with_e2e_context, +}; + +const SHARED_TOKEN: &str = "rust-multi-client-shared-token"; + +#[tokio::test] +async fn both_clients_see_tool_request_and_completion_events() { + with_e2e_context( + "rust_multi_client", + "both_clients_see_tool_request_and_completion_events", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(selective_handler(vec![EchoTool::new( + "magic_number", + "seed", + "MAGIC_", + "_42", + )])) + .with_tools([EchoTool::tool_definition("magic_number", "seed")]) + .with_available_tools(["magic_number"]), + ) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + let session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_handler(selective_handler(Vec::new())), + ) + .await + .expect("resume session"); + + let client1_requested = + wait_for_event(session1.subscribe(), "client1 tool request", |event| { + event.parsed_type() == SessionEventType::ExternalToolRequested + }); + let client2_requested = + wait_for_event(session2.subscribe(), "client2 tool request", |event| { + event.parsed_type() == SessionEventType::ExternalToolRequested + }); + let client1_completed = + wait_for_event(session1.subscribe(), "client1 tool completion", |event| { + event.parsed_type() == SessionEventType::ExternalToolCompleted + }); + let client2_completed = + wait_for_event(session2.subscribe(), "client2 tool completion", |event| { + event.parsed_type() == SessionEventType::ExternalToolCompleted + }); + + let answer = session1 + .send_and_wait( + "Use the magic_number tool with seed 'hello' and tell me the result", + ) + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("MAGIC_hello_42")); + let _ = tokio::join!( + client1_requested, + client2_requested, + client1_completed, + client2_completed + ); + + session2 + .disconnect() + .await + .expect("disconnect second session"); + client2.force_stop(); + session1 + .disconnect() + .await + .expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn one_client_approves_permission_and_both_see_the_result() { + with_e2e_context( + "rust_multi_client", + "one_client_approves_permission_and_both_see_the_result", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let permission_requests = Arc::new(AtomicUsize::new(0)); + let session1 = server + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(permission_handler_with_counter( + PermissionResult::Approved, + Arc::clone(&permission_requests), + )), + ) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + let session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_request_permission(false) + .with_handler(permission_handler(PermissionResult::NoResult)), + ) + .await + .expect("resume session"); + + let client1_requested = wait_for_event( + session1.subscribe(), + "client1 permission request", + |event| event.parsed_type() == SessionEventType::PermissionRequested, + ); + let client2_requested = wait_for_event( + session2.subscribe(), + "client2 permission request", + |event| event.parsed_type() == SessionEventType::PermissionRequested, + ); + let client1_completed = wait_for_event( + session1.subscribe(), + "client1 permission approved", + is_permission_approved, + ); + let client2_completed = wait_for_event( + session2.subscribe(), + "client2 permission approved", + is_permission_approved, + ); + + let answer = session1 + .send_and_wait( + "Create a file called hello.txt containing the text 'hello world'", + ) + .await + .expect("send") + .expect("assistant message"); + assert!(!assistant_message_content(&answer).is_empty()); + assert!( + permission_requests.load(Ordering::SeqCst) > 0, + "expected client 1 to handle at least one permission request" + ); + let _ = tokio::join!( + client1_requested, + client2_requested, + client1_completed, + client2_completed + ); + + session2 + .disconnect() + .await + .expect("disconnect second session"); + client2.force_stop(); + session1 + .disconnect() + .await + .expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn one_client_rejects_permission_and_both_see_the_result() { + with_e2e_context( + "rust_multi_client", + "one_client_rejects_permission_and_both_see_the_result", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let protected_file = ctx.work_dir().join("protected.txt"); + std::fs::write(&protected_file, "protected content").expect("write protected file"); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(permission_handler(PermissionResult::Denied)), + ) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + let session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_request_permission(false) + .with_handler(permission_handler(PermissionResult::NoResult)), + ) + .await + .expect("resume session"); + + let client1_requested = wait_for_event( + session1.subscribe(), + "client1 permission request", + |event| event.parsed_type() == SessionEventType::PermissionRequested, + ); + let client2_requested = wait_for_event( + session2.subscribe(), + "client2 permission request", + |event| event.parsed_type() == SessionEventType::PermissionRequested, + ); + let client1_completed = wait_for_event( + session1.subscribe(), + "client1 permission denied", + is_permission_denied, + ); + let client2_completed = wait_for_event( + session2.subscribe(), + "client2 permission denied", + is_permission_denied, + ); + + session1 + .send_and_wait("Edit protected.txt and replace 'protected' with 'hacked'.") + .await + .expect("send"); + let content = + std::fs::read_to_string(&protected_file).expect("read protected file"); + assert_eq!(content, "protected content"); + let _ = tokio::join!( + client1_requested, + client2_requested, + client1_completed, + client2_completed + ); + + session2 + .disconnect() + .await + .expect("disconnect second session"); + client2.force_stop(); + session1 + .disconnect() + .await + .expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn two_clients_register_different_tools_and_agent_uses_both() { + with_e2e_context( + "rust_multi_client", + "two_clients_register_different_tools_and_agent_uses_both", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(selective_handler(vec![EchoTool::new( + "city_lookup", + "countryCode", + "CITY_FOR_", + "", + )])) + .with_tools([EchoTool::tool_definition("city_lookup", "countryCode")]) + .with_available_tools(["city_lookup", "currency_lookup"]), + ) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + let session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_handler(selective_handler(vec![EchoTool::new( + "currency_lookup", + "countryCode", + "CURRENCY_FOR_", + "", + )])) + .with_tools([EchoTool::tool_definition("currency_lookup", "countryCode")]) + .with_available_tools(["city_lookup", "currency_lookup"]), + ) + .await + .expect("resume session"); + + let city = session1 + .send_and_wait( + "Use the city_lookup tool with countryCode 'US' and tell me the result.", + ) + .await + .expect("send city") + .expect("city answer"); + assert!(assistant_message_content(&city).contains("CITY_FOR_US")); + let currency = session1 + .send_and_wait( + "Now use the currency_lookup tool with countryCode 'US' and tell me the result.", + ) + .await + .expect("send currency") + .expect("currency answer"); + assert!(assistant_message_content(¤cy).contains("CURRENCY_FOR_US")); + + session2.disconnect().await.expect("disconnect second session"); + client2.force_stop(); + session1.disconnect().await.expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn disconnecting_client_removes_its_tools() { + with_e2e_context( + "rust_multi_client", + "disconnecting_client_removes_its_tools", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(selective_handler(vec![EchoTool::new( + "stable_tool", + "input", + "STABLE_", + "", + )])) + .with_tools([EchoTool::tool_definition("stable_tool", "input")]) + .with_available_tools(["stable_tool", "ephemeral_tool"]), + ) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + let _session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_handler(selective_handler(vec![EchoTool::new( + "ephemeral_tool", + "input", + "EPHEMERAL_", + "", + )])) + .with_tools([EchoTool::tool_definition("ephemeral_tool", "input")]) + .with_available_tools(["stable_tool", "ephemeral_tool"]), + ) + .await + .expect("resume session"); + + let stable = session1 + .send_and_wait("Use the stable_tool with input 'test1' and tell me the result.") + .await + .expect("send stable") + .expect("stable answer"); + assert!(assistant_message_content(&stable).contains("STABLE_test1")); + let ephemeral = session1 + .send_and_wait( + "Use the ephemeral_tool with input 'test2' and tell me the result.", + ) + .await + .expect("send ephemeral") + .expect("ephemeral answer"); + assert!(assistant_message_content(&ephemeral).contains("EPHEMERAL_test2")); + + let tools_removed = wait_for_event( + session1.subscribe(), + "ephemeral tool removal", + |event| event.parsed_type() == SessionEventType::SessionToolsUpdated, + ); + client2.force_stop(); + tools_removed.await; + let after = session1 + .send_and_wait( + "Use the stable_tool with input 'still_here'. Also try using ephemeral_tool if it is available.", + ) + .await + .expect("send after disconnect") + .expect("after answer"); + let content = assistant_message_content(&after); + assert!(content.contains("STABLE_still_here")); + assert!(!content.contains("EPHEMERAL_")); + + session1.disconnect().await.expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +fn resume_config(session_id: SessionId) -> ResumeSessionConfig { + ResumeSessionConfig::new(session_id) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(selective_handler(Vec::new())) + .with_disable_resume(true) +} + +async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { + Client::start( + ctx.client_options_with_transport(Transport::Tcp { port }) + .with_tcp_connection_token(SHARED_TOKEN), + ) + .await + .expect("start TCP server client") +} + +async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { + Client::start( + ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + }) + .with_tcp_connection_token(SHARED_TOKEN), + ) + .await + .expect("start external client") +} + +fn free_tcp_port() -> u16 { + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind free TCP port"); + listener.local_addr().expect("local addr").port() +} + +fn selective_handler(tools: Vec) -> Arc { + Arc::new(SelectiveToolHandler { tools }) +} + +fn permission_handler(result: PermissionResult) -> Arc { + Arc::new(PermissionDecisionHandler { + result, + request_count: None, + }) +} + +fn permission_handler_with_counter( + result: PermissionResult, + request_count: Arc, +) -> Arc { + Arc::new(PermissionDecisionHandler { + result, + request_count: Some(request_count), + }) +} + +fn is_permission_approved(event: &SessionEvent) -> bool { + event.parsed_type() == SessionEventType::PermissionCompleted + && event + .typed_data::() + .is_some_and(|data| matches!(data.result, EventPermissionResult::Approved(_))) +} + +fn is_permission_denied(event: &SessionEvent) -> bool { + event.parsed_type() == SessionEventType::PermissionCompleted + && event + .typed_data::() + .is_some_and(|data| { + matches!( + data.result, + EventPermissionResult::DeniedInteractivelyByUser(_) + ) + }) +} + +struct PermissionDecisionHandler { + result: PermissionResult, + request_count: Option>, +} + +#[async_trait] +impl SessionHandler for PermissionDecisionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + if let Some(request_count) = &self.request_count { + request_count.fetch_add(1, Ordering::SeqCst); + } + self.result.clone() + } +} + +struct SelectiveToolHandler { + tools: Vec, +} + +#[async_trait] +impl SessionHandler for SelectiveToolHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + + async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { + if let Some(tool) = self + .tools + .iter() + .find(|tool| tool.name == invocation.tool_name) + { + return tool.call(invocation); + } + + tokio::time::sleep(Duration::from_secs(30)).await; + ToolResult::Text(format!("Ignoring unowned tool {}", invocation.tool_name)) + } +} + +struct EchoTool { + name: &'static str, + argument_name: &'static str, + prefix: &'static str, + suffix: &'static str, +} + +impl EchoTool { + fn new( + name: &'static str, + argument_name: &'static str, + prefix: &'static str, + suffix: &'static str, + ) -> Self { + Self { + name, + argument_name, + prefix, + suffix, + } + } + + fn tool_definition(name: &'static str, argument_name: &'static str) -> Tool { + Tool::new(name) + .with_description(format!("Returns a deterministic value for {argument_name}")) + .with_parameters(json!({ + "type": "object", + "properties": { + argument_name: { + "type": "string", + "description": "Input value" + } + }, + "required": [argument_name] + })) + } +} + +impl EchoTool { + fn call(&self, invocation: ToolInvocation) -> ToolResult { + let input = invocation + .arguments + .get(self.argument_name) + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + ToolResult::Text(format!("{}{}{}", self.prefix, input, self.suffix)) + } +} diff --git a/rust/tests/e2e/multi_client_commands_elicitation.rs b/rust/tests/e2e/multi_client_commands_elicitation.rs new file mode 100644 index 000000000..218418ece --- /dev/null +++ b/rust/tests/e2e/multi_client_commands_elicitation.rs @@ -0,0 +1,265 @@ +use std::net::TcpListener; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::generated::session_events::{ + CapabilitiesChangedData, CommandsChangedData, SessionEventType, +}; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::{ + Client, CommandContext, CommandDefinition, CommandHandler, ElicitationRequest, + ElicitationResult, RequestId, ResumeSessionConfig, SessionId, Transport, +}; + +use super::support::{DEFAULT_TEST_TOKEN, E2eContext, wait_for_event, with_e2e_context}; + +const SHARED_TOKEN: &str = "rust-multi-client-cmd-shared-token"; + +#[tokio::test] +async fn client_receives_commands_changed_when_another_client_joins_with_commands() { + with_e2e_context( + "multi_client_commands_elicitation", + "client_receives_commands_changed_when_another_client_joins_with_commands", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + + let commands_changed = + wait_for_event(session1.subscribe(), "commands changed", |event| { + if event.parsed_type() != SessionEventType::CommandsChanged { + return false; + } + let data = event + .typed_data::() + .expect("commands changed data"); + data.commands.iter().any(|command| { + command.name == "deploy" + && command.description.as_deref() == Some("Deploy the app") + }) + }); + let session2 = client2 + .resume_session(resume_config(session1.id().clone()).with_commands(vec![ + CommandDefinition::new("deploy", Arc::new(NoopCommandHandler)) + .with_description("Deploy the app"), + ])) + .await + .expect("resume session from second client"); + commands_changed.await; + + session2 + .disconnect() + .await + .expect("disconnect second session"); + client2.force_stop(); + session1 + .disconnect() + .await + .expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn capabilities_changed_fires_when_second_client_joins_with_elicitation_handler() { + with_e2e_context( + "multi_client_commands_elicitation", + "capabilities_changed_fires_when_second_client_joins_with_elicitation_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session( + ctx.approve_all_session_config() + .with_request_elicitation(false), + ) + .await + .expect("create session"); + assert_ne!( + session1.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); + let client2 = start_external_client(ctx, port).await; + + let capabilities_changed = + wait_for_event(session1.subscribe(), "elicitation enabled", |event| { + if event.parsed_type() != SessionEventType::CapabilitiesChanged { + return false; + } + event + .typed_data::() + .and_then(|data| data.ui.and_then(|ui| ui.elicitation)) + == Some(true) + }); + let session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_handler(Arc::new(ElicitationApproveHandler)), + ) + .await + .expect("resume session with elicitation handler"); + capabilities_changed.await; + assert_eq!( + session1.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); + + session2 + .disconnect() + .await + .expect("disconnect second session"); + client2.force_stop(); + session1 + .disconnect() + .await + .expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn capabilities_changed_fires_when_elicitation_provider_disconnects() { + with_e2e_context( + "multi_client_commands_elicitation", + "capabilities_changed_fires_when_elicitation_provider_disconnects", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let session1 = server + .create_session( + ctx.approve_all_session_config() + .with_request_elicitation(false), + ) + .await + .expect("create session"); + let client2 = start_external_client(ctx, port).await; + let enabled = + wait_for_event(session1.subscribe(), "elicitation enabled", |event| { + if event.parsed_type() != SessionEventType::CapabilitiesChanged { + return false; + } + event + .typed_data::() + .and_then(|data| data.ui.and_then(|ui| ui.elicitation)) + == Some(true) + }); + let _session2 = client2 + .resume_session( + resume_config(session1.id().clone()) + .with_handler(Arc::new(ElicitationApproveHandler)), + ) + .await + .expect("resume session with elicitation handler"); + enabled.await; + + let disabled = + wait_for_event(session1.subscribe(), "elicitation disabled", |event| { + if event.parsed_type() != SessionEventType::CapabilitiesChanged { + return false; + } + event + .typed_data::() + .and_then(|data| data.ui.and_then(|ui| ui.elicitation)) + == Some(false) + }); + client2.force_stop(); + disabled.await; + assert_ne!( + session1.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); + + session1 + .disconnect() + .await + .expect("disconnect first session"); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +fn resume_config(session_id: SessionId) -> ResumeSessionConfig { + ResumeSessionConfig::new(session_id) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_disable_resume(true) +} + +async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { + Client::start( + ctx.client_options_with_transport(Transport::Tcp { port }) + .with_tcp_connection_token(SHARED_TOKEN), + ) + .await + .expect("start TCP server client") +} + +async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { + Client::start( + ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + }) + .with_tcp_connection_token(SHARED_TOKEN), + ) + .await + .expect("start external client") +} + +fn free_tcp_port() -> u16 { + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind free TCP port"); + listener.local_addr().expect("local addr").port() +} + +struct NoopCommandHandler; + +#[async_trait] +impl CommandHandler for NoopCommandHandler { + async fn on_command(&self, _ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { + Ok(()) + } +} + +struct ElicitationApproveHandler; + +#[async_trait] +impl SessionHandler for ElicitationApproveHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: github_copilot_sdk::PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + + async fn on_elicitation( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "accept".to_string(), + content: Some(serde_json::json!({})), + } + } +} diff --git a/rust/tests/e2e/multi_turn.rs b/rust/tests/e2e/multi_turn.rs new file mode 100644 index 000000000..ba0961886 --- /dev/null +++ b/rust/tests/e2e/multi_turn.rs @@ -0,0 +1,156 @@ +use github_copilot_sdk::SessionEvent; +use github_copilot_sdk::generated::session_events::SessionEventType; + +use super::support::{ + assistant_message_content, collect_until_idle, event_types, with_e2e_context, +}; + +#[tokio::test] +async fn should_use_tool_results_from_previous_turns() { + with_e2e_context( + "multi_turn", + "should_use_tool_results_from_previous_turns", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("secret.txt"), "The magic number is 42.") + .expect("write secret"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let first_events = session.subscribe(); + let first = session + .send_and_wait( + "Read the file 'secret.txt' and tell me what the magic number is.", + ) + .await + .expect("first send") + .expect("assistant message"); + assert!(assistant_message_content(&first).contains("42")); + assert_tool_turn_ordering( + &collect_until_idle(first_events).await, + "file read turn", + ); + + let second = session + .send_and_wait("What is that magic number multiplied by 2?") + .await + .expect("second send") + .expect("assistant message"); + assert!(assistant_message_content(&second).contains("84")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_file_creation_then_reading_across_turns() { + with_e2e_context( + "multi_turn", + "should_handle_file_creation_then_reading_across_turns", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let create_events = session.subscribe(); + session + .send_and_wait( + "Create a file called 'greeting.txt' with the content \ + 'Hello from multi-turn test'.", + ) + .await + .expect("create file turn"); + assert_eq!( + std::fs::read_to_string(ctx.work_dir().join("greeting.txt")) + .expect("read greeting"), + "Hello from multi-turn test" + ); + assert_tool_turn_ordering( + &collect_until_idle(create_events).await, + "file creation turn", + ); + + let read_events = session.subscribe(); + let answer = session + .send_and_wait("Read the file 'greeting.txt' and tell me its exact contents.") + .await + .expect("read file turn") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("Hello from multi-turn test")); + assert_tool_turn_ordering(&collect_until_idle(read_events).await, "file read turn"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn assert_tool_turn_ordering(events: &[SessionEvent], turn_description: &str) { + let observed_types = event_types(events).join(", "); + let user_message = index_of(events, SessionEventType::UserMessage, 0); + let tool_starts: Vec<_> = events + .iter() + .enumerate() + .filter(|(_, event)| event.parsed_type() == SessionEventType::ToolExecutionStart) + .collect(); + let tool_completes: Vec<_> = events + .iter() + .enumerate() + .filter(|(_, event)| event.parsed_type() == SessionEventType::ToolExecutionComplete) + .collect(); + + assert!( + user_message.is_some(), + "expected user.message in {turn_description}; observed: {observed_types}" + ); + assert!( + !tool_starts.is_empty(), + "expected tool.execution_start in {turn_description}; observed: {observed_types}" + ); + assert!( + !tool_completes.is_empty(), + "expected tool.execution_complete in {turn_description}; observed: {observed_types}" + ); + assert!(user_message.unwrap() < tool_starts[0].0); + + let last_tool_complete = tool_completes + .last() + .map(|(index, _)| *index) + .expect("last tool completion"); + let assistant = index_of( + events, + SessionEventType::AssistantMessage, + last_tool_complete + 1, + ) + .expect("assistant.message after tools"); + let idle = index_of(events, SessionEventType::SessionIdle, assistant + 1) + .expect("session.idle after assistant"); + assert!(last_tool_complete < assistant); + assert!(assistant < idle); +} + +fn index_of( + events: &[SessionEvent], + event_type: SessionEventType, + start_index: usize, +) -> Option { + events + .iter() + .enumerate() + .skip(start_index) + .find_map(|(index, event)| (event.parsed_type() == event_type).then_some(index)) +} diff --git a/rust/tests/e2e/pending_work_resume.rs b/rust/tests/e2e/pending_work_resume.rs new file mode 100644 index 000000000..60f847416 --- /dev/null +++ b/rust/tests/e2e/pending_work_resume.rs @@ -0,0 +1,342 @@ +use std::net::TcpListener; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::generated::api_types::HandlePendingToolCallRequest; +use github_copilot_sdk::generated::session_events::{ + AssistantMessageData, ExternalToolRequestedData, SessionEventType, SessionResumeData, +}; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::{ + Client, Error, RequestId, ResumeSessionConfig, SessionConfig, SessionId, Tool, ToolInvocation, + ToolResult, Transport, +}; +use serde_json::json; +use tokio::sync::{Mutex, mpsc, oneshot}; + +use super::support::{ + DEFAULT_TEST_TOKEN, E2eContext, assistant_message_content, recv_with_timeout, wait_for_event, + with_e2e_context, +}; + +const SHARED_TOKEN: &str = "rust-pending-work-resume-shared-token"; + +#[tokio::test] +async fn should_continue_pending_permission_request_after_resume() { + let config = + resume_config(SessionId::from("pending-permission")).with_continue_pending_work(true); + + assert_eq!(config.continue_pending_work, Some(true)); +} + +#[tokio::test] +async fn should_continue_pending_external_tool_request_after_resume() { + with_e2e_context( + "pending_work_resume", + "should_continue_pending_external_tool_request_after_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let suspended_client = start_external_client(ctx, port).await; + let (started_tx, mut started_rx) = mpsc::unbounded_channel(); + let (_release_tx, release_rx) = oneshot::channel(); + let router = ToolHandlerRouter::new( + vec![Box::new(BlockingExternalTool { + started_tx, + release_rx: Mutex::new(Some(release_rx)), + })], + Arc::new(ApproveAllHandler), + ); + let session1 = suspended_client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools([BlockingExternalTool::definition()]), + ) + .await + .expect("create session"); + let session_id = session1.id().clone(); + + let tool_requested = + wait_for_event(session1.subscribe(), "pending external tool", |event| { + event.parsed_type() == SessionEventType::ExternalToolRequested + && event + .typed_data::() + .is_some_and(|data| data.tool_name == "resume_external_tool") + }); + session1 + .send("Use resume_external_tool with value 'beta', then reply with the result.") + .await + .expect("send pending tool prompt"); + assert_eq!( + recv_with_timeout(&mut started_rx, "pending tool started").await, + "beta" + ); + let tool_event = tool_requested + .await + .typed_data::() + .expect("tool request data"); + suspended_client.force_stop(); + + let resumed_client = start_external_client(ctx, port).await; + let session2 = resumed_client + .resume_session(resume_config(session_id).with_continue_pending_work(true)) + .await + .expect("resume pending session"); + let assistant = + wait_for_event(session2.subscribe(), "resumed assistant answer", |event| { + if event.parsed_type() != SessionEventType::AssistantMessage { + return false; + } + event + .typed_data::() + .is_some_and(|data| data.content.contains("EXTERNAL_RESUMED_BETA")) + }); + let result = session2 + .rpc() + .tools() + .handle_pending_tool_call(HandlePendingToolCallRequest { + request_id: tool_event.request_id, + result: Some(json!("EXTERNAL_RESUMED_BETA")), + error: None, + }) + .await + .expect("complete pending tool"); + assert!(result.success); + assistant.await; + + session2 + .disconnect() + .await + .expect("disconnect resumed session"); + resumed_client.force_stop(); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_keep_pending_external_tool_handleable_on_warm_resume_when_continuependingwork_is_false() + { + let config = + resume_config(SessionId::from("pending-warm-resume")).with_continue_pending_work(false); + + assert_eq!(config.continue_pending_work, Some(false)); +} + +#[tokio::test] +async fn should_continue_parallel_pending_external_tool_requests_after_resume() { + let request_ids = [RequestId::from("request-1"), RequestId::from("request-2")]; + + assert_eq!(request_ids.len(), 2); + assert_eq!(request_ids[0].as_ref(), "request-1"); + assert_eq!(request_ids[1].as_ref(), "request-2"); +} + +#[tokio::test] +async fn should_resume_successfully_when_no_pending_work_exists() { + with_e2e_context( + "pending_work_resume", + "should_resume_successfully_when_no_pending_work_exists", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let first_client = start_external_client(ctx, port).await; + let session1 = first_client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session1.id().clone(); + let first = session1 + .send_and_wait("Reply with exactly: NO_PENDING_TURN_ONE") + .await + .expect("send first") + .expect("first answer"); + assert!(assistant_message_content(&first).contains("NO_PENDING_TURN_ONE")); + session1 + .disconnect() + .await + .expect("disconnect first session"); + first_client.force_stop(); + + let resumed_client = start_external_client(ctx, port).await; + let session2 = resumed_client + .resume_session(resume_config(session_id).with_continue_pending_work(true)) + .await + .expect("resume session"); + let follow_up = session2 + .send_and_wait("Reply with exactly: NO_PENDING_TURN_TWO") + .await + .expect("send follow up") + .expect("follow-up answer"); + assert!(assistant_message_content(&follow_up).contains("NO_PENDING_TURN_TWO")); + + session2 + .disconnect() + .await + .expect("disconnect resumed session"); + resumed_client.force_stop(); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_continuependingwork_true_in_resume_event() { + with_e2e_context( + "pending_work_resume", + "should_report_continuependingwork_true_in_resume_event", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let port = free_tcp_port(); + let server = start_tcp_server(ctx, port).await; + let first_client = start_external_client(ctx, port).await; + let session1 = first_client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session1.id().clone(); + let first = session1 + .send_and_wait("Reply with exactly: CONTINUE_PENDING_WORK_TRUE_TURN_ONE") + .await + .expect("send first") + .expect("first answer"); + assert!( + assistant_message_content(&first) + .contains("CONTINUE_PENDING_WORK_TRUE_TURN_ONE") + ); + session1 + .disconnect() + .await + .expect("disconnect first session"); + first_client.force_stop(); + + let resumed_client = start_external_client(ctx, port).await; + let session2 = resumed_client + .resume_session(resume_config(session_id).with_continue_pending_work(true)) + .await + .expect("resume session"); + let resume_event = session2 + .get_messages() + .await + .expect("messages") + .into_iter() + .find(|event| event.parsed_type() == SessionEventType::SessionResume) + .expect("session.resume event") + .typed_data::() + .expect("resume data"); + assert_eq!(resume_event.continue_pending_work, Some(true)); + assert_eq!(resume_event.session_was_active, Some(false)); + let follow_up = session2 + .send_and_wait("Reply with exactly: CONTINUE_PENDING_WORK_TRUE_TURN_TWO") + .await + .expect("send follow up") + .expect("follow-up answer"); + assert!( + assistant_message_content(&follow_up) + .contains("CONTINUE_PENDING_WORK_TRUE_TURN_TWO") + ); + + session2 + .disconnect() + .await + .expect("disconnect resumed session"); + resumed_client.force_stop(); + server.stop().await.expect("stop server client"); + }) + }, + ) + .await; +} + +fn resume_config(session_id: SessionId) -> ResumeSessionConfig { + ResumeSessionConfig::new(session_id) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(ApproveAllHandler)) +} + +async fn start_tcp_server(ctx: &E2eContext, port: u16) -> Client { + Client::start( + ctx.client_options_with_transport(Transport::Tcp { port }) + .with_tcp_connection_token(SHARED_TOKEN), + ) + .await + .expect("start TCP server client") +} + +async fn start_external_client(ctx: &E2eContext, port: u16) -> Client { + Client::start( + ctx.client_options_with_transport(Transport::External { + host: "127.0.0.1".to_string(), + port, + }) + .with_tcp_connection_token(SHARED_TOKEN), + ) + .await + .expect("start external client") +} + +fn free_tcp_port() -> u16 { + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind free TCP port"); + listener.local_addr().expect("local addr").port() +} + +struct BlockingExternalTool { + started_tx: mpsc::UnboundedSender, + release_rx: Mutex>>, +} + +impl BlockingExternalTool { + fn definition() -> Tool { + Tool::new("resume_external_tool") + .with_description("Looks up a value after resumption") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "Value to look up" + } + }, + "required": ["value"] + })) + } +} + +#[async_trait] +impl ToolHandler for BlockingExternalTool { + fn tool(&self) -> Tool { + Self::definition() + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let value = invocation + .arguments + .get("value") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + let _ = self.started_tx.send(value); + let release_rx = self + .release_rx + .lock() + .await + .take() + .expect("blocking tool called once"); + let result = release_rx + .await + .unwrap_or_else(|_| "ORIGINAL_SHOULD_NOT_WIN".to_string()); + Ok(ToolResult::Text(result)) + } +} diff --git a/rust/tests/e2e/per_session_auth.rs b/rust/tests/e2e/per_session_auth.rs new file mode 100644 index 000000000..cf19181e2 --- /dev/null +++ b/rust/tests/e2e/per_session_auth.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; + +use github_copilot_sdk::SessionConfig; +use github_copilot_sdk::handler::ApproveAllHandler; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn session_uses_client_token_when_no_session_token_is_supplied() { + with_e2e_context( + "per-session-auth", + "session_uses_client_token_when_no_session_token_is_supplied", + |ctx| { + Box::pin(async move { + let token = "alice-token"; + ctx.set_copilot_user_by_token_with_login(token, "alice"); + let client = github_copilot_sdk::Client::start( + ctx.client_options().with_github_token(token), + ) + .await + .expect("start client"); + + let session = client + .create_session( + SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + ) + .await + .expect("create session"); + let status = session + .rpc() + .auth() + .get_status() + .await + .expect("auth status"); + + assert!(status.is_authenticated); + assert_eq!(status.login.as_deref(), Some("alice")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_token_overrides_client_token() { + with_e2e_context( + "per-session-auth", + "session_token_overrides_client_token", + |ctx| { + Box::pin(async move { + ctx.set_copilot_user_by_token_with_login("alice-token", "alice"); + ctx.set_copilot_user_by_token_with_login("bob-token", "bob"); + let client = github_copilot_sdk::Client::start( + ctx.client_options().with_github_token("alice-token"), + ) + .await + .expect("start client"); + + let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_github_token("bob-token"), + ) + .await + .expect("create session"); + let status = session + .rpc() + .auth() + .get_status() + .await + .expect("auth status"); + + assert!(status.is_authenticated); + assert_eq!(status.login.as_deref(), Some("bob")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_auth_status_is_unauthenticated_without_token() { + with_e2e_context( + "per-session-auth", + "session_auth_status_is_unauthenticated_without_token", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), + ) + .await + .expect("create session"); + let status = session + .rpc() + .auth() + .get_status() + .await + .expect("auth status"); + + assert!(!status.is_authenticated); + assert!(status.login.is_none()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn session_fails_with_invalid_token() { + with_e2e_context( + "per-session-auth", + "session_fails_with_invalid_token", + |ctx| { + Box::pin(async move { + ctx.set_copilot_user_by_token_with_login("valid-token", "valid-user"); + let client = ctx.start_client().await; + + let err = match client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_github_token("invalid-token"), + ) + .await + { + Ok(_) => panic!("invalid token should fail session create"), + Err(err) => err, + }; + + assert!( + err.to_string().contains("401") || err.to_string().contains("Unauthorized"), + "expected unauthorized error, got {err}" + ); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/permissions.rs b/rust/tests/e2e/permissions.rs new file mode 100644 index 000000000..99aadfaac --- /dev/null +++ b/rust/tests/e2e/permissions.rs @@ -0,0 +1,672 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::generated::api_types::PermissionsSetApproveAllRequest; +use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData}; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::{ + PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, +}; +use tokio::sync::{mpsc, oneshot}; + +use super::support::{ + DEFAULT_TEST_TOKEN, assistant_message_content, recv_with_timeout, wait_for_condition, + wait_for_event, with_e2e_context, +}; + +#[tokio::test] +async fn should_work_with_approve_all_permission_handler() { + with_e2e_context( + "permissions", + "should_work_with_approve_all_permission_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 2+2?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('4')); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_permission_handler_errors_gracefully() { + let result = PermissionResult::UserNotAvailable; + + assert!(matches!(result, PermissionResult::UserNotAvailable)); +} + +#[tokio::test] +async fn should_handle_concurrent_permission_requests_from_parallel_tools() { + let requests = [ + RequestId::from("permission-1"), + RequestId::from("permission-2"), + ]; + + assert_eq!(requests.len(), 2); + assert_ne!(requests[0], requests[1]); +} + +#[tokio::test] +async fn should_deny_permission_when_handler_returns_denied() { + with_e2e_context( + "permissions", + "should_deny_permission_when_handler_returns_denied", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let test_file = ctx.work_dir().join("protected.txt"); + std::fs::write(&test_file, "protected content").expect("write protected file"); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(StaticPermissionHandler::new( + PermissionResult::Denied, + ))), + ) + .await + .expect("create session"); + + session + .send_and_wait("Edit protected.txt and replace 'protected' with 'hacked'.") + .await + .expect("send"); + + let content = std::fs::read_to_string(&test_file).expect("read protected file"); + assert_eq!(content, "protected content"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_deny_tool_operations_when_handler_explicitly_denies() { + with_e2e_context( + "permissions", + "should_deny_tool_operations_when_handler_explicitly_denies", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(StaticPermissionHandler::new( + PermissionResult::UserNotAvailable, + ))), + ) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("Run 'node --version'") + .await + .expect("send"); + + wait_for_event(events, "permission-denied tool completion", |event| { + is_permission_denied_tool_completion(event) + }) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_async_permission_handler() { + with_e2e_context( + "permissions", + "should_handle_async_permission_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(AsyncPermissionHandler { request_tx })), + ) + .await + .expect("create session"); + + session + .send_and_wait("Run 'echo test' and tell me what happens") + .await + .expect("send"); + + recv_with_timeout(&mut request_rx, "async permission request").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_resume_session_with_permission_handler() { + with_e2e_context( + "permissions", + "should_resume_session_with_permission_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + session + .send_and_wait("What is 1+1?") + .await + .expect("first send"); + let session_id = session.id().clone(); + session + .disconnect() + .await + .expect("disconnect first session"); + client.stop().await.expect("stop first client"); + + let new_client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let resumed = new_client + .resume_session( + ResumeSessionConfig::new(session_id) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + ) + .await + .expect("resume session"); + + resumed + .send_and_wait("Run 'echo resumed' for me") + .await + .expect("send after resume"); + + recv_with_timeout(&mut request_rx, "resumed permission request").await; + + resumed + .disconnect() + .await + .expect("disconnect resumed session"); + new_client.stop().await.expect("stop resumed client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_deny_tool_operations_when_handler_explicitly_denies_after_resume() { + with_e2e_context( + "permissions", + "should_deny_tool_operations_when_handler_explicitly_denies_after_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + session + .send_and_wait("What is 1+1?") + .await + .expect("first send"); + let session_id = session.id().clone(); + session + .disconnect() + .await + .expect("disconnect first session"); + client.stop().await.expect("stop first client"); + + let new_client = ctx.start_client().await; + let resumed = new_client + .resume_session( + ResumeSessionConfig::new(session_id) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(StaticPermissionHandler::new( + PermissionResult::UserNotAvailable, + ))), + ) + .await + .expect("resume session"); + let events = resumed.subscribe(); + + resumed + .send_and_wait("Run 'node --version'") + .await + .expect("send after resume"); + + wait_for_event( + events, + "resumed permission-denied tool completion", + is_permission_denied_tool_completion, + ) + .await; + + resumed + .disconnect() + .await + .expect("disconnect resumed session"); + new_client.stop().await.expect("stop resumed client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_receive_toolcallid_in_permission_requests() { + with_e2e_context( + "permissions", + "should_receive_toolcallid_in_permission_requests", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + ) + .await + .expect("create session"); + + session + .send_and_wait("Run 'echo test'") + .await + .expect("send"); + + let request = recv_with_timeout(&mut request_rx, "permission request").await; + assert!( + permission_request_tool_call_id(&request).is_some(), + "expected permission request to include a toolCallId: {request:?}" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_deny_permission_with_noresult_kind() { + with_e2e_context( + "permissions", + "should_deny_permission_with_noresult_kind", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(NotifyingPermissionHandler { + request_tx, + result: PermissionResult::NoResult, + })), + ) + .await + .expect("create session"); + + session.send("Run 'node --version'").await.expect("send"); + + recv_with_timeout(&mut request_rx, "no-result permission request").await; + session.abort().await.expect("abort no-result turn"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_short_circuit_permission_handler_when_set_approve_all_enabled() { + with_e2e_context( + "permissions", + "should_short_circuit_permission_handler_when_set_approve_all_enabled", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + ) + .await + .expect("create session"); + let set_result = session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: true }) + .await + .expect("set approve all"); + assert!(set_result.success); + let events = session.subscribe(); + + session + .send_and_wait("Run 'echo test' and tell me what happens") + .await + .expect("send"); + + wait_for_event(events, "successful tool completion", |event| { + event.parsed_type() == SessionEventType::ToolExecutionComplete + && event + .typed_data::() + .expect("tool.execution_complete data") + .success + }) + .await; + assert!( + request_rx.try_recv().is_err(), + "runtime approve-all should bypass the SDK permission handler" + ); + + let reset_result = session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: false }) + .await + .expect("reset approve all"); + assert!(reset_result.success); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_wait_for_slow_permission_handler() { + with_e2e_context( + "permissions", + "should_wait_for_slow_permission_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (entered_tx, entered_rx) = oneshot::channel(); + let (release_tx, release_rx) = oneshot::channel(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(SlowPermissionHandler { + entered_tx: tokio::sync::Mutex::new(Some(entered_tx)), + release_rx: tokio::sync::Mutex::new(Some(release_rx)), + })), + ) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send("Run 'echo slow_handler_test'") + .await + .expect("send"); + tokio::time::timeout(std::time::Duration::from_secs(30), entered_rx) + .await + .expect("permission handler entered timeout") + .expect("permission handler entered channel"); + assert!( + tokio::time::timeout( + std::time::Duration::from_millis(250), + wait_for_event(events, "premature tool completion", |event| { + event.parsed_type() == SessionEventType::ToolExecutionComplete + }), + ) + .await + .is_err(), + "tool completed before the permission handler returned" + ); + + release_tx.send(()).expect("release slow handler"); + wait_for_condition("assistant response after slow permission", || async { + session + .get_messages() + .await + .expect("get messages") + .iter() + .any(|event| { + event.parsed_type() == SessionEventType::AssistantMessage + && assistant_message_content(event).contains("slow_handler_test") + }) + }) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_invoke_permission_handler_for_write_operations() { + with_e2e_context( + "permissions", + "should_invoke_permission_handler_for_write_operations", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let test_file = ctx.work_dir().join("test.txt"); + std::fs::write(&test_file, "original content").expect("write test file"); + let client = ctx.start_client().await; + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + let session = client + .create_session( + github_copilot_sdk::SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(RecordingPermissionHandler { request_tx })), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Edit test.txt and replace 'original' with 'modified'") + .await + .expect("send") + .expect("assistant message"); + assert!(!assistant_message_content(&answer).is_empty()); + + let first = recv_with_timeout(&mut request_rx, "first permission request").await; + let second = recv_with_timeout(&mut request_rx, "second permission request").await; + assert!( + first.extra.is_object() || second.extra.is_object(), + "expected permission request payloads to preserve raw CLI fields" + ); + + let updated = std::fs::read_to_string(&test_file).expect("read updated file"); + assert_eq!(updated, "modified content"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn is_permission_denied_tool_completion(event: &github_copilot_sdk::SessionEvent) -> bool { + if event.parsed_type() != SessionEventType::ToolExecutionComplete { + return false; + } + let data = event + .typed_data::() + .expect("tool.execution_complete data"); + !data.success + && data + .error + .as_ref() + .map(|error| error.message.contains("Permission denied")) + .unwrap_or(false) +} + +fn permission_request_tool_call_id(request: &PermissionRequestData) -> Option<&str> { + request + .tool_call_id + .as_deref() + .or_else(|| { + request + .extra + .get("toolCallId") + .and_then(|value| value.as_str()) + }) + .or_else(|| { + request + .extra + .get("permissionRequest") + .and_then(|value| value.get("toolCallId")) + .and_then(|value| value.as_str()) + }) + .or_else(|| { + request + .extra + .get("promptRequest") + .and_then(|value| value.get("toolCallId")) + .and_then(|value| value.as_str()) + }) +} + +#[derive(Clone)] +struct StaticPermissionHandler { + result: PermissionResult, +} + +impl StaticPermissionHandler { + fn new(result: PermissionResult) -> Self { + Self { result } + } +} + +#[async_trait] +impl SessionHandler for StaticPermissionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + self.result.clone() + } +} + +struct RecordingPermissionHandler { + request_tx: mpsc::UnboundedSender, +} + +#[async_trait] +impl SessionHandler for RecordingPermissionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let _ = self.request_tx.send(data); + PermissionResult::Approved + } +} + +struct NotifyingPermissionHandler { + request_tx: mpsc::UnboundedSender, + result: PermissionResult, +} + +#[async_trait] +impl SessionHandler for NotifyingPermissionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let _ = self.request_tx.send(data); + self.result.clone() + } +} + +struct AsyncPermissionHandler { + request_tx: mpsc::UnboundedSender, +} + +#[async_trait] +impl SessionHandler for AsyncPermissionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + tokio::task::yield_now().await; + let _ = self.request_tx.send(data); + PermissionResult::Approved + } +} + +struct SlowPermissionHandler { + entered_tx: tokio::sync::Mutex>>, + release_rx: tokio::sync::Mutex>>, +} + +#[async_trait] +impl SessionHandler for SlowPermissionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + if let Some(entered_tx) = self.entered_tx.lock().await.take() { + let _ = entered_tx.send(()); + } + if let Some(release_rx) = self.release_rx.lock().await.take() { + let _ = release_rx.await; + } + PermissionResult::Approved + } +} diff --git a/rust/tests/e2e/rpc_additional_edge_cases.rs b/rust/tests/e2e/rpc_additional_edge_cases.rs new file mode 100644 index 000000000..bf35a2a87 --- /dev/null +++ b/rust/tests/e2e/rpc_additional_edge_cases.rs @@ -0,0 +1,535 @@ +use github_copilot_sdk::generated::api_types::{ + ModeSetRequest, NameSetRequest, PermissionsSetApproveAllRequest, PlanUpdateRequest, + SessionMode, ShellExecRequest, WorkspacesCreateFileRequest, WorkspacesReadFileRequest, +}; + +use super::support::{wait_for_condition, with_e2e_context}; + +#[tokio::test] +async fn shell_exec_with_zero_timeout_does_not_kill_long_running_command() { + with_e2e_context( + "rpc_additional_edge_cases", + "shell_exec_with_zero_timeout_does_not_kill_long_running_command", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let marker_path = ctx.work_dir().join("shell-zero-timeout-marker.txt"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: delayed_marker_command(&marker_path), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: Some(0), + }) + .await + .expect("execute shell command"); + + assert!(!result.process_id.trim().is_empty()); + wait_for_condition("zero-timeout shell marker", || async { + marker_path.exists() + }) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn workspaces_create_file_with_empty_content_round_trips() { + with_e2e_context( + "rpc_additional_edge_cases", + "workspaces_create_file_with_empty_content_round_trips", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let path = "empty-rust.txt"; + + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: String::new(), + }) + .await + .expect("create file"); + let read = session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string(), + }) + .await + .expect("read file"); + assert_eq!(read.content, ""); + let listed = session + .rpc() + .workspaces() + .list_files() + .await + .expect("list files"); + assert!(listed.files.iter().any(|file| file == path)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn workspaces_create_file_with_unicode_content_round_trips() { + with_e2e_context( + "rpc_additional_edge_cases", + "workspaces_create_file_with_unicode_content_round_trips", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let path = "unicode-rust.txt"; + let payload = "Hello, 世界! 🚀✨ Привет\u{0000}end"; + + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: payload.to_string(), + }) + .await + .expect("create file"); + let read = session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string(), + }) + .await + .expect("read file"); + assert_eq!(read.content, payload); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn workspaces_create_file_with_large_content_round_trips() { + with_e2e_context( + "rpc_additional_edge_cases", + "workspaces_create_file_with_large_content_round_trips", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let path = "large-rust.txt"; + let payload: String = (0..256 * 1024) + .map(|i| (b'a' + (i % 26) as u8) as char) + .collect(); + + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: payload.clone(), + }) + .await + .expect("create file"); + let read = session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string(), + }) + .await + .expect("read file"); + assert_eq!(read.content.len(), payload.len()); + assert_eq!(read.content, payload); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn plan_update_with_empty_content_then_read_returns_empty() { + with_e2e_context( + "rpc_additional_edge_cases", + "plan_update_with_empty_content_then_read_returns_empty", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .rpc() + .plan() + .update(PlanUpdateRequest { + content: String::new(), + }) + .await + .expect("update plan"); + let read = session.rpc().plan().read().await.expect("read plan"); + assert_eq!(read.content.as_deref(), Some("")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn plan_delete_when_none_exists_is_idempotent() { + with_e2e_context( + "rpc_additional_edge_cases", + "plan_delete_when_none_exists_is_idempotent", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session.rpc().plan().delete().await.expect("delete plan"); + session + .rpc() + .plan() + .delete() + .await + .expect("delete plan again"); + let read = session.rpc().plan().read().await.expect("read plan"); + assert!(read.content.as_deref().unwrap_or_default().is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn mode_set_to_same_value_multiple_times_stays_stable() { + with_e2e_context( + "rpc_additional_edge_cases", + "mode_set_to_same_value_multiple_times_stays_stable", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + for _ in 0..3 { + session + .rpc() + .mode() + .set(ModeSetRequest { + mode: SessionMode::Plan, + }) + .await + .expect("set mode"); + } + assert_eq!( + session.rpc().mode().get().await.expect("get mode"), + SessionMode::Plan + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn name_set_with_unicode_round_trips() { + with_e2e_context( + "rpc_additional_edge_cases", + "name_set_with_unicode_round_trips", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let name = "セッション 名前 ☕ – test"; + + session + .rpc() + .name() + .set(NameSetRequest { + name: name.to_string(), + }) + .await + .expect("set name"); + let read = session.rpc().name().get().await.expect("get name"); + assert_eq!(read.name.as_deref(), Some(name)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn usage_get_metrics_on_fresh_session_returns_zero_tokens() { + with_e2e_context( + "rpc_additional_edge_cases", + "usage_get_metrics_on_fresh_session_returns_zero_tokens", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let metrics = session.rpc().usage().get_metrics().await.expect("metrics"); + assert_eq!(metrics.last_call_input_tokens, 0); + assert_eq!(metrics.last_call_output_tokens, 0); + assert_eq!(metrics.total_user_requests, 0); + assert!(metrics.session_start_time > 0); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn permissions_reset_session_approvals_on_fresh_session_is_noop() { + with_e2e_context( + "rpc_additional_edge_cases", + "permissions_reset_session_approvals_on_fresh_session_is_noop", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .permissions() + .reset_session_approvals() + .await + .expect("reset approvals"); + assert!(result.success); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn permissions_set_approve_all_toggle_round_trips() { + with_e2e_context( + "rpc_additional_edge_cases", + "permissions_set_approve_all_toggle_round_trips", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert!( + session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: true }) + .await + .expect("enable approve all") + .success + ); + assert!( + session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: true }) + .await + .expect("enable approve all again") + .success + ); + assert!( + session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: false }) + .await + .expect("disable approve all") + .success + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn workspaces_createfile_then_listfiles_returns_sorted_or_stable_order() { + with_e2e_context( + "rpc_additional_edge_cases", + "workspaces_createfile_then_listfiles_returns_sorted_or_stable_order", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + for path in ["b-rust.txt", "a-rust.txt", "c-rust.txt"] { + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: path.to_string(), + }) + .await + .expect("create workspace file"); + } + + let first = session + .rpc() + .workspaces() + .list_files() + .await + .expect("list files"); + let second = session + .rpc() + .workspaces() + .list_files() + .await + .expect("list files again"); + assert_eq!(first.files, second.files); + for expected in ["a-rust.txt", "b-rust.txt", "c-rust.txt"] { + assert!(first.files.iter().any(|file| file == expected)); + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn workspaces_getworkspace_returns_stable_result_across_calls() { + with_e2e_context( + "rpc_additional_edge_cases", + "workspaces_getworkspace_returns_stable_result_across_calls", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let first = session + .rpc() + .workspaces() + .get_workspace() + .await + .expect("get workspace"); + let second = session + .rpc() + .workspaces() + .get_workspace() + .await + .expect("get workspace again"); + + assert_eq!( + first.workspace.as_ref().map(|workspace| &workspace.id), + second.workspace.as_ref().map(|workspace| &workspace.id) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[cfg(windows)] +fn delayed_marker_command(marker_path: &std::path::Path) -> String { + format!( + "powershell -NoLogo -NoProfile -Command \"Start-Sleep -Seconds 2; Set-Content -LiteralPath '{}' -Value done\"", + marker_path.display() + ) +} + +#[cfg(not(windows))] +fn delayed_marker_command(marker_path: &std::path::Path) -> String { + format!( + "sh -c \"sleep 2; printf done > '{}'\"", + marker_path.display() + ) +} diff --git a/rust/tests/e2e/rpc_agent.rs b/rust/tests/e2e/rpc_agent.rs new file mode 100644 index 000000000..47f9ff792 --- /dev/null +++ b/rust/tests/e2e/rpc_agent.rs @@ -0,0 +1,324 @@ +use github_copilot_sdk::CustomAgentConfig; +use github_copilot_sdk::generated::api_types::{AgentInfo, AgentSelectRequest}; +use github_copilot_sdk::generated::session_events::SessionEventType; +use serde_json::json; + +use super::support::{wait_for_event, with_e2e_context}; + +#[tokio::test] +async fn should_list_available_custom_agents() { + with_e2e_context("rpc_agents", "should_list_available_custom_agents", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents(create_custom_agents()), + ) + .await + .expect("create session"); + + let result = session.rpc().agent().list().await.expect("agent list"); + assert_agent(&result.agents, "test-agent", "Test Agent", "A test agent"); + assert_agent( + &result.agents, + "another-agent", + "Another Agent", + "Another test agent", + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_return_null_when_no_agent_is_selected() { + with_e2e_context( + "rpc_agents", + "should_return_null_when_no_agent_is_selected", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents([create_custom_agents().remove(0)]), + ) + .await + .expect("create session"); + + let value = client + .call( + "session.agent.getCurrent", + Some(json!({ "sessionId": session.id() })), + ) + .await + .expect("get current agent"); + assert!(value.get("agent").is_some_and(serde_json::Value::is_null)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_select_and_get_current_agent() { + with_e2e_context("rpc_agents", "should_select_and_get_current_agent", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents([create_custom_agents().remove(0)]), + ) + .await + .expect("create session"); + + let selected = session + .rpc() + .agent() + .select(AgentSelectRequest { + name: "test-agent".to_string(), + }) + .await + .expect("select agent"); + assert_eq!(selected.agent.name, "test-agent"); + assert_eq!(selected.agent.display_name, "Test Agent"); + + let current = session + .rpc() + .agent() + .get_current() + .await + .expect("get selected agent"); + assert_eq!(current.agent.name, "test-agent"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_emit_subagent_selected_and_deselected_events() { + with_e2e_context( + "rpc_agents", + "should_emit_subagent_selected_and_deselected_events", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents([create_custom_agents().remove(0)]), + ) + .await + .expect("create session"); + + let selected_event = + wait_for_event(session.subscribe(), "subagent selected", |event| { + event.parsed_type() == SessionEventType::SubagentSelected + }); + session + .rpc() + .agent() + .select(AgentSelectRequest { + name: "test-agent".to_string(), + }) + .await + .expect("select agent"); + let selected = selected_event.await; + assert_eq!( + selected + .data + .get("agentName") + .and_then(serde_json::Value::as_str), + Some("test-agent") + ); + assert_eq!( + selected + .data + .get("agentDisplayName") + .and_then(serde_json::Value::as_str), + Some("Test Agent") + ); + + let deselected_event = + wait_for_event(session.subscribe(), "subagent deselected", |event| { + event.parsed_type() == SessionEventType::SubagentDeselected + }); + session + .rpc() + .agent() + .deselect() + .await + .expect("deselect agent"); + deselected_event.await; + + let value = client + .call( + "session.agent.getCurrent", + Some(json!({ "sessionId": session.id() })), + ) + .await + .expect("get current agent after deselect"); + assert!(value.get("agent").is_some_and(serde_json::Value::is_null)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_deselect_current_agent() { + with_e2e_context("rpc_agents", "should_deselect_current_agent", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents([create_custom_agents().remove(0)]), + ) + .await + .expect("create session"); + + session + .rpc() + .agent() + .select(AgentSelectRequest { + name: "test-agent".to_string(), + }) + .await + .expect("select agent"); + session + .rpc() + .agent() + .deselect() + .await + .expect("deselect agent"); + let value = client + .call( + "session.agent.getCurrent", + Some(json!({ "sessionId": session.id() })), + ) + .await + .expect("get current agent"); + assert!(value.get("agent").is_some_and(serde_json::Value::is_null)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_return_empty_list_when_no_custom_agents_configured() { + with_e2e_context( + "rpc_agents", + "should_return_empty_list_when_no_custom_agents_configured", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session.rpc().agent().list().await.expect("agent list"); + assert!(result.agents.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_agent_reload() { + with_e2e_context("rpc_agents", "should_call_agent_reload", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let reload_agent = + CustomAgentConfig::new("reload-test-agent-rust", "You are a reload test agent.") + .with_display_name("Reload Test Agent") + .with_description("Used by the agent reload RPC test."); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_custom_agents([reload_agent.clone()]), + ) + .await + .expect("create session"); + + assert_agent( + &session + .rpc() + .agent() + .list() + .await + .expect("list before") + .agents, + "reload-test-agent-rust", + "Reload Test Agent", + "Used by the agent reload RPC test.", + ); + let reloaded = session.rpc().agent().reload().await.expect("reload agents"); + let current = session.rpc().agent().list().await.expect("list after"); + assert_eq!( + agent_names(&reloaded.agents), + agent_names(¤t.agents), + "reload result should match current list" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +fn create_custom_agents() -> Vec { + vec![ + CustomAgentConfig::new("test-agent", "You are a test agent.") + .with_display_name("Test Agent") + .with_description("A test agent"), + CustomAgentConfig::new("another-agent", "You are another agent.") + .with_display_name("Another Agent") + .with_description("Another test agent"), + ] +} + +fn assert_agent(agents: &[AgentInfo], name: &str, display_name: &str, description: &str) { + let agent = agents + .iter() + .find(|agent| agent.name == name) + .unwrap_or_else(|| panic!("missing agent {name}; actual agents: {agents:?}")); + assert_eq!(agent.display_name, display_name); + assert_eq!(agent.description, description); +} + +fn agent_names(agents: &[AgentInfo]) -> Vec<&str> { + let mut names: Vec<_> = agents.iter().map(|agent| agent.name.as_str()).collect(); + names.sort_unstable(); + names +} diff --git a/rust/tests/e2e/rpc_event_side_effects.rs b/rust/tests/e2e/rpc_event_side_effects.rs new file mode 100644 index 000000000..1c39dc317 --- /dev/null +++ b/rust/tests/e2e/rpc_event_side_effects.rs @@ -0,0 +1,354 @@ +use github_copilot_sdk::generated::api_types::{ + HistoryTruncateRequest, ModeSetRequest, NameSetRequest, PlanUpdateRequest, SessionMode, + WorkspacesCreateFileRequest, +}; +use github_copilot_sdk::generated::session_events::{ + PlanChangedOperation, SessionEventType, SessionModeChangedData, SessionPlanChangedData, + SessionSnapshotRewindData, SessionTitleChangedData, SessionWorkspaceFileChangedData, +}; + +use super::support::{assistant_message_content, wait_for_event, with_e2e_context}; + +#[tokio::test] +async fn should_emit_mode_changed_event_when_mode_set() { + with_e2e_context( + "rpc_event_side_effects", + "should_emit_mode_changed_event_when_mode_set", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let changed = wait_for_event(session.subscribe(), "mode changed", |event| { + if event.parsed_type() != SessionEventType::SessionModeChanged { + return false; + } + let data = event + .typed_data::() + .expect("mode changed data"); + data.previous_mode == "interactive" && data.new_mode == "plan" + }); + session + .rpc() + .mode() + .set(ModeSetRequest { + mode: SessionMode::Plan, + }) + .await + .expect("set mode"); + changed.await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_plan_changed_event_for_update_and_delete() { + with_e2e_context( + "rpc_event_side_effects", + "should_emit_plan_changed_event_for_update_and_delete", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let create = wait_for_plan_event(&session, PlanChangedOperation::Create); + session + .rpc() + .plan() + .update(PlanUpdateRequest { + content: "# Test plan\n- item".to_string(), + }) + .await + .expect("create plan"); + create.await; + + let delete = wait_for_plan_event(&session, PlanChangedOperation::Delete); + session.rpc().plan().delete().await.expect("delete plan"); + delete.await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_plan_changed_update_operation_on_second_update() { + with_e2e_context( + "rpc_event_side_effects", + "should_emit_plan_changed_update_operation_on_second_update", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .rpc() + .plan() + .update(PlanUpdateRequest { + content: "# initial".to_string(), + }) + .await + .expect("create plan"); + let update = wait_for_plan_event(&session, PlanChangedOperation::Update); + session + .rpc() + .plan() + .update(PlanUpdateRequest { + content: "# updated".to_string(), + }) + .await + .expect("update plan"); + update.await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_workspace_file_changed_event_when_file_created() { + with_e2e_context( + "rpc_event_side_effects", + "should_emit_workspace_file_changed_event_when_file_created", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let path = "side-effect-rust.txt"; + + let changed = + wait_for_event(session.subscribe(), "workspace file changed", |event| { + if event.parsed_type() != SessionEventType::SessionWorkspaceFileChanged { + return false; + } + event + .typed_data::() + .expect("workspace file changed data") + .path + == path + }); + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: "hello".to_string(), + }) + .await + .expect("create workspace file"); + changed.await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_title_changed_event_when_name_set() { + with_e2e_context( + "rpc_event_side_effects", + "should_emit_title_changed_event_when_name_set", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let title = "Renamed-Rust"; + + let changed = wait_for_event(session.subscribe(), "title changed", |event| { + if event.parsed_type() != SessionEventType::SessionTitleChanged { + return false; + } + event + .typed_data::() + .expect("title changed data") + .title + == title + }); + session + .rpc() + .name() + .set(NameSetRequest { + name: title.to_string(), + }) + .await + .expect("set name"); + changed.await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_snapshot_rewind_event_and_remove_events_on_truncate() { + with_e2e_context( + "rpc_event_side_effects", + "should_emit_snapshot_rewind_event_and_remove_events_on_truncate", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Say SNAPSHOT_REWIND_TARGET exactly.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("SNAPSHOT_REWIND_TARGET")); + let user_event = session + .get_messages() + .await + .expect("messages") + .into_iter() + .find(|event| event.parsed_type() == SessionEventType::UserMessage) + .expect("user.message event"); + let target_event_id = user_event.id.clone(); + + let rewind = wait_for_event(session.subscribe(), "snapshot rewind", |event| { + if event.parsed_type() != SessionEventType::SessionSnapshotRewind { + return false; + } + event + .typed_data::() + .expect("snapshot rewind data") + .up_to_event_id + == target_event_id + }); + let result = session + .rpc() + .history() + .truncate(HistoryTruncateRequest { + event_id: target_event_id.clone(), + }) + .await + .expect("truncate history"); + assert!(result.events_removed >= 1); + rewind.await; + + let remaining = session + .get_messages() + .await + .expect("messages after truncate"); + assert!(!remaining.iter().any(|event| event.id == target_event_id)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_allow_session_use_after_truncate() { + with_e2e_context( + "rpc_event_side_effects", + "should_allow_session_use_after_truncate", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait("Say SNAPSHOT_REWIND_TARGET exactly.") + .await + .expect("send"); + let user_event = session + .get_messages() + .await + .expect("messages") + .into_iter() + .find(|event| event.parsed_type() == SessionEventType::UserMessage) + .expect("user.message event"); + + let result = session + .rpc() + .history() + .truncate(HistoryTruncateRequest { + event_id: user_event.id, + }) + .await + .expect("truncate history"); + assert!(result.events_removed >= 1); + session + .rpc() + .mode() + .get() + .await + .expect("mode after truncate"); + session + .rpc() + .workspaces() + .get_workspace() + .await + .expect("workspace after truncate"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn wait_for_plan_event( + session: &github_copilot_sdk::session::Session, + operation: PlanChangedOperation, +) -> impl std::future::Future { + let events = session.subscribe(); + wait_for_event(events, "plan changed", move |event| { + if event.parsed_type() != SessionEventType::SessionPlanChanged { + return false; + } + event + .typed_data::() + .expect("plan changed data") + .operation + == operation + }) +} diff --git a/rust/tests/e2e/rpc_mcp_and_skills.rs b/rust/tests/e2e/rpc_mcp_and_skills.rs new file mode 100644 index 000000000..1d65a0416 --- /dev/null +++ b/rust/tests/e2e/rpc_mcp_and_skills.rs @@ -0,0 +1,483 @@ +use std::collections::HashMap; + +use github_copilot_sdk::generated::api_types::{ + ExtensionsDisableRequest, ExtensionsEnableRequest, McpDisableRequest, McpEnableRequest, + McpOauthLoginRequest, SkillsDisableRequest, SkillsEnableRequest, +}; +use github_copilot_sdk::{McpServerConfig, McpStdioServerConfig}; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn should_list_and_toggle_session_skills() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_list_and_toggle_session_skills", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let skill_name = "session-rpc-skill-rust"; + let skills_dir = create_skill_directory( + ctx.work_dir(), + skill_name, + "Session skill controlled by RPC.", + ); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_skill_directories([skills_dir]) + .with_disabled_skills([skill_name]), + ) + .await + .expect("create session"); + + assert_skill( + session.rpc().skills().list().await.expect("list disabled"), + skill_name, + false, + ); + session + .rpc() + .skills() + .enable(SkillsEnableRequest { + name: skill_name.to_string(), + }) + .await + .expect("enable skill"); + assert_skill( + session.rpc().skills().list().await.expect("list enabled"), + skill_name, + true, + ); + session + .rpc() + .skills() + .disable(SkillsDisableRequest { + name: skill_name.to_string(), + }) + .await + .expect("disable skill"); + assert_skill( + session + .rpc() + .skills() + .list() + .await + .expect("list disabled again"), + skill_name, + false, + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_reload_session_skills() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_reload_session_skills", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let skills_dir = ctx.work_dir().join("reloadable-rpc-skills"); + std::fs::create_dir_all(&skills_dir).expect("create skills dir"); + let skill_name = "reload-rpc-skill-rust"; + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_skill_directories([skills_dir.clone()]), + ) + .await + .expect("create session"); + + let before = session.rpc().skills().list().await.expect("list before"); + assert!(!before.skills.iter().any(|skill| skill.name == skill_name)); + + create_skill( + &skills_dir, + skill_name, + "Skill added after session creation.", + ); + session + .rpc() + .skills() + .reload() + .await + .expect("reload skills"); + let after = session.rpc().skills().list().await.expect("list after"); + let skill = assert_skill(after, skill_name, true); + assert_eq!(skill.description, "Skill added after session creation."); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_list_mcp_servers_with_configured_server() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_list_mcp_servers_with_configured_server", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let server_name = "rpc-list-mcp-server"; + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_servers(test_mcp_servers(server_name)), + ) + .await + .expect("create session"); + + let result = session.rpc().mcp().list().await.expect("mcp list"); + assert!( + result + .servers + .iter() + .any(|server| server.name == server_name) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_list_plugins() { + with_e2e_context("rpc_mcp_and_skills", "should_list_plugins", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session.rpc().plugins().list().await.expect("plugins list"); + assert!( + result.plugins.iter().all(|plugin| !plugin.name.is_empty()), + "plugins should have names: {:?}", + result.plugins + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_list_extensions() { + with_e2e_context("rpc_mcp_and_skills", "should_list_extensions", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = + github_copilot_sdk::Client::start(ctx.client_options().with_extra_args(["--yolo"])) + .await + .expect("start yolo client"); + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .extensions() + .list() + .await + .expect("extensions list"); + assert!( + result + .extensions + .iter() + .all(|extension| !extension.id.is_empty() && !extension.name.is_empty()), + "extensions should have ids and names: {:?}", + result.extensions + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_report_error_when_mcp_host_is_not_initialized() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_report_error_when_mcp_host_is_not_initialized", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + expect_err_contains( + session.rpc().mcp().enable(McpEnableRequest { + server_name: "missing-server".to_string(), + }), + "No MCP host initialized", + ) + .await; + expect_err_contains( + session.rpc().mcp().disable(McpDisableRequest { + server_name: "missing-server".to_string(), + }), + "No MCP host initialized", + ) + .await; + expect_err_contains( + session.rpc().mcp().reload(), + "MCP config reload not available", + ) + .await; + expect_err_contains( + session.rpc().mcp().oauth().login(McpOauthLoginRequest { + server_name: "missing-server".to_string(), + callback_success_message: None, + client_name: None, + force_reauth: None, + }), + "MCP host is not available", + ) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_error_when_mcp_oauth_server_is_not_configured() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_report_error_when_mcp_oauth_server_is_not_configured", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_servers(test_mcp_servers("configured-stdio-server")), + ) + .await + .expect("create session"); + + expect_err_contains( + session.rpc().mcp().oauth().login(McpOauthLoginRequest { + server_name: "missing-server".to_string(), + callback_success_message: None, + client_name: None, + force_reauth: None, + }), + "is not configured", + ) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_error_when_mcp_oauth_server_is_not_remote() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_report_error_when_mcp_oauth_server_is_not_remote", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let server_name = "configured-stdio-server"; + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_servers(test_mcp_servers(server_name)), + ) + .await + .expect("create session"); + + expect_err_contains( + session.rpc().mcp().oauth().login(McpOauthLoginRequest { + server_name: server_name.to_string(), + callback_success_message: Some("Done".to_string()), + client_name: Some("SDK E2E".to_string()), + force_reauth: Some(true), + }), + "not a remote server", + ) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_error_when_extensions_are_not_available() { + with_e2e_context( + "rpc_mcp_and_skills", + "should_report_error_when_extensions_are_not_available", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = github_copilot_sdk::Client::start( + ctx.client_options().with_extra_args(["--yolo"]), + ) + .await + .expect("start client"); + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + expect_err_contains( + session.rpc().extensions().enable(ExtensionsEnableRequest { + id: "missing-extension".to_string(), + }), + "Extensions not available", + ) + .await; + expect_err_contains( + session + .rpc() + .extensions() + .disable(ExtensionsDisableRequest { + id: "missing-extension".to_string(), + }), + "Extensions not available", + ) + .await; + expect_err_contains( + session.rpc().extensions().reload(), + "Extensions not available", + ) + .await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn create_skill_directory( + work_dir: &std::path::Path, + skill_name: &str, + description: &str, +) -> std::path::PathBuf { + let skills_dir = work_dir.join("session-rpc-skills"); + create_skill(&skills_dir, skill_name, description); + skills_dir +} + +fn create_skill(skills_dir: &std::path::Path, skill_name: &str, description: &str) { + let skill_dir = skills_dir.join(skill_name); + std::fs::create_dir_all(&skill_dir).expect("create skill dir"); + std::fs::write( + skill_dir.join("SKILL.md"), + format!( + "---\nname: {skill_name}\ndescription: {description}\n---\n\n# {skill_name}\n\nThis skill is used by RPC E2E tests.\n" + ), + ) + .expect("write skill"); +} + +fn assert_skill( + list: github_copilot_sdk::generated::api_types::SkillList, + skill_name: &str, + enabled: bool, +) -> github_copilot_sdk::generated::api_types::Skill { + let skill = list + .skills + .into_iter() + .find(|skill| skill.name == skill_name) + .unwrap_or_else(|| panic!("skill {skill_name} not found")); + assert_eq!(skill.enabled, enabled); + assert!( + skill + .path + .as_deref() + .is_some_and(|path| path.contains(skill_name) && path.ends_with("SKILL.md")) + ); + skill +} + +fn test_mcp_servers(message: &str) -> HashMap { + HashMap::from([( + message.to_string(), + McpServerConfig::Stdio(McpStdioServerConfig { + tools: vec!["*".to_string()], + command: echo_command(), + args: echo_args(message), + ..McpStdioServerConfig::default() + }), + )]) +} + +async fn expect_err_contains( + future: impl std::future::Future>, + expected: &str, +) { + let err = match future.await { + Ok(_) => panic!("expected RPC failure"), + Err(err) => err, + }; + assert!( + err.to_string() + .to_ascii_lowercase() + .contains(&expected.to_ascii_lowercase()), + "expected error to contain {expected:?}, got {err}" + ); +} + +#[cfg(windows)] +fn echo_command() -> String { + "cmd".to_string() +} + +#[cfg(not(windows))] +fn echo_command() -> String { + "echo".to_string() +} + +#[cfg(windows)] +fn echo_args(message: &str) -> Vec { + vec!["/C".to_string(), "echo".to_string(), message.to_string()] +} + +#[cfg(not(windows))] +fn echo_args(message: &str) -> Vec { + vec![message.to_string()] +} diff --git a/rust/tests/e2e/rpc_mcp_config.rs b/rust/tests/e2e/rpc_mcp_config.rs new file mode 100644 index 000000000..818d5119d --- /dev/null +++ b/rust/tests/e2e/rpc_mcp_config.rs @@ -0,0 +1,211 @@ +use github_copilot_sdk::generated::api_types::{ + McpConfigAddRequest, McpConfigDisableRequest, McpConfigEnableRequest, McpConfigRemoveRequest, + McpConfigUpdateRequest, +}; +use serde_json::json; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn should_call_server_mcp_config_rpcs() { + with_e2e_context( + "rpc_mcp_config", + "should_call_server_mcp_config_rpcs", + |ctx| { + Box::pin(async move { + let server_name = "rust-sdk-test-mcp-config"; + let client = ctx.start_client().await; + let config = client.rpc().mcp().config(); + let _ = config + .remove(McpConfigRemoveRequest { + name: server_name.to_string(), + }) + .await; + + let initial = config.list().await.expect("initial list"); + assert!(!initial.servers.contains_key(server_name)); + + config + .add(McpConfigAddRequest { + name: server_name.to_string(), + config: json!({ "command": "node", "args": [] }), + }) + .await + .expect("add"); + let after_add = config.list().await.expect("list after add"); + assert!(after_add.servers.contains_key(server_name)); + + config + .update(McpConfigUpdateRequest { + name: server_name.to_string(), + config: json!({ "command": "node", "args": ["--version"] }), + }) + .await + .expect("update"); + let after_update = config.list().await.expect("list after update"); + let updated = after_update + .servers + .get(server_name) + .expect("updated server"); + assert_eq!( + updated.get("command").and_then(|v| v.as_str()), + Some("node") + ); + assert_eq!( + updated + .get("args") + .and_then(|v| v.as_array()) + .and_then(|args| args.first()) + .and_then(|v| v.as_str()), + Some("--version") + ); + + config + .disable(McpConfigDisableRequest { + names: vec![server_name.to_string()], + }) + .await + .expect("disable"); + config + .enable(McpConfigEnableRequest { + names: vec![server_name.to_string()], + }) + .await + .expect("enable"); + config + .remove(McpConfigRemoveRequest { + name: server_name.to_string(), + }) + .await + .expect("remove"); + + let after_remove = config.list().await.expect("list after remove"); + assert!(!after_remove.servers.contains_key(server_name)); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_round_trip_http_mcp_oauth_config_rpc() { + with_e2e_context( + "rpc_mcp_config", + "should_round_trip_http_mcp_oauth_config_rpc", + |ctx| { + Box::pin(async move { + let server_name = "rust-sdk-http-oauth-mcp-config"; + let client = ctx.start_client().await; + let config = client.rpc().mcp().config(); + let _ = config + .remove(McpConfigRemoveRequest { + name: server_name.to_string(), + }) + .await; + + config + .add(McpConfigAddRequest { + name: server_name.to_string(), + config: json!({ + "type": "http", + "url": "https://example.com/mcp", + "headers": { "Authorization": "Bearer token" }, + "oauthClientId": "client-id", + "oauthPublicClient": false, + "oauthGrantType": "client_credentials", + "tools": ["*"], + "timeout": 3000 + }), + }) + .await + .expect("add"); + let after_add = config.list().await.expect("list after add"); + let added = after_add.servers.get(server_name).expect("added server"); + assert_eq!(added.get("type").and_then(|v| v.as_str()), Some("http")); + assert_eq!( + added.get("url").and_then(|v| v.as_str()), + Some("https://example.com/mcp") + ); + assert_eq!( + added + .get("headers") + .and_then(|v| v.get("Authorization")) + .and_then(|v| v.as_str()), + Some("Bearer token") + ); + assert_eq!( + added.get("oauthClientId").and_then(|v| v.as_str()), + Some("client-id") + ); + assert_eq!( + added.get("oauthPublicClient").and_then(|v| v.as_bool()), + Some(false) + ); + assert_eq!( + added.get("oauthGrantType").and_then(|v| v.as_str()), + Some("client_credentials") + ); + + config + .update(McpConfigUpdateRequest { + name: server_name.to_string(), + config: json!({ + "type": "http", + "url": "https://example.com/updated-mcp", + "oauthClientId": "updated-client-id", + "oauthPublicClient": true, + "oauthGrantType": "authorization_code", + "tools": ["updated-tool"], + "timeout": 4000 + }), + }) + .await + .expect("update"); + let after_update = config.list().await.expect("list after update"); + let updated = after_update + .servers + .get(server_name) + .expect("updated server"); + assert_eq!( + updated.get("url").and_then(|v| v.as_str()), + Some("https://example.com/updated-mcp") + ); + assert_eq!( + updated.get("oauthClientId").and_then(|v| v.as_str()), + Some("updated-client-id") + ); + assert_eq!( + updated.get("oauthPublicClient").and_then(|v| v.as_bool()), + Some(true) + ); + assert_eq!( + updated.get("oauthGrantType").and_then(|v| v.as_str()), + Some("authorization_code") + ); + assert_eq!( + updated + .get("tools") + .and_then(|v| v.as_array()) + .and_then(|tools| tools.first()) + .and_then(|v| v.as_str()), + Some("updated-tool") + ); + assert_eq!(updated.get("timeout").and_then(|v| v.as_i64()), Some(4000)); + + config + .remove(McpConfigRemoveRequest { + name: server_name.to_string(), + }) + .await + .expect("remove"); + let after_remove = config.list().await.expect("list after remove"); + assert!(!after_remove.servers.contains_key(server_name)); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/rpc_server.rs b/rust/tests/e2e/rpc_server.rs new file mode 100644 index 000000000..d1508541b --- /dev/null +++ b/rust/tests/e2e/rpc_server.rs @@ -0,0 +1,244 @@ +use github_copilot_sdk::Client; +use github_copilot_sdk::generated::api_types::{ + McpDiscoverRequest, PingRequest, SkillsConfigSetDisabledSkillsRequest, SkillsDiscoverRequest, + ToolsListRequest, +}; +use serde_json::json; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn should_call_rpc_ping_with_typed_params_and_result() { + with_e2e_context( + "rpc_server", + "should_call_rpc_ping_with_typed_params_and_result", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + + let result = client + .rpc() + .ping(PingRequest { + message: Some("typed rpc test".to_string()), + }) + .await + .expect("ping"); + + assert_eq!(result.message, "pong: typed rpc test"); + assert!(result.timestamp >= 0); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_rpc_models_list_with_typed_result() { + with_e2e_context( + "rpc_server", + "should_call_rpc_models_list_with_typed_result", + |ctx| { + Box::pin(async move { + let token = "rpc-models-token"; + ctx.set_copilot_user_by_token_with_login(token, "rpc-user"); + let client = Client::start(ctx.client_options().with_github_token(token)) + .await + .expect("start client"); + + let result = client.rpc().models().list().await.expect("models list"); + + assert!( + result + .models + .iter() + .any(|model| model.id == "claude-sonnet-4.5") + ); + assert!(result.models.iter().all(|model| !model.name.is_empty())); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_rpc_account_get_quota_when_authenticated() { + with_e2e_context( + "rpc_server", + "should_call_rpc_account_get_quota_when_authenticated", + |ctx| { + Box::pin(async move { + let token = "rpc-quota-token"; + ctx.set_copilot_user_by_token_with_login_and_quota( + token, + "rpc-user", + Some(json!({ + "chat": { + "entitlement": 100, + "overage_count": 2, + "overage_permitted": true, + "percent_remaining": 75, + "timestamp_utc": "2026-04-30T00:00:00Z" + } + })), + ); + let client = Client::start(ctx.client_options().with_github_token(token)) + .await + .expect("start client"); + + let result = client.rpc().account().get_quota().await.expect("quota"); + let chat = result.quota_snapshots.get("chat").expect("chat quota"); + + assert_eq!(chat.entitlement_requests, 100); + assert_eq!(chat.used_requests, 25); + assert_eq!(chat.remaining_percentage, 75.0); + assert_eq!(chat.overage, 2.0); + assert!(chat.usage_allowed_with_exhausted_quota); + assert!(chat.overage_allowed_with_exhausted_quota); + assert_eq!(chat.reset_date.as_deref(), Some("2026-04-30T00:00:00Z")); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_rpc_tools_list_with_typed_result() { + with_e2e_context( + "rpc_server", + "should_call_rpc_tools_list_with_typed_result", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + + let result = client + .rpc() + .tools() + .list(ToolsListRequest { model: None }) + .await + .expect("tools list"); + + assert!(!result.tools.is_empty()); + assert!(result.tools.iter().all(|tool| !tool.name.is_empty())); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_discover_server_mcp_and_skills() { + with_e2e_context( + "rpc_server", + "should_discover_server_mcp_and_skills", + |ctx| { + Box::pin(async move { + let skill_name = "server-rpc-skill-rust"; + let skill_directory = create_skill_directory( + ctx.work_dir(), + skill_name, + "Skill discovered by server-scoped RPC tests.", + ); + let client = ctx.start_client().await; + + let mcp = client + .rpc() + .mcp() + .discover(McpDiscoverRequest { + working_directory: Some(ctx.work_dir().to_string_lossy().to_string()), + }) + .await + .expect("mcp discover"); + assert!(mcp.servers.iter().all(|server| !server.name.is_empty())); + + let skills = client + .rpc() + .skills() + .discover(SkillsDiscoverRequest { + project_paths: Vec::new(), + skill_directories: vec![skill_directory.to_string_lossy().to_string()], + }) + .await + .expect("skills discover"); + let discovered = assert_server_skill(skills, skill_name, true); + assert_eq!( + discovered.description, + "Skill discovered by server-scoped RPC tests." + ); + + client + .rpc() + .skills() + .config() + .set_disabled_skills(SkillsConfigSetDisabledSkillsRequest { + disabled_skills: vec![skill_name.to_string()], + }) + .await + .expect("disable skill globally"); + let disabled_skills = client + .rpc() + .skills() + .discover(SkillsDiscoverRequest { + project_paths: Vec::new(), + skill_directories: vec![skill_directory.to_string_lossy().to_string()], + }) + .await + .expect("skills discover disabled"); + assert_server_skill(disabled_skills, skill_name, false); + + client + .rpc() + .skills() + .config() + .set_disabled_skills(SkillsConfigSetDisabledSkillsRequest { + disabled_skills: Vec::new(), + }) + .await + .expect("clear disabled skills"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn create_skill_directory( + work_dir: &std::path::Path, + skill_name: &str, + description: &str, +) -> std::path::PathBuf { + let skills_dir = work_dir.join("server-rpc-skills"); + let skill_dir = skills_dir.join(skill_name); + std::fs::create_dir_all(&skill_dir).expect("create skill dir"); + std::fs::write( + skill_dir.join("SKILL.md"), + format!( + "---\nname: {skill_name}\ndescription: {description}\n---\n\n# {skill_name}\n\nThis skill is used by RPC E2E tests.\n" + ), + ) + .expect("write skill"); + skills_dir +} + +fn assert_server_skill( + list: github_copilot_sdk::generated::api_types::ServerSkillList, + skill_name: &str, + enabled: bool, +) -> github_copilot_sdk::generated::api_types::ServerSkill { + let skill = list + .skills + .into_iter() + .find(|skill| skill.name == skill_name) + .unwrap_or_else(|| panic!("skill {skill_name} not found")); + assert_eq!(skill.enabled, enabled); + assert!( + skill + .path + .as_deref() + .is_some_and(|path| path.contains(skill_name) && path.ends_with("SKILL.md")) + ); + skill +} diff --git a/rust/tests/e2e/rpc_session_state.rs b/rust/tests/e2e/rpc_session_state.rs new file mode 100644 index 000000000..8a8ae5c18 --- /dev/null +++ b/rust/tests/e2e/rpc_session_state.rs @@ -0,0 +1,990 @@ +use github_copilot_sdk::generated::api_types::{ + HistoryTruncateRequest, McpOauthLoginRequest, ModeSetRequest, ModelSwitchToRequest, + NameSetRequest, PermissionsSetApproveAllRequest, PlanUpdateRequest, SessionMode, + SessionsForkRequest, WorkspacesCreateFileRequest, WorkspacesReadFileRequest, +}; +use github_copilot_sdk::generated::session_events::{ + AssistantMessageData, SessionEventType, SessionTitleChangedData, + SessionWorkspaceFileChangedData, UserMessageData, WorkspaceFileChangedOperation, +}; + +use super::support::{assistant_message_content, wait_for_event, with_e2e_context}; + +#[tokio::test] +async fn should_call_session_rpc_model_getcurrent() { + with_e2e_context( + "rpc_session_state", + "should_call_session_rpc_model_getcurrent", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_model("claude-sonnet-4.5"), + ) + .await + .expect("create session"); + + let current = session + .rpc() + .model() + .get_current() + .await + .expect("get current model"); + assert_eq!(current.model_id.as_deref(), Some("claude-sonnet-4.5")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_session_rpc_model_switchto() { + with_e2e_context( + "rpc_session_state", + "should_call_session_rpc_model_switchto", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_model("claude-sonnet-4.5"), + ) + .await + .expect("create session"); + + let result = session + .rpc() + .model() + .switch_to(ModelSwitchToRequest { + model_id: "gpt-4.1".to_string(), + reasoning_effort: Some("high".to_string()), + model_capabilities: None, + }) + .await + .expect("switch model"); + assert_eq!(result.model_id.as_deref(), Some("gpt-4.1")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_get_and_set_session_mode() { + with_e2e_context( + "rpc_session_state", + "should_get_and_set_session_mode", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert_eq!( + session.rpc().mode().get().await.expect("get mode"), + SessionMode::Interactive + ); + session + .rpc() + .mode() + .set(ModeSetRequest { + mode: SessionMode::Plan, + }) + .await + .expect("set plan"); + assert_eq!( + session.rpc().mode().get().await.expect("get mode"), + SessionMode::Plan + ); + session + .rpc() + .mode() + .set(ModeSetRequest { + mode: SessionMode::Interactive, + }) + .await + .expect("set interactive"); + assert_eq!( + session.rpc().mode().get().await.expect("get mode"), + SessionMode::Interactive + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_set_and_get_each_session_mode_value() { + with_e2e_context( + "rpc_session_state", + "should_set_and_get_each_session_mode_value", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + for mode in [ + SessionMode::Interactive, + SessionMode::Plan, + SessionMode::Autopilot, + ] { + session + .rpc() + .mode() + .set(ModeSetRequest { mode: mode.clone() }) + .await + .expect("set mode"); + assert_eq!(session.rpc().mode().get().await.expect("get mode"), mode); + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_read_update_and_delete_plan() { + with_e2e_context( + "rpc_session_state", + "should_read_update_and_delete_plan", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let content = "# Test Plan\n\n- Step 1\n- Step 2"; + + let initial = session.rpc().plan().read().await.expect("read initial"); + assert!(!initial.exists); + assert!(initial.content.is_none()); + session + .rpc() + .plan() + .update(PlanUpdateRequest { + content: content.to_string(), + }) + .await + .expect("update plan"); + let updated = session.rpc().plan().read().await.expect("read updated"); + assert!(updated.exists); + assert_eq!(updated.content.as_deref(), Some(content)); + session.rpc().plan().delete().await.expect("delete plan"); + let deleted = session.rpc().plan().read().await.expect("read deleted"); + assert!(!deleted.exists); + assert!(deleted.content.is_none()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_workspace_file_rpc_methods() { + with_e2e_context( + "rpc_session_state", + "should_call_workspace_file_rpc_methods", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let initial = session + .rpc() + .workspaces() + .list_files() + .await + .expect("list files"); + assert!(initial.files.is_empty()); + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: "test.txt".to_string(), + content: "Hello, workspace!".to_string(), + }) + .await + .expect("create file"); + let listed = session + .rpc() + .workspaces() + .list_files() + .await + .expect("list files"); + assert!(listed.files.iter().any(|file| file == "test.txt")); + let read = session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: "test.txt".to_string(), + }) + .await + .expect("read file"); + assert_eq!(read.content, "Hello, workspace!"); + let workspace = session + .rpc() + .workspaces() + .get_workspace() + .await + .expect("get workspace"); + assert!(workspace.workspace.is_some()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_reject_workspace_file_path_traversal() { + with_e2e_context( + "rpc_session_state", + "should_reject_workspace_file_path_traversal", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let err = session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: "../escaped.txt".to_string(), + content: "outside".to_string(), + }) + .await + .expect_err("path traversal should fail"); + assert!(err.to_string().contains("workspace")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_workspace_file_with_nested_path_auto_creating_dirs() { + with_e2e_context( + "rpc_session_state", + "should_create_workspace_file_with_nested_path_auto_creating_dirs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let path = "nested-rust/subdir/file.txt"; + + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: "nested content".to_string(), + }) + .await + .expect("create nested file"); + let read = session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string(), + }) + .await + .expect("read nested file"); + assert_eq!(read.content, "nested content"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_error_reading_nonexistent_workspace_file() { + with_e2e_context( + "rpc_session_state", + "should_report_error_reading_nonexistent_workspace_file", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert!( + session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: "never-exists-rust.txt".to_string(), + }) + .await + .is_err() + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_update_existing_workspace_file_with_update_operation() { + with_e2e_context( + "rpc_session_state", + "should_update_existing_workspace_file_with_update_operation", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let path = "reused-rust.txt"; + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: "v1".to_string(), + }) + .await + .expect("create file"); + + let update_event = + wait_for_event(session.subscribe(), "workspace update event", |event| { + if event.parsed_type() != SessionEventType::SessionWorkspaceFileChanged { + return false; + } + let data = event + .typed_data::() + .expect("workspace file changed data"); + data.path == path && data.operation == WorkspaceFileChangedOperation::Update + }); + session + .rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string(), + content: "v2".to_string(), + }) + .await + .expect("update file"); + update_event.await; + let read = session + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string(), + }) + .await + .expect("read updated"); + assert_eq!(read.content, "v2"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_reject_empty_or_whitespace_session_name() { + with_e2e_context( + "rpc_session_state", + "should_reject_empty_or_whitespace_session_name", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + for name in ["", " ", "\t\n \r"] { + let err = session + .rpc() + .name() + .set(NameSetRequest { + name: name.to_string(), + }) + .await + .expect_err("empty name should fail"); + assert!(err.to_string().to_lowercase().contains("empty")); + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_title_changed_event_each_time_name_set_is_called() { + with_e2e_context( + "rpc_session_state", + "should_emit_title_changed_event_each_time_name_set_is_called", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + for title in ["Title-A-Rust", "Title-B-Rust"] { + let event = wait_for_event(session.subscribe(), "title changed", |event| { + if event.parsed_type() != SessionEventType::SessionTitleChanged { + return false; + } + event + .typed_data::() + .expect("title data") + .title + == title + }); + session + .rpc() + .name() + .set(NameSetRequest { + name: title.to_string(), + }) + .await + .expect("set name"); + event.await; + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_get_and_set_session_metadata() { + with_e2e_context( + "rpc_session_state", + "should_get_and_set_session_metadata", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .rpc() + .name() + .set(NameSetRequest { + name: "SDK test session".to_string(), + }) + .await + .expect("set name"); + assert_eq!( + session + .rpc() + .name() + .get() + .await + .expect("get name") + .name + .as_deref(), + Some("SDK test session") + ); + let sources = session + .rpc() + .instructions() + .get_sources() + .await + .expect("get instruction sources"); + assert!(sources.sources.is_empty() || !sources.sources.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_call_session_usage_and_permission_rpcs() { + with_e2e_context( + "rpc_session_state", + "should_call_session_usage_and_permission_rpcs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let metrics = session.rpc().usage().get_metrics().await.expect("metrics"); + assert!(metrics.session_start_time > 0); + assert!( + session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: true }) + .await + .expect("set approve all") + .success + ); + assert!( + session + .rpc() + .permissions() + .reset_session_approvals() + .await + .expect("reset approvals") + .success + ); + session + .rpc() + .permissions() + .set_approve_all(PermissionsSetApproveAllRequest { enabled: false }) + .await + .expect("disable approve all"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_fork_session_with_persisted_messages() { + with_e2e_context( + "rpc_session_state", + "should_fork_session_with_persisted_messages", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Say FORK_SOURCE_ALPHA exactly.") + .await + .expect("send source") + .expect("source answer"); + assert!(assistant_message_content(&answer).contains("FORK_SOURCE_ALPHA")); + let fork = client + .rpc() + .sessions() + .fork(SessionsForkRequest { + session_id: session.id().clone(), + to_event_id: None, + }) + .await + .expect("fork session"); + assert_ne!(fork.session_id, *session.id()); + let forked = client + .resume_session( + github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )), + ) + .await + .expect("resume fork"); + let forked_messages = forked.get_messages().await.expect("forked messages"); + assert!(contains_user_message( + &forked_messages, + "Say FORK_SOURCE_ALPHA exactly." + )); + assert!(contains_assistant_message( + &forked_messages, + "FORK_SOURCE_ALPHA" + )); + + let fork_answer = forked + .send_and_wait("Now say FORK_CHILD_BETA exactly.") + .await + .expect("send fork") + .expect("fork answer"); + assert!(assistant_message_content(&fork_answer).contains("FORK_CHILD_BETA")); + let source_after = session.get_messages().await.expect("source messages"); + assert!(!contains_user_message( + &source_after, + "Now say FORK_CHILD_BETA exactly." + )); + + forked.disconnect().await.expect("disconnect fork"); + session.disconnect().await.expect("disconnect source"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_forking_session_without_persisted_events() { + with_e2e_context( + "rpc_session_state", + "should_handle_forking_session_without_persisted_events", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + match client + .rpc() + .sessions() + .fork(SessionsForkRequest { + session_id: session.id().clone(), + to_event_id: None, + }) + .await + { + Ok(fork) => { + assert!(!fork.session_id.as_str().trim().is_empty()); + assert_ne!(fork.session_id, *session.id()); + let forked = client + .resume_session( + github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )), + ) + .await + .expect("resume fork"); + assert!( + !forked + .get_messages() + .await + .expect("forked messages") + .iter() + .any(|event| { + matches!( + event.parsed_type(), + SessionEventType::UserMessage + | SessionEventType::AssistantMessage + ) + }) + ); + forked.disconnect().await.expect("disconnect fork"); + } + Err(err) => { + let message = err.to_string(); + assert!( + message.contains("not found or has no persisted events"), + "unexpected sessions.fork error: {message}" + ); + assert!( + !message.contains("Unhandled method sessions.fork"), + "expected implemented error for sessions.fork, got {message}" + ); + } + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_fork_session_to_event_id_excluding_boundary_event() { + with_e2e_context( + "rpc_session_state", + "should_fork_session_to_event_id_excluding_boundary_event", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait("Say FORK_BOUNDARY_FIRST exactly.") + .await + .expect("send first"); + session + .send_and_wait("Say FORK_BOUNDARY_SECOND exactly.") + .await + .expect("send second"); + let source_events = session.get_messages().await.expect("messages"); + let boundary_id = source_events + .iter() + .find(|event| { + event.parsed_type() == SessionEventType::UserMessage + && event.typed_data::().is_some_and(|data| { + data.content == "Say FORK_BOUNDARY_SECOND exactly." + }) + }) + .expect("second user message") + .id + .clone(); + let fork = client + .rpc() + .sessions() + .fork(SessionsForkRequest { + session_id: session.id().clone(), + to_event_id: Some(boundary_id.clone()), + }) + .await + .expect("fork to boundary"); + let forked = client + .resume_session( + github_copilot_sdk::ResumeSessionConfig::new(fork.session_id) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )), + ) + .await + .expect("resume fork"); + let forked_events = forked.get_messages().await.expect("forked messages"); + assert!(contains_user_message( + &forked_events, + "Say FORK_BOUNDARY_FIRST exactly." + )); + assert!(!forked_events.iter().any(|event| event.id == boundary_id)); + assert!(!contains_user_message( + &forked_events, + "Say FORK_BOUNDARY_SECOND exactly." + )); + + forked.disconnect().await.expect("disconnect fork"); + session.disconnect().await.expect("disconnect source"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_error_when_forking_session_to_unknown_event_id() { + with_e2e_context( + "rpc_session_state", + "should_report_error_when_forking_session_to_unknown_event_id", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + session + .send_and_wait("Say FORK_UNKNOWN_EVENT_OK exactly.") + .await + .expect("send source"); + let bogus_event_id = "00000000-0000-0000-0000-000000000000"; + + assert_implemented_error( + client + .rpc() + .sessions() + .fork(SessionsForkRequest { + session_id: session.id().clone(), + to_event_id: Some(bogus_event_id.to_string()), + }) + .await, + "sessions.fork", + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_implemented_errors_for_unsupported_session_rpc_paths() { + with_e2e_context( + "rpc_session_state", + "should_report_implemented_errors_for_unsupported_session_rpc_paths", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert_implemented_error( + session + .rpc() + .history() + .truncate(HistoryTruncateRequest { + event_id: "missing-event".to_string(), + }) + .await, + "session.history.truncate", + ); + assert_implemented_error( + session + .rpc() + .mcp() + .oauth() + .login(McpOauthLoginRequest { + server_name: "missing-server".to_string(), + callback_success_message: None, + client_name: None, + force_reauth: None, + }) + .await, + "session.mcp.oauth.login", + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_compact_session_history_after_messages() { + with_e2e_context( + "rpc_session_state", + "should_compact_session_history_after_messages", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 2+2?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('4')); + + let compact = session + .rpc() + .history() + .compact() + .await + .expect("compact history"); + assert!(compact.success); + assert!(compact.messages_removed >= 0); + session.rpc().name().get().await.expect("name still works"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn contains_user_message(events: &[github_copilot_sdk::SessionEvent], expected: &str) -> bool { + events.iter().any(|event| { + event.parsed_type() == SessionEventType::UserMessage + && event + .typed_data::() + .is_some_and(|data| data.content == expected) + }) +} + +fn contains_assistant_message(events: &[github_copilot_sdk::SessionEvent], expected: &str) -> bool { + events.iter().any(|event| { + event.parsed_type() == SessionEventType::AssistantMessage + && event + .typed_data::() + .is_some_and(|data| data.content.contains(expected)) + }) +} + +fn assert_implemented_error(result: Result, method: &str) { + let err = match result { + Ok(_) => panic!("RPC should fail"), + Err(err) => err, + }; + let message = err.to_string(); + assert!( + !message.contains(&format!("Unhandled method {method}")), + "expected implemented error for {method}, got {message}" + ); +} diff --git a/rust/tests/e2e/rpc_shell_and_fleet.rs b/rust/tests/e2e/rpc_shell_and_fleet.rs new file mode 100644 index 000000000..eb389421a --- /dev/null +++ b/rust/tests/e2e/rpc_shell_and_fleet.rs @@ -0,0 +1,115 @@ +use github_copilot_sdk::generated::api_types::{ShellExecRequest, ShellKillRequest}; + +use super::support::{wait_for_condition, with_e2e_context}; + +#[tokio::test] +async fn should_execute_shell_command() { + with_e2e_context( + "rpc_shell_and_fleet", + "should_execute_shell_command", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let marker_path = ctx.work_dir().join("shell-rpc-marker.txt"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: write_file_command(&marker_path, "copilot-sdk-shell-rpc"), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: None, + }) + .await + .expect("execute shell command"); + + assert!(!result.process_id.trim().is_empty()); + wait_for_file_text(&marker_path, "copilot-sdk-shell-rpc").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_kill_shell_process() { + with_e2e_context("rpc_shell_and_fleet", "should_kill_shell_process", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let exec = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: long_running_command(), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: None, + }) + .await + .expect("start shell process"); + assert!(!exec.process_id.trim().is_empty()); + + let killed = session + .rpc() + .shell() + .kill(ShellKillRequest { + process_id: exec.process_id, + signal: None, + }) + .await + .expect("kill shell process"); + assert!(killed.killed); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +async fn wait_for_file_text(path: &std::path::Path, expected: &'static str) { + wait_for_condition("shell command output file", || async { + match std::fs::read_to_string(path) { + Ok(content) => content.contains(expected), + Err(_) => false, + } + }) + .await; +} + +#[cfg(windows)] +fn write_file_command(path: &std::path::Path, marker: &str) -> String { + format!( + "powershell -NoLogo -NoProfile -Command \"Set-Content -LiteralPath '{}' -Value '{}'\"", + path.display(), + marker + ) +} + +#[cfg(not(windows))] +fn write_file_command(path: &std::path::Path, marker: &str) -> String { + format!("sh -c \"printf '%s' '{}' > '{}'\"", marker, path.display()) +} + +#[cfg(windows)] +fn long_running_command() -> String { + "powershell -NoLogo -NoProfile -Command \"Start-Sleep -Seconds 30\"".to_string() +} + +#[cfg(not(windows))] +fn long_running_command() -> String { + "sleep 30".to_string() +} diff --git a/rust/tests/e2e/rpc_shell_edge_cases.rs b/rust/tests/e2e/rpc_shell_edge_cases.rs new file mode 100644 index 000000000..a94bc1007 --- /dev/null +++ b/rust/tests/e2e/rpc_shell_edge_cases.rs @@ -0,0 +1,392 @@ +use std::path::Path; + +use github_copilot_sdk::generated::api_types::{ + ShellExecRequest, ShellKillRequest, ShellKillSignal, +}; + +use super::support::{wait_for_condition, with_e2e_context}; + +#[tokio::test] +async fn shell_exec_with_timeout_kills_long_running_command() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_exec_with_timeout_kills_long_running_command", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let started_path = ctx.work_dir().join("shell-timeout-started.txt"); + let marker_path = ctx.work_dir().join("shell-timeout-marker.txt"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: delayed_write_command(&started_path, &marker_path), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: Some(200), + }) + .await + .expect("execute timed command"); + assert!(!result.process_id.trim().is_empty()); + + wait_for_exists(&started_path).await; + wait_for_process_cleanup(&session, result.process_id, "timed-out command").await; + assert!( + !marker_path.exists(), + "timeout should kill before marker is written" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn shell_exec_with_custom_cwd_honors_override() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_exec_with_custom_cwd_honors_override", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let subdir = ctx.work_dir().join("shell-cwd"); + std::fs::create_dir_all(&subdir).expect("create shell cwd"); + let marker_path = subdir.join("marker.txt"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: write_relative_marker_command("shell-cwd-marker"), + cwd: Some(subdir.display().to_string()), + timeout: None, + }) + .await + .expect("execute cwd command"); + + assert!(!result.process_id.trim().is_empty()); + wait_for_file_text(&marker_path, "shell-cwd-marker").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn shell_exec_with_nonexistent_command_returns_processid_and_cleans_up() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_exec_with_nonexistent_command_returns_processid_and_cleans_up", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: nonexistent_command(), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: None, + }) + .await + .expect("execute nonexistent command"); + + assert!(!result.process_id.trim().is_empty()); + wait_for_process_cleanup(&session, result.process_id, "nonexistent command").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn shell_kill_unknown_processid_returns_false() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_kill_unknown_processid_returns_false", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .kill(ShellKillRequest { + process_id: "unknown-rust-process".to_string(), + signal: None, + }) + .await + .expect("kill unknown process"); + + assert!(!result.killed); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn shell_kill_cleans_up_after_terminating_signal() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_kill_cleans_up_after_terminating_signal", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let exec = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: long_running_command(), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: None, + }) + .await + .expect("start shell"); + + let killed = session + .rpc() + .shell() + .kill(ShellKillRequest { + process_id: exec.process_id.clone(), + signal: Some(ShellKillSignal::SIGTERM), + }) + .await + .expect("kill shell"); + assert!(killed.killed); + wait_for_process_cleanup(&session, exec.process_id, "killed command").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn shell_exec_with_stderr_output_cleans_up() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_exec_with_stderr_output_cleans_up", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let marker_path = ctx.work_dir().join("shell-stderr-marker.txt"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: stderr_command(&marker_path), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: None, + }) + .await + .expect("execute stderr command"); + + wait_for_exists(&marker_path).await; + wait_for_process_cleanup(&session, result.process_id, "stderr command").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn shell_exec_with_large_stdout_cleans_up() { + with_e2e_context( + "rpc_shell_edge_cases", + "shell_exec_with_large_stdout_cleans_up", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let marker_path = ctx.work_dir().join("shell-stdout-marker.txt"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .rpc() + .shell() + .exec(ShellExecRequest { + command: large_stdout_command(&marker_path), + cwd: Some(ctx.work_dir().display().to_string()), + timeout: None, + }) + .await + .expect("execute large stdout command"); + + wait_for_exists(&marker_path).await; + wait_for_process_cleanup(&session, result.process_id, "large stdout command").await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +async fn wait_for_exists(path: &Path) { + wait_for_condition("shell marker file", || async { path.exists() }).await; +} + +async fn wait_for_file_text(path: &Path, expected: &'static str) { + wait_for_condition("shell marker text", || async { + match std::fs::read_to_string(path) { + Ok(content) => content.contains(expected), + Err(_) => false, + } + }) + .await; +} + +async fn wait_for_process_cleanup( + session: &github_copilot_sdk::session::Session, + process_id: String, + _scenario: &'static str, +) { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let result = session + .rpc() + .shell() + .kill(ShellKillRequest { + process_id, + signal: None, + }) + .await + .expect("probe process cleanup"); + assert!(!result.killed); +} + +#[cfg(windows)] +fn delayed_write_command(started_path: &Path, marker_path: &Path) -> String { + format!( + "powershell -NoLogo -NoProfile -Command \"Set-Content -LiteralPath '{}' -Value started; Start-Sleep -Seconds 30; Set-Content -LiteralPath '{}' -Value should-not-exist\"", + started_path.display(), + marker_path.display() + ) +} + +#[cfg(not(windows))] +fn delayed_write_command(started_path: &Path, marker_path: &Path) -> String { + format!( + "sh -c \"printf started > '{}'; sleep 30; printf should-not-exist > '{}'\"", + started_path.display(), + marker_path.display() + ) +} + +#[cfg(windows)] +fn write_relative_marker_command(marker: &str) -> String { + format!( + "powershell -NoLogo -NoProfile -Command \"Set-Content -LiteralPath 'marker.txt' -Value '{marker}'\"" + ) +} + +#[cfg(not(windows))] +fn write_relative_marker_command(marker: &str) -> String { + format!("sh -c \"printf '%s' '{marker}' > marker.txt\"") +} + +#[cfg(windows)] +fn long_running_command() -> String { + "powershell -NoLogo -NoProfile -Command \"Start-Sleep -Seconds 60\"".to_string() +} + +#[cfg(not(windows))] +fn long_running_command() -> String { + "sleep 60".to_string() +} + +#[cfg(windows)] +fn nonexistent_command() -> String { + "cmd /C definitely-not-a-real-command-rust-12345".to_string() +} + +#[cfg(not(windows))] +fn nonexistent_command() -> String { + "sh -c 'definitely-not-a-real-command-rust-12345'".to_string() +} + +#[cfg(windows)] +fn stderr_command(marker_path: &Path) -> String { + format!( + "powershell -NoLogo -NoProfile -Command \"[Console]::Error.WriteLine('boom'); Set-Content -LiteralPath '{}' -Value done; exit 2\"", + marker_path.display() + ) +} + +#[cfg(not(windows))] +fn stderr_command(marker_path: &Path) -> String { + format!( + "sh -c \"echo boom 1>&2; printf done > '{}'; exit 2\"", + marker_path.display() + ) +} + +#[cfg(windows)] +fn large_stdout_command(marker_path: &Path) -> String { + format!( + "powershell -NoLogo -NoProfile -Command \"Write-Host ('x' * 204800); Set-Content -LiteralPath '{}' -Value done\"", + marker_path.display() + ) +} + +#[cfg(not(windows))] +fn large_stdout_command(marker_path: &Path) -> String { + format!( + "sh -c \"python3 - <<'PY'\nprint('x' * 204800)\nPY\nprintf done > '{}'\"", + marker_path.display() + ) +} diff --git a/rust/tests/e2e/rpc_tasks_and_handlers.rs b/rust/tests/e2e/rpc_tasks_and_handlers.rs new file mode 100644 index 000000000..d98f88598 --- /dev/null +++ b/rust/tests/e2e/rpc_tasks_and_handlers.rs @@ -0,0 +1,293 @@ +use github_copilot_sdk::generated::api_types::{ + CommandsHandlePendingCommandRequest, HandlePendingToolCallRequest, PermissionDecision, + PermissionDecisionApproveForLocation, PermissionDecisionApproveForLocationApproval, + PermissionDecisionApproveForLocationApprovalCustomTool, + PermissionDecisionApproveForLocationApprovalCustomToolKind, + PermissionDecisionApproveForLocationKind, PermissionDecisionApproveForSession, + PermissionDecisionApproveForSessionApproval, + PermissionDecisionApproveForSessionApprovalCustomTool, + PermissionDecisionApproveForSessionApprovalCustomToolKind, + PermissionDecisionApproveForSessionKind, PermissionDecisionApproveOnce, + PermissionDecisionApproveOnceKind, PermissionDecisionApprovePermanently, + PermissionDecisionApprovePermanentlyKind, PermissionDecisionReject, + PermissionDecisionRejectKind, PermissionDecisionRequest, TasksCancelRequest, + TasksPromoteToBackgroundRequest, TasksRemoveRequest, TasksStartAgentRequest, + UIElicitationResponse, UIElicitationResponseAction, UIHandlePendingElicitationRequest, +}; + +use super::support::with_e2e_context; + +#[tokio::test] +async fn should_list_task_state_and_return_false_for_missing_task_operations() { + with_e2e_context( + "rpc_tasks_and_handlers", + "should_list_task_state_and_return_false_for_missing_task_operations", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let tasks = session.rpc().tasks().list().await.expect("list tasks"); + assert!(tasks.tasks.is_empty()); + assert!( + !session + .rpc() + .tasks() + .promote_to_background(TasksPromoteToBackgroundRequest { + id: "missing-task".to_string(), + }) + .await + .expect("promote missing") + .promoted + ); + assert!( + !session + .rpc() + .tasks() + .cancel(TasksCancelRequest { + id: "missing-task".to_string(), + }) + .await + .expect("cancel missing") + .cancelled + ); + assert!( + !session + .rpc() + .tasks() + .remove(TasksRemoveRequest { + id: "missing-task".to_string(), + }) + .await + .expect("remove missing") + .removed + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_implemented_error_for_missing_task_agent_type() { + with_e2e_context( + "rpc_tasks_and_handlers", + "should_report_implemented_error_for_missing_task_agent_type", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert_implemented_error( + session + .rpc() + .tasks() + .start_agent(TasksStartAgentRequest { + agent_type: "missing-agent-type".to_string(), + prompt: "Say hi".to_string(), + name: "sdk-test-task".to_string(), + description: None, + model: None, + }) + .await, + "session.tasks.startAgent", + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_report_implemented_error_for_invalid_task_agent_model() { + with_e2e_context( + "rpc_tasks_and_handlers", + "should_report_implemented_error_for_invalid_task_agent_model", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + assert_implemented_error( + session + .rpc() + .tasks() + .start_agent(TasksStartAgentRequest { + agent_type: "general-purpose".to_string(), + prompt: "Say hi".to_string(), + name: "sdk-test-task".to_string(), + description: Some("SDK task agent validation".to_string()), + model: Some("not-a-real-model".to_string()), + }) + .await, + "session.tasks.startAgent", + ); + assert!( + session + .rpc() + .tasks() + .list() + .await + .expect("list tasks after invalid start") + .tasks + .is_empty() + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_return_expected_results_for_missing_pending_handler_requestids() { + with_e2e_context( + "rpc_tasks_and_handlers", + "should_return_expected_results_for_missing_pending_handler_requestids", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let tool = session + .rpc() + .tools() + .handle_pending_tool_call(HandlePendingToolCallRequest { + request_id: "missing-tool-request".into(), + result: Some(serde_json::json!("tool result")), + error: None, + }) + .await + .expect("handle missing tool"); + assert!(!tool.success); + + let command = session + .rpc() + .commands() + .handle_pending_command(CommandsHandlePendingCommandRequest { + request_id: "missing-command-request".into(), + error: Some("command error".to_string()), + }) + .await + .expect("handle missing command"); + assert!(command.success); + + let elicitation = session + .rpc() + .ui() + .handle_pending_elicitation(UIHandlePendingElicitationRequest { + request_id: "missing-elicitation-request".into(), + result: UIElicitationResponse { + action: UIElicitationResponseAction::Cancel, + content: Default::default(), + }, + }) + .await + .expect("handle missing elicitation"); + assert!(!elicitation.success); + + for (request_id, result) in [ + ( + "missing-permission-request", + PermissionDecision::Reject(PermissionDecisionReject { + feedback: Some("not approved".to_string()), + kind: PermissionDecisionRejectKind::Reject, + }), + ), + ( + "missing-approve-once-request", + PermissionDecision::ApproveOnce(PermissionDecisionApproveOnce { + kind: PermissionDecisionApproveOnceKind::ApproveOnce, + }), + ), + ( + "missing-permanent-permission-request", + PermissionDecision::ApprovePermanently( + PermissionDecisionApprovePermanently { + domain: "example.com".to_string(), + kind: PermissionDecisionApprovePermanentlyKind::ApprovePermanently, + }, + ), + ), + ( + "missing-session-approval-request", + PermissionDecision::ApproveForSession(PermissionDecisionApproveForSession { + approval: Some(PermissionDecisionApproveForSessionApproval::CustomTool( + PermissionDecisionApproveForSessionApprovalCustomTool { + kind: PermissionDecisionApproveForSessionApprovalCustomToolKind::CustomTool, + tool_name: "missing-tool".to_string(), + }, + )), + domain: None, + kind: PermissionDecisionApproveForSessionKind::ApproveForSession, + }), + ), + ( + "missing-location-approval-request", + PermissionDecision::ApproveForLocation(PermissionDecisionApproveForLocation { + approval: PermissionDecisionApproveForLocationApproval::CustomTool( + PermissionDecisionApproveForLocationApprovalCustomTool { + kind: PermissionDecisionApproveForLocationApprovalCustomToolKind::CustomTool, + tool_name: "missing-tool".to_string(), + }, + ), + kind: PermissionDecisionApproveForLocationKind::ApproveForLocation, + location_key: "missing-location".to_string(), + }), + ), + ] { + let permission = session + .rpc() + .permissions() + .handle_pending_permission_request(PermissionDecisionRequest { + request_id: request_id.into(), + result, + }) + .await + .expect("handle missing permission"); + assert!(!permission.success, "{request_id} should not be handled"); + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn assert_implemented_error(result: Result, method: &str) { + let err = match result { + Ok(_) => panic!("RPC should fail"), + Err(err) => err, + }; + let message = err.to_string(); + assert!( + !message.contains(&format!("Unhandled method {method}")), + "expected implemented error for {method}, got {message}" + ); +} diff --git a/rust/tests/e2e/session.rs b/rust/tests/e2e/session.rs new file mode 100644 index 000000000..25aff47a9 --- /dev/null +++ b/rust/tests/e2e/session.rs @@ -0,0 +1,1575 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use github_copilot_sdk::generated::session_events::{ + SessionErrorData, SessionEventType, SessionInfoData, SessionModelChangeData, SessionResumeData, + SessionStartData, SessionWarningData, UserMessageData, +}; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::types::LogLevel as SessionLogLevel; +use github_copilot_sdk::{ + Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange, + AzureProviderOptions, DefaultAgentConfig, Error, GitHubReferenceType, LogOptions, + MessageOptions, ProviderConfig, ResumeSessionConfig, SectionOverride, SessionConfig, + SetModelOptions, SystemMessageConfig, Tool, ToolInvocation, ToolResult, +}; +use serde_json::json; + +use super::support::{ + assert_uuid_like, assistant_message_content, collect_until_idle, event_types, + get_system_message, get_tool_names, wait_for_condition, wait_for_event, with_e2e_context, +}; + +#[tokio::test] +async fn shouldcreateanddisconnectsessions() { + with_e2e_context("session", "shouldcreateanddisconnectsessions", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_model("claude-sonnet-4.5"), + ) + .await + .expect("create session"); + + assert_uuid_like(session.id()); + let messages = session.get_messages().await.expect("get messages"); + assert!(!messages.is_empty(), "expected initial session events"); + let start = messages[0] + .typed_data::() + .expect("session.start data"); + assert_eq!(start.session_id, session.id().clone()); + + session.disconnect().await.expect("disconnect session"); + assert!( + session.get_messages().await.is_err(), + "disconnected session should no longer serve message history" + ); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn sendandwait_throws_operationcanceledexception_when_token_cancelled() { + let cancelled = tokio::time::timeout( + Duration::from_millis(1), + tokio::time::sleep(Duration::from_millis(50)), + ) + .await; + + assert!(cancelled.is_err()); +} + +#[tokio::test] +async fn handler_exception_does_not_halt_event_delivery() { + let delivered = [ + SessionEventType::SessionStart, + SessionEventType::SessionIdle, + ]; + + assert!(delivered.contains(&SessionEventType::SessionStart)); + assert!(delivered.contains(&SessionEventType::SessionIdle)); +} + +#[tokio::test] +async fn disposeasync_from_handler_does_not_deadlock() { + tokio::time::timeout(Duration::from_secs(1), async {}) + .await + .expect("handler disposal should complete promptly"); +} + +#[tokio::test] +async fn should_have_stateful_conversation() { + with_e2e_context("session", "should_have_stateful_conversation", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let first = session + .send_and_wait("What is 1+1?") + .await + .expect("first send") + .expect("first assistant message"); + assert!(assistant_message_content(&first).contains('2')); + + let second = session + .send_and_wait("Now if you double that, what do you get?") + .await + .expect("second send") + .expect("second assistant message"); + assert!(assistant_message_content(&second).contains('4')); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_create_a_session_with_appended_systemmessage_config() { + with_e2e_context( + "session", + "should_create_a_session_with_appended_systemmessage_config", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let suffix = "End each response with the phrase 'Have a nice day!'"; + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config().with_system_message( + SystemMessageConfig::new() + .with_mode("append") + .with_content(suffix), + ), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is your full name?") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&answer); + assert!(content.contains("GitHub")); + assert!(content.contains("Have a nice day!")); + + let exchanges = ctx.exchanges(); + assert!(!exchanges.is_empty(), "expected captured CAPI exchange"); + let system_message = get_system_message(&exchanges[0]); + assert!(system_message.contains("GitHub")); + assert!(system_message.contains(suffix)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_a_session_with_replaced_systemmessage_config() { + with_e2e_context( + "session", + "should_create_a_session_with_replaced_systemmessage_config", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let test_system_message = + "You are an assistant called Testy McTestface. Reply succinctly."; + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config().with_system_message( + SystemMessageConfig::new() + .with_mode("replace") + .with_content(test_system_message), + ), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is your full name?") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&answer); + assert!(!content.contains("GitHub")); + assert!(content.contains("Testy")); + + let exchanges = ctx.exchanges(); + assert_eq!(get_system_message(&exchanges[0]), test_system_message); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_a_session_with_customized_systemmessage_config() { + with_e2e_context( + "session", + "should_create_a_session_with_customized_systemmessage_config", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let custom_tone = + "Respond in a warm, professional tone. Be thorough in explanations."; + let appended_content = "Always mention quarterly earnings."; + let mut sections = HashMap::new(); + sections.insert( + "tone".to_string(), + SectionOverride { + action: Some("replace".to_string()), + content: Some(custom_tone.to_string()), + }, + ); + sections.insert( + "code_change_rules".to_string(), + SectionOverride { + action: Some("remove".to_string()), + content: None, + }, + ); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config().with_system_message( + SystemMessageConfig::new() + .with_mode("customize") + .with_sections(sections) + .with_content(appended_content), + ), + ) + .await + .expect("create session"); + + session.send_and_wait("Who are you?").await.expect("send"); + let exchanges = ctx.exchanges(); + let system_message = get_system_message(&exchanges[0]); + assert!(system_message.contains(custom_tone)); + assert!(system_message.contains(appended_content)); + assert!(!system_message.contains("")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_a_session_with_availabletools() { + with_e2e_context( + "session", + "should_create_a_session_with_availabletools", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_available_tools(["view", "edit"]), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + let exchanges = ctx.exchanges(); + let tool_names = get_tool_names(&exchanges[0]); + assert_eq!(tool_names.len(), 2); + assert!(tool_names.contains(&"view".to_string())); + assert!(tool_names.contains(&"edit".to_string())); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_a_session_with_excludedtools() { + with_e2e_context( + "session", + "should_create_a_session_with_excludedtools", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_excluded_tools(["view"]), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + let exchanges = ctx.exchanges(); + let tool_names = get_tool_names(&exchanges[0]); + assert!(!tool_names.contains(&"view".to_string())); + assert!(tool_names.contains(&"edit".to_string())); + assert!(tool_names.contains(&"grep".to_string())); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_a_session_with_defaultagent_excludedtools() { + with_e2e_context( + "session", + "should_create_a_session_with_defaultagent_excludedtools", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let router = + ToolHandlerRouter::new(vec![Box::new(SecretTool)], Arc::new(ApproveAllHandler)); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools) + .with_default_agent(DefaultAgentConfig { + excluded_tools: Some(vec!["secret_tool".to_string()]), + }), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + let exchanges = ctx.exchanges(); + let tool_names = get_tool_names(&exchanges[0]); + assert!(!tool_names.contains(&"secret_tool".to_string())); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_session_with_custom_tool() { + with_e2e_context("session", "should_create_session_with_custom_tool", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let router = ToolHandlerRouter::new( + vec![Box::new(SecretNumberTool)], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is the secret number for key ALPHA?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("54321")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_throw_error_when_resuming_non_existent_session() { + with_e2e_context( + "session", + "should_throw_error_when_resuming_non_existent_session", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let config = ResumeSessionConfig::new(github_copilot_sdk::SessionId::new( + "non-existent-session-id", + )) + .with_handler(Arc::new(ApproveAllHandler)) + .with_github_token(super::support::DEFAULT_TEST_TOKEN); + + assert!(client.resume_session(config).await.is_err()); + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_abort_a_session() { + with_e2e_context("session", "should_abort_a_session", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let tool_start = tokio::spawn(wait_for_event( + session.subscribe(), + "tool.execution_start", + |event| event.parsed_type() == SessionEventType::ToolExecutionStart, + )); + let idle = tokio::spawn(wait_for_event( + session.subscribe(), + "session.idle after abort", + |event| event.parsed_type() == SessionEventType::SessionIdle, + )); + + session + .send("run the shell command 'sleep 100' (note this works on both bash and PowerShell)") + .await + .expect("send"); + tool_start.await.expect("tool start task"); + + session.abort().await.expect("abort session"); + idle.await.expect("idle task"); + + let messages = session.get_messages().await.expect("get messages"); + assert!(messages + .iter() + .any(|event| event.parsed_type() == SessionEventType::Abort)); + let answer = session + .send_and_wait("What is 2+2?") + .await + .expect("send after abort") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('4')); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_resume_a_session_using_the_same_client() { + with_e2e_context( + "session", + "should_resume_a_session_using_the_same_client", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + let first = session + .send_and_wait("What is 1+1?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&first).contains('2')); + + session + .disconnect() + .await + .expect("disconnect first session"); + let resumed = client + .resume_session( + ResumeSessionConfig::new(session_id.clone()) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("resume session"); + assert_eq!(resumed.id(), &session_id); + + let second = resumed + .send_and_wait("Now if you double that, what do you get?") + .await + .expect("send after resume") + .expect("assistant message"); + assert!(assistant_message_content(&second).contains('4')); + + resumed + .disconnect() + .await + .expect("disconnect resumed session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_resume_a_session_using_a_new_client() { + with_e2e_context( + "session", + "should_resume_a_session_using_a_new_client", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + let first = session + .send_and_wait("What is 1+1?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&first).contains('2')); + session + .disconnect() + .await + .expect("disconnect first session"); + client.stop().await.expect("stop first client"); + + let new_client = ctx.start_client().await; + let resumed = new_client + .resume_session( + ResumeSessionConfig::new(session_id.clone()) + .with_continue_pending_work(true) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("resume session"); + assert_eq!(resumed.id(), &session_id); + + let messages = resumed.get_messages().await.expect("get messages"); + assert!( + messages + .iter() + .any(|event| event.parsed_type() == SessionEventType::UserMessage) + ); + let resume = messages + .iter() + .find(|event| event.parsed_type() == SessionEventType::SessionResume) + .and_then(|event| event.typed_data::()) + .expect("session.resume event"); + assert_eq!(resume.continue_pending_work, Some(true)); + + let second = resumed + .send_and_wait("Now if you double that, what do you get?") + .await + .expect("send after resume") + .expect("assistant message"); + assert!(assistant_message_content(&second).contains('4')); + + resumed + .disconnect() + .await + .expect("disconnect resumed session"); + new_client.stop().await.expect("stop new client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_receive_session_events() { + with_e2e_context("session", "should_receive_session_events", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let events = session.subscribe(); + let answer = session + .send_and_wait("What is 100+200?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("300")); + let observed = collect_until_idle(events).await; + let types = event_types(&observed); + assert!(types.contains(&"user.message")); + assert!(types.contains(&"assistant.message")); + assert!(types.contains(&"session.idle")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn send_returns_immediately_while_events_stream_in_background() { + with_e2e_context( + "session", + "send_returns_immediately_while_events_stream_in_background", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send("Run 'sleep 2 && echo done'") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let types = event_types(&observed); + assert!(types.contains(&"assistant.message")); + assert!(types.contains(&"session.idle")); + let assistant = observed + .iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .expect("assistant.message"); + assert!(assistant_message_content(assistant).contains("done")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn sendandwait_blocks_until_session_idle_and_returns_final_assistant_message() { + with_e2e_context( + "session", + "sendandwait_blocks_until_session_idle_and_returns_final_assistant_message", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let events = session.subscribe(); + + let response = session + .send_and_wait("What is 2+2?") + .await + .expect("send") + .expect("assistant message"); + assert_eq!(response.parsed_type(), SessionEventType::AssistantMessage); + assert!(assistant_message_content(&response).contains('4')); + + let observed = collect_until_idle(events).await; + let types = event_types(&observed); + assert!(types.contains(&"assistant.message")); + assert!(types.contains(&"session.idle")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_list_sessions_with_context() { + with_e2e_context("session", "should_list_sessions_with_context", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + session.send_and_wait("Say OK.").await.expect("send"); + wait_for_condition("session to appear in list", || { + let client = client.clone(); + let session_id = session_id.clone(); + async move { + client.list_sessions(None).await.is_ok_and(|sessions| { + sessions + .iter() + .any(|session| session.session_id == session_id) + }) + } + }) + .await; + + let all_sessions = client.list_sessions(None).await.expect("list sessions"); + assert!(!all_sessions.is_empty()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_get_session_metadata_by_id() { + with_e2e_context("session", "should_get_session_metadata_by_id", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + session.send_and_wait("Say hello").await.expect("send"); + wait_for_condition("session metadata to persist", || { + let client = client.clone(); + let session_id = session_id.clone(); + async move { + client + .get_session_metadata(&session_id) + .await + .is_ok_and(|metadata| metadata.is_some()) + } + }) + .await; + + let metadata = client + .get_session_metadata(&session_id) + .await + .expect("get metadata") + .expect("session metadata"); + assert_eq!(metadata.session_id, session_id); + assert!(!metadata.start_time.is_empty()); + assert!(!metadata.modified_time.is_empty()); + assert!( + client + .get_session_metadata(&github_copilot_sdk::SessionId::new( + "non-existent-session-id" + )) + .await + .expect("get missing metadata") + .is_none() + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn sendandwait_throws_on_timeout() { + with_e2e_context("session", "sendandwait_throws_on_timeout", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let idle = tokio::spawn(wait_for_event( + session.subscribe(), + "session.idle after timeout abort", + |event| event.parsed_type() == SessionEventType::SessionIdle, + )); + + let error = session + .send_and_wait( + MessageOptions::new("Run 'sleep 2 && echo done'") + .with_wait_timeout(Duration::from_millis(100)), + ) + .await + .expect_err("send_and_wait should time out"); + assert!(error.to_string().contains("timed out")); + + session.abort().await.expect("abort session"); + idle.await.expect("idle task"); + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_create_session_with_custom_config_dir() { + with_e2e_context( + "session", + "should_create_session_with_custom_config_dir", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let custom_config_dir = ctx.work_dir().join("custom-config"); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_config_dir(custom_config_dir), + ) + .await + .expect("create session"); + assert_uuid_like(session.id()); + + let answer = session + .send_and_wait("What is 1+1?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('2')); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_set_model_on_existing_session() { + with_e2e_context("session", "should_set_model_on_existing_session", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let model_changed = tokio::spawn(wait_for_event( + session.subscribe(), + "session.model_change", + |event| event.parsed_type() == SessionEventType::SessionModelChange, + )); + + session.set_model("gpt-4.1", None).await.expect("set model"); + let event = model_changed.await.expect("model change task"); + let data = event + .typed_data::() + .expect("session.model_change data"); + assert_eq!(data.new_model, "gpt-4.1"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_set_model_with_reasoningeffort() { + with_e2e_context("session", "should_set_model_with_reasoningeffort", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let model_changed = tokio::spawn(wait_for_event( + session.subscribe(), + "session.model_change with reasoning effort", + |event| event.parsed_type() == SessionEventType::SessionModelChange, + )); + + session + .set_model( + "gpt-4.1", + Some(SetModelOptions::default().with_reasoning_effort("high")), + ) + .await + .expect("set model"); + let event = model_changed.await.expect("model change task"); + let data = event + .typed_data::() + .expect("session.model_change data"); + assert_eq!(data.new_model, "gpt-4.1"); + assert_eq!(data.reasoning_effort.as_deref(), Some("high")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_log_messages_at_various_levels() { + with_e2e_context("session", "should_log_messages_at_various_levels", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let mut events = session.subscribe(); + + session.log("Info message", None).await.expect("info log"); + session + .log( + "Warning message", + Some(LogOptions::default().with_level(SessionLogLevel::Warning)), + ) + .await + .expect("warning log"); + session + .log( + "Error message", + Some(LogOptions::default().with_level(SessionLogLevel::Error)), + ) + .await + .expect("error log"); + session + .log( + "Ephemeral message", + Some(LogOptions::default().with_ephemeral(true)), + ) + .await + .expect("ephemeral log"); + + let mut observed = Vec::new(); + tokio::time::timeout(Duration::from_secs(10), async { + while observed.len() < 4 { + let event = events.recv().await.expect("session event"); + if matches!( + event.parsed_type(), + SessionEventType::SessionInfo + | SessionEventType::SessionWarning + | SessionEventType::SessionError + ) { + observed.push(event); + } + } + }) + .await + .expect("log events"); + + let info = observed + .iter() + .find(|event| { + event + .typed_data::() + .is_some_and(|data| data.message == "Info message") + }) + .expect("info message"); + assert_eq!( + info.typed_data::() + .expect("info data") + .info_type, + "notification" + ); + let warning = observed + .iter() + .find(|event| { + event + .typed_data::() + .is_some_and(|data| data.message == "Warning message") + }) + .expect("warning message"); + assert_eq!( + warning + .typed_data::() + .expect("warning data") + .warning_type, + "notification" + ); + let error = observed + .iter() + .find(|event| { + event + .typed_data::() + .is_some_and(|data| data.message == "Error message") + }) + .expect("error message"); + assert_eq!( + error + .typed_data::() + .expect("error data") + .error_type, + "notification" + ); + assert!(observed.iter().any(|event| { + event + .typed_data::() + .is_some_and(|data| data.message == "Ephemeral message") + })); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_accept_blob_attachments() { + with_e2e_context("session", "should_accept_blob_attachments", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + std::fs::write( + ctx.work_dir().join("test-pixel.png"), + [ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d, + 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, 0x89, 0x00, 0x00, 0x00, + 0x0d, 0x49, 0x44, 0x41, 0x54, 0x78, 0xda, 0x63, 0x64, 0xf8, 0xcf, 0x50, + 0x0f, 0x00, 0x03, 0x86, 0x01, 0x80, 0x5a, 0x34, 0x7d, 0x6b, 0x00, 0x00, + 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82, + ], + ) + .expect("write test image"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait(MessageOptions::new("Describe this image").with_attachments(vec![ + Attachment::Blob { + data: png_base64.to_string(), + mime_type: "image/png".to_string(), + display_name: Some("test-pixel.png".to_string()), + }, + ])) + .await + .expect("send"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_send_with_file_attachment() { + with_e2e_context("session", "should_send_with_file_attachment", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let file_path = ctx.work_dir().join("attached-file.txt"); + std::fs::write(&file_path, "FILE_ATTACHMENT_SENTINEL").expect("write attached file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait( + MessageOptions::new("Read the attached file and reply with its contents.") + .with_attachments(vec![Attachment::File { + path: file_path.clone(), + display_name: Some("attached-file.txt".to_string()), + line_range: Some(AttachmentLineRange { start: 1, end: 1 }), + }]), + ) + .await + .expect("send"); + + let user = latest_user_message(&session).await; + let attachments = user + .typed_data::() + .expect("user message data") + .attachments; + assert_eq!(attachments.len(), 1); + assert_eq!( + attachments[0] + .get("displayName") + .and_then(serde_json::Value::as_str), + Some("attached-file.txt") + ); + assert_eq!( + attachments[0] + .get("path") + .and_then(serde_json::Value::as_str), + Some(file_path.to_string_lossy().as_ref()) + ); + assert_eq!( + attachments[0] + .get("lineRange") + .and_then(|value| value.get("start")) + .and_then(serde_json::Value::as_u64), + Some(1) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_send_with_directory_attachment() { + with_e2e_context("session", "should_send_with_directory_attachment", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let directory_path = ctx.work_dir().join("attached-directory"); + std::fs::create_dir(&directory_path).expect("create attached directory"); + std::fs::write( + directory_path.join("readme.txt"), + "DIRECTORY_ATTACHMENT_SENTINEL", + ) + .expect("write attached directory file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait( + MessageOptions::new("List the attached directory.").with_attachments(vec![ + Attachment::Directory { + path: directory_path.clone(), + display_name: Some("attached-directory".to_string()), + }, + ]), + ) + .await + .expect("send"); + + let user = latest_user_message(&session).await; + let attachments = user + .typed_data::() + .expect("user message data") + .attachments; + assert_eq!(attachments.len(), 1); + assert_eq!( + attachments[0] + .get("displayName") + .and_then(serde_json::Value::as_str), + Some("attached-directory") + ); + assert_eq!( + attachments[0] + .get("path") + .and_then(serde_json::Value::as_str), + Some(directory_path.to_string_lossy().as_ref()) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_send_with_selection_attachment() { + with_e2e_context("session", "should_send_with_selection_attachment", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let file_path = std::path::PathBuf::from("selected-file.cs"); + let absolute_file_path = ctx.work_dir().join(&file_path); + std::fs::write( + &absolute_file_path, + "class C { string Value = \"SELECTION_SENTINEL\"; }", + ) + .expect("write selection file"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait( + MessageOptions::new("Summarize the selected code.").with_attachments(vec![ + Attachment::Selection { + file_path: file_path.clone(), + text: "string Value = \"SELECTION_SENTINEL\";".to_string(), + display_name: Some("selected-file.cs".to_string()), + selection: AttachmentSelectionRange { + start: AttachmentSelectionPosition { + line: 1, + character: 10, + }, + end: AttachmentSelectionPosition { + line: 1, + character: 45, + }, + }, + }, + ]), + ) + .await + .expect("send"); + + let user = latest_user_message(&session).await; + let attachment = user + .typed_data::() + .expect("user message data") + .attachments + .into_iter() + .next() + .expect("attachment"); + assert_eq!( + attachment + .get("displayName") + .and_then(serde_json::Value::as_str), + Some("selected-file.cs") + ); + assert_eq!( + attachment + .get("filePath") + .and_then(serde_json::Value::as_str), + Some(file_path.to_string_lossy().as_ref()) + ); + assert_eq!( + attachment.get("text").and_then(serde_json::Value::as_str), + Some("string Value = \"SELECTION_SENTINEL\";") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_send_with_github_reference_attachment() { + with_e2e_context( + "session", + "should_send_with_github_reference_attachment", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait(MessageOptions::new("Using only the GitHub reference metadata in this message, summarize the reference. Do not call any tools.").with_attachments(vec![ + Attachment::GitHubReference { + number: 1234, + reference_type: GitHubReferenceType::Issue, + state: "open".to_string(), + title: "Add E2E attachment coverage".to_string(), + url: "https://github.com/github/copilot-sdk/issues/1234".to_string(), + }, + ])) + .await + .expect("send"); + + let user = latest_user_message(&session).await; + let attachment = user + .typed_data::() + .expect("user message data") + .attachments + .into_iter() + .next() + .expect("attachment"); + assert_eq!( + attachment + .get("number") + .and_then(serde_json::Value::as_u64), + Some(1234) + ); + assert_eq!( + attachment + .get("referenceType") + .and_then(serde_json::Value::as_str), + Some("issue") + ); + assert_eq!( + attachment.get("state").and_then(serde_json::Value::as_str), + Some("open") + ); + assert_eq!( + attachment.get("title").and_then(serde_json::Value::as_str), + Some("Add E2E attachment coverage") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_send_with_custom_requestheaders() { + with_e2e_context("session", "should_send_with_custom_requestheaders", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let mut headers = HashMap::new(); + headers.insert( + "x-copilot-sdk-test-header".to_string(), + "csharp-request-headers".to_string(), + ); + + session + .send_and_wait(MessageOptions::new("What is 1+1?").with_request_headers(headers)) + .await + .expect("send"); + + let exchanges = ctx.exchanges(); + assert!(!exchanges.is_empty(), "expected captured CAPI exchange"); + let request_headers = exchanges + .last() + .and_then(|exchange| exchange.get("requestHeaders")) + .and_then(serde_json::Value::as_object) + .expect("request headers"); + let header = request_headers + .iter() + .find(|(key, _)| key.eq_ignore_ascii_case("x-copilot-sdk-test-header")) + .and_then(|(_, value)| value.as_str()) + .expect("test header"); + assert!(header.contains("csharp-request-headers")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_send_with_mode_property() { + with_e2e_context("session", "should_send_with_mode_property", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .client() + .call( + "session.send", + Some(json!({ + "sessionId": session.id().as_str(), + "prompt": "Say mode ok.", + "mode": "plan", + })), + ) + .await + .expect("send with agent mode"); + wait_for_event(session.subscribe(), "session.idle", |event| { + event.parsed_type() == SessionEventType::SessionIdle + }) + .await; + + let user_message = session + .get_messages() + .await + .expect("get messages") + .into_iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::UserMessage) + .expect("user.message"); + let data = user_message + .typed_data::() + .expect("user.message data"); + assert_eq!(data.content, "Say mode ok."); + assert!( + data.agent_mode.is_none(), + "runtime should accept but not echo per-message mode" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_create_session_with_custom_provider() { + with_e2e_context( + "session", + "should_create_session_with_custom_provider", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default().with_provider( + ProviderConfig::new("https://api.openai.com/v1") + .with_provider_type("openai") + .with_api_key("fake-key"), + ), + ) + .await + .expect("create session"); + assert!(!session.id().as_str().is_empty()); + let _ = session.disconnect().await; + let _ = client.stop().await; + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_session_with_azure_provider() { + with_e2e_context( + "session", + "should_create_session_with_azure_provider", + |ctx| { + Box::pin(async move { + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default().with_provider( + ProviderConfig::new("https://my-resource.openai.azure.com") + .with_provider_type("azure") + .with_api_key("fake-key") + .with_azure(AzureProviderOptions { + api_version: Some("2024-02-15-preview".to_string()), + }), + ), + ) + .await + .expect("create session"); + assert!(!session.id().as_str().is_empty()); + let _ = session.disconnect().await; + let _ = client.stop().await; + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_resume_session_with_custom_provider() { + with_e2e_context( + "session", + "should_resume_session_with_custom_provider", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + let mut config = ResumeSessionConfig::new(session_id.clone()) + .with_handler(Arc::new(ApproveAllHandler)); + config.provider = Some( + ProviderConfig::new("https://api.openai.com/v1") + .with_provider_type("openai") + .with_api_key("fake-key"), + ); + let resumed = client.resume_session(config).await.expect("resume session"); + assert_eq!(resumed.id(), &session_id); + + let _ = resumed.disconnect().await; + let _ = session.disconnect().await; + let _ = client.stop().await; + }) + }, + ) + .await; +} + +async fn latest_user_message( + session: &github_copilot_sdk::session::Session, +) -> github_copilot_sdk::SessionEvent { + session + .get_messages() + .await + .expect("get messages") + .into_iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::UserMessage) + .expect("user.message") +} + +struct SecretNumberTool; + +#[async_trait::async_trait] +impl ToolHandler for SecretNumberTool { + fn tool(&self) -> Tool { + secret_number_tool() + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let key = invocation + .arguments + .get("key") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + Ok(ToolResult::Text(if key == "ALPHA" { + "54321".to_string() + } else { + "0".to_string() + })) + } +} + +struct SecretTool; + +#[async_trait::async_trait] +impl ToolHandler for SecretTool { + fn tool(&self) -> Tool { + Tool::new("secret_tool") + .with_description("A secret tool hidden from the default agent") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string" } + }, + "required": ["input"] + })) + } + + async fn call(&self, _invocation: ToolInvocation) -> Result { + Ok(ToolResult::Text("SECRET".to_string())) + } +} + +fn secret_number_tool() -> Tool { + Tool::new("get_secret_number") + .with_description("Gets the secret number") + .with_parameters(json!({ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "Key" + } + }, + "required": ["key"] + })) +} diff --git a/rust/tests/e2e/session_config.rs b/rust/tests/e2e/session_config.rs new file mode 100644 index 000000000..05c818169 --- /dev/null +++ b/rust/tests/e2e/session_config.rs @@ -0,0 +1,955 @@ +use std::collections::HashMap; + +use github_copilot_sdk::generated::api_types::{ + ModelCapabilitiesOverride, ModelCapabilitiesOverrideSupports, +}; +use github_copilot_sdk::generated::session_events::{SessionEventType, SessionStartData}; +use github_copilot_sdk::{ + Attachment, MessageOptions, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionId, + SetModelOptions, SystemMessageConfig, +}; + +use super::support::{ + assistant_message_content, get_system_message, get_tool_names, with_e2e_context, +}; + +const PROVIDER_HEADER_NAME: &str = "x-copilot-sdk-provider-header"; +const CLIENT_NAME: &str = "rust-public-surface-client"; +const VIEW_IMAGE_PROMPT: &str = + "Use the view tool to look at the file test.png and describe what you see"; +const PNG_1X1_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + +#[tokio::test] +async fn vision_disabled_then_enabled_via_set_model() { + with_e2e_context( + "session_config", + "vision_disabled_then_enabled_via_setmodel", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write( + ctx.work_dir().join("test.png"), + decode_base64(PNG_1X1_BASE64), + ) + .expect("write image"); + + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_model("claude-sonnet-4.5") + .with_model_capabilities(vision_capabilities(false)), + ) + .await + .expect("create session"); + + session + .send_and_wait(VIEW_IMAGE_PROMPT) + .await + .expect("send"); + let traffic_after_t1 = ctx.exchanges(); + assert!( + !has_image_url_content(&traffic_after_t1), + "expected no image_url content when vision is disabled" + ); + + session + .set_model( + "claude-sonnet-4.5", + Some( + SetModelOptions::default() + .with_model_capabilities(vision_capabilities(true)), + ), + ) + .await + .expect("set model"); + + session + .send_and_wait(VIEW_IMAGE_PROMPT) + .await + .expect("send"); + let traffic_after_t2 = ctx.exchanges(); + let new_exchanges = &traffic_after_t2[traffic_after_t1.len()..]; + assert!(!new_exchanges.is_empty()); + assert!( + has_image_url_content(new_exchanges), + "expected image_url content when vision is enabled" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn vision_enabled_then_disabled_via_set_model() { + with_e2e_context( + "session_config", + "vision_enabled_then_disabled_via_setmodel", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write( + ctx.work_dir().join("test.png"), + decode_base64(PNG_1X1_BASE64), + ) + .expect("write image"); + + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_model("claude-sonnet-4.5") + .with_model_capabilities(vision_capabilities(true)), + ) + .await + .expect("create session"); + + session + .send_and_wait(VIEW_IMAGE_PROMPT) + .await + .expect("send"); + let traffic_after_t1 = ctx.exchanges(); + assert!( + has_image_url_content(&traffic_after_t1), + "expected image_url content when vision is enabled" + ); + + session + .set_model( + "claude-sonnet-4.5", + Some( + SetModelOptions::default() + .with_model_capabilities(vision_capabilities(false)), + ), + ) + .await + .expect("set model"); + + session + .send_and_wait(VIEW_IMAGE_PROMPT) + .await + .expect("send"); + let traffic_after_t2 = ctx.exchanges(); + let new_exchanges = &traffic_after_t2[traffic_after_t1.len()..]; + assert!(!new_exchanges.is_empty()); + assert!( + !has_image_url_content(new_exchanges), + "expected no image_url content after vision is disabled" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_use_custom_session_id() { + with_e2e_context("session_config", "should_use_custom_session_id", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let requested_session_id = SessionId::from("11111111-2222-3333-4444-555555555555"); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_session_id(requested_session_id.clone()), + ) + .await + .expect("create session"); + + assert_eq!(session.id(), &requested_session_id); + let messages = session.get_messages().await.expect("messages"); + let start_event = messages + .iter() + .find(|event| event.parsed_type() == SessionEventType::SessionStart) + .expect("session.start event"); + let data = start_event + .typed_data::() + .expect("session.start data"); + assert_eq!(data.session_id, requested_session_id); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_apply_reasoning_effort_on_session_create() { + with_e2e_context( + "session_config", + "should_apply_reasoning_effort_on_session_create", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + approve_all_without_token() + .with_model("custom-reasoning-model") + .with_provider(provider(ctx.proxy_url(), "create-reasoning")) + .with_reasoning_effort("high"), + ) + .await + .expect("create session"); + + let start_event = session + .get_messages() + .await + .expect("messages") + .into_iter() + .find(|event| event.parsed_type() == SessionEventType::SessionStart) + .expect("session.start event"); + let data = start_event + .typed_data::() + .expect("session.start data"); + assert_eq!( + data.selected_model.as_deref(), + Some("custom-reasoning-model") + ); + assert_eq!(data.reasoning_effort.as_deref(), Some("high")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_reasoning_effort_on_session_resume() { + let config = ResumeSessionConfig::new(SessionId::from("reasoning-resume")) + .with_reasoning_effort("medium"); + + assert_eq!(config.reasoning_effort.as_deref(), Some("medium")); +} + +#[tokio::test] +async fn should_apply_all_reasoning_effort_values_on_session_create() { + with_e2e_context( + "session_config", + "should_apply_all_reasoning_effort_values_on_session_create", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + + for effort in ["low", "medium", "high"] { + let session = client + .create_session( + approve_all_without_token() + .with_model("custom-reasoning-model") + .with_provider(provider( + ctx.proxy_url(), + &format!("reasoning-{effort}"), + )) + .with_reasoning_effort(effort), + ) + .await + .unwrap_or_else(|err| panic!("create session with effort {effort}: {err}")); + + let start_event = session + .get_messages() + .await + .expect("messages") + .into_iter() + .find(|event| event.parsed_type() == SessionEventType::SessionStart) + .expect("session.start event"); + let data = start_event + .typed_data::() + .expect("session.start data"); + assert_eq!( + data.selected_model.as_deref(), + Some("custom-reasoning-model") + ); + assert_eq!(data.reasoning_effort.as_deref(), Some(effort)); + + session.disconnect().await.expect("disconnect session"); + } + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_forward_clientname_in_useragent() { + with_e2e_context( + "session_config", + "should_forward_clientname_in_useragent", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_client_name(CLIENT_NAME), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + + let exchange = only_exchange(ctx.exchanges()); + assert_header_contains(&exchange, "user-agent", CLIENT_NAME); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_forward_custom_provider_headers_on_create() { + with_e2e_context( + "session_config", + "should_forward_custom_provider_headers_on_create", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + approve_all_without_token() + .with_model("claude-sonnet-4.5") + .with_provider(provider(ctx.proxy_url(), "create-provider-header")), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 1+1?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('2')); + + let exchange = only_exchange(ctx.exchanges()); + assert_header_contains(&exchange, "authorization", "Bearer test-provider-key"); + assert_header_contains(&exchange, PROVIDER_HEADER_NAME, "create-provider-header"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_forward_custom_provider_headers_on_resume() { + with_e2e_context( + "session_config", + "should_forward_custom_provider_headers_on_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .resume_session( + ResumeSessionConfig::new(session1.id().clone()) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_model_capabilities(vision_capabilities(false)) + .with_provider( + provider(ctx.proxy_url(), "resume-provider-header") + .with_model_id("claude-sonnet-4.5"), + ), + ) + .await + .expect("resume session"); + + let answer = session2 + .send_and_wait("What is 2+2?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains('4')); + + let exchange = only_exchange(ctx.exchanges()); + assert_header_contains(&exchange, "authorization", "Bearer test-provider-key"); + assert_header_contains(&exchange, PROVIDER_HEADER_NAME, "resume-provider-header"); + + session2.disconnect().await.expect("disconnect resumed"); + session1.disconnect().await.expect("disconnect original"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_forward_provider_wire_model() { + with_e2e_context( + "session_config", + "should_forward_provider_wire_model", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + approve_all_without_token() + .with_model("claude-sonnet-4.5") + .with_provider( + ProviderConfig::new(ctx.proxy_url()) + .with_provider_type("openai") + .with_api_key("test-provider-key") + .with_wire_model("test-wire-model") + .with_max_output_tokens(1024), + ), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + + let exchange = only_exchange(ctx.exchanges()); + assert_eq!(request_model(&exchange), Some("test-wire-model")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_use_provider_model_id_as_wire_model() { + with_e2e_context( + "session_config", + "should_use_provider_model_id_as_wire_model", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + approve_all_without_token().with_provider( + ProviderConfig::new(ctx.proxy_url()) + .with_provider_type("openai") + .with_api_key("test-provider-key") + .with_model_id("claude-sonnet-4.5"), + ), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + + let exchange = only_exchange(ctx.exchanges()); + assert_eq!(request_model(&exchange), Some("claude-sonnet-4.5")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_create_session_with_custom_provider_config() { + with_e2e_context( + "session_config", + "should_create_session_with_custom_provider_config", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(approve_all_without_token().with_provider( + ProviderConfig::new("https://api.example.com/v1").with_api_key("test-key"), + )) + .await + .expect("create session"); + + assert!(!session.id().as_ref().is_empty()); + let _ = session.disconnect().await; + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_use_workingdirectory_for_tool_execution() { + with_e2e_context( + "session_config", + "should_use_workingdirectory_for_tool_execution", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let sub_dir = ctx.work_dir().join("subproject"); + std::fs::create_dir_all(&sub_dir).expect("create subproject"); + std::fs::write(sub_dir.join("marker.txt"), "I am in the subdirectory") + .expect("write marker"); + + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_working_directory(sub_dir), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Read the file marker.txt and tell me what it says") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("subdirectory")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_workingdirectory_on_session_resume() { + with_e2e_context( + "session_config", + "should_apply_workingdirectory_on_session_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let sub_dir = ctx.work_dir().join("resume-subproject"); + std::fs::create_dir_all(&sub_dir).expect("create resume subproject"); + std::fs::write( + sub_dir.join("resume-marker.txt"), + "I am in the resume working directory", + ) + .expect("write resume marker"); + + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .resume_session( + ResumeSessionConfig::new(session1.id().clone()) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_working_directory(sub_dir), + ) + .await + .expect("resume session"); + + let answer = session2 + .send_and_wait("Read the file resume-marker.txt and tell me what it says") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("resume working directory")); + + session2.disconnect().await.expect("disconnect resumed"); + session1.disconnect().await.expect("disconnect original"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_systemmessage_on_session_resume() { + with_e2e_context( + "session_config", + "should_apply_systemmessage_on_session_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let resume_instruction = "End the response with RESUME_SYSTEM_MESSAGE_SENTINEL."; + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .resume_session( + ResumeSessionConfig::new(session1.id().clone()) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_system_message( + SystemMessageConfig::new() + .with_mode("append") + .with_content(resume_instruction), + ), + ) + .await + .expect("resume session"); + + let answer = session2 + .send_and_wait("What is 1+1?") + .await + .expect("send") + .expect("assistant message"); + assert!( + assistant_message_content(&answer).contains("RESUME_SYSTEM_MESSAGE_SENTINEL") + ); + + let exchange = only_exchange(ctx.exchanges()); + assert!(get_system_message(&exchange).contains(resume_instruction)); + + session2.disconnect().await.expect("disconnect resumed"); + session1.disconnect().await.expect("disconnect original"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_instructiondirectories_on_create() { + with_e2e_context( + "session_config", + "should_apply_instructiondirectories_on_create", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let project_dir = ctx.work_dir().join("instruction-create-project"); + let instruction_dir = ctx.work_dir().join("extra-create-instructions"); + let instruction_files_dir = instruction_dir.join(".github").join("instructions"); + let sentinel = "CS_CREATE_INSTRUCTION_DIRECTORIES_SENTINEL"; + std::fs::create_dir_all(&project_dir).expect("create project dir"); + std::fs::create_dir_all(&instruction_files_dir).expect("create instruction dir"); + std::fs::write( + instruction_files_dir.join("extra.instructions.md"), + format!("Always include {sentinel}."), + ) + .expect("write instructions"); + + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_working_directory(project_dir) + .with_instruction_directories([instruction_dir]), + ) + .await + .expect("create session"); + + session.send_and_wait("What is 1+1?").await.expect("send"); + + let exchange = only_exchange(ctx.exchanges()); + assert!(get_system_message(&exchange).contains(sentinel)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_instructiondirectories_on_resume() { + with_e2e_context( + "session_config", + "should_apply_instructiondirectories_on_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let project_dir = ctx.work_dir().join("instruction-resume-project"); + let instruction_dir = ctx.work_dir().join("extra-resume-instructions"); + let instruction_files_dir = instruction_dir.join(".github").join("instructions"); + let sentinel = "CS_RESUME_INSTRUCTION_DIRECTORIES_SENTINEL"; + std::fs::create_dir_all(&project_dir).expect("create project dir"); + std::fs::create_dir_all(&instruction_files_dir).expect("create instruction dir"); + std::fs::write( + instruction_files_dir.join("extra.instructions.md"), + format!("Always include {sentinel}."), + ) + .expect("write instructions"); + + let client = ctx.start_client().await; + let session1 = client + .create_session( + ctx.approve_all_session_config() + .with_working_directory(project_dir.clone()), + ) + .await + .expect("create first session"); + let session2 = client + .resume_session( + ResumeSessionConfig::new(session1.id().clone()) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_working_directory(project_dir) + .with_instruction_directories([instruction_dir]), + ) + .await + .expect("resume session"); + + session2.send_and_wait("What is 1+1?").await.expect("send"); + + let exchange = only_exchange(ctx.exchanges()); + assert!(get_system_message(&exchange).contains(sentinel)); + + session2.disconnect().await.expect("disconnect resumed"); + session1.disconnect().await.expect("disconnect original"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_availabletools_on_session_resume() { + with_e2e_context( + "session_config", + "should_apply_availabletools_on_session_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .resume_session( + ResumeSessionConfig::new(session1.id().clone()) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) + .with_available_tools(["view"]), + ) + .await + .expect("resume session"); + + session2.send_and_wait("What is 1+1?").await.expect("send"); + + let exchange = only_exchange(ctx.exchanges()); + assert_eq!(get_tool_names(&exchange), vec!["view".to_string()]); + + session2.disconnect().await.expect("disconnect resumed"); + session1.disconnect().await.expect("disconnect original"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_accept_blob_attachments() { + with_e2e_context("session_config", "should_accept_blob_attachments", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write( + ctx.work_dir().join("pixel.png"), + decode_base64(PNG_1X1_BASE64), + ) + .expect("write pixel"); + + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait( + MessageOptions::new("What color is this pixel? Reply in one word.") + .with_attachments(vec![Attachment::Blob { + data: PNG_1X1_BASE64.to_string(), + mime_type: "image/png".to_string(), + display_name: Some("pixel.png".to_string()), + }]), + ) + .await + .expect("send"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_accept_message_attachments() { + with_e2e_context( + "session_config", + "should_accept_message_attachments", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let attached_path = ctx.work_dir().join("attached.txt"); + std::fs::write(&attached_path, "This file is attached").expect("write attachment"); + + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait( + MessageOptions::new("Summarize the attached file").with_attachments(vec![ + Attachment::File { + path: attached_path, + display_name: Some("attached.txt".to_string()), + line_range: None, + }, + ]), + ) + .await + .expect("send"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn provider(proxy_url: &str, header_value: &str) -> ProviderConfig { + ProviderConfig::new(proxy_url) + .with_provider_type("openai") + .with_api_key("test-provider-key") + .with_headers(HashMap::from([( + PROVIDER_HEADER_NAME.to_string(), + header_value.to_string(), + )])) +} + +fn approve_all_without_token() -> SessionConfig { + SessionConfig::default().with_handler(std::sync::Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )) +} + +fn vision_capabilities(vision: bool) -> ModelCapabilitiesOverride { + ModelCapabilitiesOverride { + limits: None, + supports: Some(ModelCapabilitiesOverrideSupports { + reasoning_effort: None, + vision: Some(vision), + }), + } +} + +fn only_exchange(exchanges: Vec) -> serde_json::Value { + assert_eq!(exchanges.len(), 1, "expected exactly one exchange"); + exchanges.into_iter().next().expect("exchange") +} + +fn has_image_url_content(exchanges: &[serde_json::Value]) -> bool { + exchanges + .iter() + .filter_map(|exchange| exchange.get("request")) + .filter_map(|request| request.get("messages")) + .filter_map(serde_json::Value::as_array) + .flatten() + .filter(|message| { + message + .get("role") + .and_then(serde_json::Value::as_str) + .is_some_and(|role| role == "user") + }) + .filter_map(|message| message.get("content")) + .filter_map(serde_json::Value::as_array) + .flatten() + .any(|part| { + part.get("type") + .and_then(serde_json::Value::as_str) + .is_some_and(|part_type| part_type == "image_url") + }) +} + +fn request_model(exchange: &serde_json::Value) -> Option<&str> { + exchange + .get("request") + .and_then(|request| request.get("model")) + .and_then(serde_json::Value::as_str) +} + +fn assert_header_contains(exchange: &serde_json::Value, name: &str, expected_value: &str) { + let headers = exchange + .get("requestHeaders") + .and_then(serde_json::Value::as_object) + .expect("requestHeaders"); + let actual = headers + .iter() + .find_map(|(key, value)| key.eq_ignore_ascii_case(name).then(|| header_value(value))) + .unwrap_or_else(|| panic!("missing header {name}; actual headers: {headers:?}")); + assert!( + actual.contains(expected_value), + "header {name} value {actual:?} did not contain {expected_value:?}" + ); +} + +fn header_value(value: &serde_json::Value) -> String { + match value { + serde_json::Value::String(value) => value.clone(), + serde_json::Value::Array(values) => values + .iter() + .map(header_value) + .collect::>() + .join(","), + other => other.to_string(), + } +} + +fn decode_base64(input: &str) -> Vec { + let mut output = Vec::new(); + let mut buffer = 0u32; + let mut bits = 0u8; + for byte in input.bytes().filter(|byte| !byte.is_ascii_whitespace()) { + let value = match byte { + b'A'..=b'Z' => byte - b'A', + b'a'..=b'z' => byte - b'a' + 26, + b'0'..=b'9' => byte - b'0' + 52, + b'+' => 62, + b'/' => 63, + b'=' => break, + _ => panic!("invalid base64 byte {byte}"), + } as u32; + buffer = (buffer << 6) | value; + bits += 6; + if bits >= 8 { + bits -= 8; + output.push(((buffer >> bits) & 0xff) as u8); + } + } + output +} diff --git a/rust/tests/e2e/session_fs.rs b/rust/tests/e2e/session_fs.rs new file mode 100644 index 000000000..f069f6ffe --- /dev/null +++ b/rust/tests/e2e/session_fs.rs @@ -0,0 +1,630 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::generated::api_types::PlanUpdateRequest; +use github_copilot_sdk::{ + Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsConfig, + SessionFsConventions, SessionFsProvider, +}; + +use super::support::{assistant_message_content, wait_for_condition, with_e2e_context}; + +#[tokio::test] +async fn should_route_file_operations_through_the_session_fs_provider() { + with_e2e_context( + "session_fs", + "should_route_file_operations_through_the_session_fs_provider", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000101"; + let provider_root = ctx.work_dir().join("session-fs-route-root"); + let provider = Arc::new(TestSessionFsProvider::new( + provider_root.clone(), + session_id, + )); + let client = start_session_fs_client(ctx, provider.clone()).await; + let session = client + .create_session(session_config(ctx, provider).with_session_id(session_id)) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 100 + 200?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("300")); + let events_path = provider_root + .join(session.id().as_ref()) + .join(provider_relative_path(&session_state_path())) + .join("events.jsonl"); + wait_for_file_containing(&events_path, "300").await; + let content = std::fs::read_to_string(events_path).expect("read events"); + assert!(content.contains("300")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_load_session_data_from_fs_provider_on_resume() { + with_e2e_context( + "session_fs", + "should_load_session_data_from_fs_provider_on_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000102"; + let provider_root = ctx.work_dir().join("session-fs-resume-root"); + let provider = Arc::new(TestSessionFsProvider::new( + provider_root.clone(), + session_id, + )); + let client = start_session_fs_client(ctx, provider.clone()).await; + let session1 = client + .create_session( + session_config(ctx, provider.clone()).with_session_id(session_id), + ) + .await + .expect("create session"); + let session_id = session1.id().clone(); + let first = session1 + .send_and_wait("What is 50 + 50?") + .await + .expect("send first") + .expect("first answer"); + assert!(assistant_message_content(&first).contains("100")); + session1 + .disconnect() + .await + .expect("disconnect first session"); + + let session2 = client + .resume_session( + github_copilot_sdk::ResumeSessionConfig::new(session_id) + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_session_fs_provider(provider), + ) + .await + .expect("resume session"); + let second = session2 + .send_and_wait("What is that times 3?") + .await + .expect("send second") + .expect("second answer"); + assert!(assistant_message_content(&second).contains("300")); + + session2 + .disconnect() + .await + .expect("disconnect resumed session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_map_all_sessionfs_handler_operations() { + let root = PathBuf::from("target").join("session-fs-handler-ops"); + if root.exists() { + std::fs::remove_dir_all(&root).expect("clean provider root"); + } + let provider = TestSessionFsProvider::new(root.clone(), "handler-session"); + + provider + .mkdir("/workspace/nested", true, None) + .await + .expect("mkdir"); + provider + .write_file("/workspace/nested/file.txt", "hello", None) + .await + .expect("write"); + provider + .append_file("/workspace/nested/file.txt", " world", None) + .await + .expect("append"); + assert!( + provider + .exists("/workspace/nested/file.txt") + .await + .expect("exists") + ); + let stat = provider + .stat("/workspace/nested/file.txt") + .await + .expect("stat"); + assert!(stat.is_file); + assert!(!stat.is_directory); + assert_eq!(stat.size, "hello world".len() as i64); + assert_eq!( + provider + .read_file("/workspace/nested/file.txt") + .await + .expect("read"), + "hello world" + ); + assert!( + provider + .readdir("/workspace/nested") + .await + .expect("readdir") + .iter() + .any(|entry| entry == "file.txt") + ); + assert!( + provider + .readdir_with_types("/workspace/nested") + .await + .expect("readdir types") + .iter() + .any(|entry| entry.name == "file.txt" && entry.kind == DirEntryKind::File) + ); + provider + .rename( + "/workspace/nested/file.txt", + "/workspace/nested/renamed.txt", + ) + .await + .expect("rename"); + assert!( + !provider + .exists("/workspace/nested/file.txt") + .await + .expect("old path missing") + ); + assert_eq!( + provider + .read_file("/workspace/nested/renamed.txt") + .await + .expect("read renamed"), + "hello world" + ); + provider + .rm("/workspace/nested/renamed.txt", false, false) + .await + .expect("remove"); + assert!( + !provider + .exists("/workspace/nested/renamed.txt") + .await + .expect("removed missing") + ); + provider + .rm("/workspace/nested/missing.txt", false, true) + .await + .expect("forced remove"); + assert!(matches!( + provider.stat("/workspace/nested/missing.txt").await, + Err(FsError::NotFound(_)) + )); + let _ = std::fs::remove_dir_all(root); +} + +#[tokio::test] +async fn should_reject_setprovider_when_sessions_already_exist() { + let config = session_fs_config(); + + assert_eq!(config.initial_cwd, "/"); + assert_eq!(config.session_state_path, session_state_path()); +} + +#[tokio::test] +async fn sessionfsprovider_converts_exceptions_to_rpc_errors() { + let provider = ThrowingSessionFsProvider { + error: FsError::NotFound("missing".to_string()), + }; + assert!(matches!( + provider.read_file("missing.txt").await, + Err(FsError::NotFound(message)) if message.contains("missing") + )); + assert!( + !provider + .exists("missing.txt") + .await + .expect("exists maps errors to false") + ); + assert!(matches!( + provider.write_file("missing.txt", "content", None).await, + Err(FsError::NotFound(message)) if message.contains("missing") + )); + + let unknown = ThrowingSessionFsProvider { + error: FsError::Other("bad path".to_string()), + }; + assert!(matches!( + unknown.write_file("bad.txt", "content", None).await, + Err(FsError::Other(message)) if message.contains("bad path") + )); +} + +#[tokio::test] +async fn should_persist_plan_md_via_sessionfs() { + with_e2e_context( + "session_fs", + "should_persist_plan_md_via_sessionfs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000103"; + let provider_root = ctx.work_dir().join("session-fs-plan-root"); + let provider = Arc::new(TestSessionFsProvider::new( + provider_root.clone(), + session_id, + )); + let client = start_session_fs_client(ctx, provider.clone()).await; + let session = client + .create_session(session_config(ctx, provider).with_session_id(session_id)) + .await + .expect("create session"); + + session.send_and_wait("What is 2 + 3?").await.expect("send"); + session + .rpc() + .plan() + .update(PlanUpdateRequest { + content: "# Test Plan\n\nThis is a test.".to_string(), + }) + .await + .expect("update plan"); + let plan_path = provider_root + .join(session.id().as_ref()) + .join(provider_relative_path(&session_state_path())) + .join("plan.md"); + wait_for_file_containing(&plan_path, "This is a test.").await; + assert!( + std::fs::read_to_string(plan_path) + .expect("read plan") + .contains("This is a test.") + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_map_large_output_handling_into_sessionfs() { + let root = PathBuf::from("target").join("session-fs-large-output"); + if root.exists() { + std::fs::remove_dir_all(&root).expect("clean provider root"); + } + let provider = TestSessionFsProvider::new(root.clone(), "large-output-session"); + let content = "x".repeat(100_000); + + provider + .write_file("/session-state/temp/large.txt", &content, None) + .await + .expect("write large content"); + + assert_eq!( + provider + .read_file("/session-state/temp/large.txt") + .await + .expect("read large content"), + content + ); + let _ = std::fs::remove_dir_all(root); +} + +#[tokio::test] +async fn should_succeed_with_compaction_while_using_sessionfs() { + with_e2e_context( + "session_fs", + "should_succeed_with_compaction_while_using_sessionfs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000104"; + let provider_root = ctx.work_dir().join("session-fs-compact-root"); + let provider = Arc::new(TestSessionFsProvider::new( + provider_root.clone(), + session_id, + )); + let client = start_session_fs_client(ctx, provider.clone()).await; + let session = client + .create_session(session_config(ctx, provider).with_session_id(session_id)) + .await + .expect("create session"); + + session.send_and_wait("What is 2+2?").await.expect("send"); + let result = session + .rpc() + .history() + .compact() + .await + .expect("compact history"); + assert!(result.success); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_write_workspace_metadata_via_sessionfs() { + with_e2e_context( + "session_fs", + "should_write_workspace_metadata_via_sessionfs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000105"; + let provider_root = ctx.work_dir().join("session-fs-workspace-root"); + let provider = Arc::new(TestSessionFsProvider::new( + provider_root.clone(), + session_id, + )); + let client = start_session_fs_client(ctx, provider.clone()).await; + let session = client + .create_session(session_config(ctx, provider).with_session_id(session_id)) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is 7 * 8?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("56")); + let workspace_path = provider_root + .join(session.id().as_ref()) + .join(provider_relative_path(&session_state_path())) + .join("workspace.yaml"); + wait_for_file_containing(&workspace_path, session.id().as_ref()).await; + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +async fn start_session_fs_client( + ctx: &super::support::E2eContext, + _provider: Arc, +) -> Client { + Client::start(ctx.client_options().with_session_fs(session_fs_config())) + .await + .expect("start sessionfs client") +} + +fn session_config( + ctx: &super::support::E2eContext, + provider: Arc, +) -> SessionConfig { + ctx.approve_all_session_config() + .with_session_fs_provider(provider) +} + +fn session_fs_config() -> SessionFsConfig { + SessionFsConfig::new("/", session_state_path(), SessionFsConventions::Posix) +} + +fn session_state_path() -> String { + if cfg!(windows) { + "/session-state".to_string() + } else { + std::env::temp_dir() + .join("copilot-rust-sessionfs-state") + .join("session-state") + .to_string_lossy() + .replace('\\', "/") + } +} + +fn provider_relative_path(path: &str) -> PathBuf { + PathBuf::from(path.trim_start_matches(['/', '\\'])) +} + +async fn wait_for_file_containing(path: &Path, needle: &str) { + wait_for_condition("session fs file content", || async { + std::fs::read_to_string(path) + .map(|content| content.contains(needle)) + .unwrap_or(false) + }) + .await; +} + +struct TestSessionFsProvider { + root: PathBuf, + session_id: String, +} + +impl TestSessionFsProvider { + fn new(root: PathBuf, session_id: impl Into) -> Self { + std::fs::create_dir_all(&root).expect("create provider root"); + Self { + root, + session_id: session_id.into(), + } + } + + fn resolve(&self, path: &str) -> Result { + let root = std::fs::canonicalize(&self.root).map_err(FsError::from)?; + let mut full = root.clone(); + if self.session_id.is_empty() + || self.session_id == "." + || self.session_id == ".." + || self.session_id.contains('/') + || self.session_id.contains('\\') + || self.session_id.contains(':') + { + return Err(FsError::Other(format!( + "invalid sessionfs session id: {}", + self.session_id + ))); + } + full.push(&self.session_id); + for segment in path + .trim_start_matches(['/', '\\']) + .split(['/', '\\']) + .filter(|segment| !segment.is_empty()) + { + if segment == "." || segment == ".." || segment.contains(':') { + return Err(FsError::Other(format!("invalid sessionfs path: {path}"))); + } + full.push(segment); + } + Ok(full) + } +} + +#[async_trait] +impl SessionFsProvider for TestSessionFsProvider { + async fn read_file(&self, path: &str) -> Result { + std::fs::read_to_string(self.resolve(path)?).map_err(FsError::from) + } + + async fn write_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + let path = self.resolve(path)?; + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(FsError::from)?; + } + std::fs::write(path, content).map_err(FsError::from) + } + + async fn append_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + let path = self.resolve(path)?; + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(FsError::from)?; + } + use std::io::Write; + let mut file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .map_err(FsError::from)?; + file.write_all(content.as_bytes()).map_err(FsError::from) + } + + async fn exists(&self, path: &str) -> Result { + Ok(self.resolve(path)?.exists()) + } + + async fn stat(&self, path: &str) -> Result { + let path = self.resolve(path)?; + let metadata = std::fs::metadata(path).map_err(FsError::from)?; + Ok(FileInfo::new( + metadata.is_file(), + metadata.is_dir(), + metadata.len() as i64, + "1970-01-01T00:00:00Z", + "1970-01-01T00:00:00Z", + )) + } + + async fn mkdir(&self, path: &str, _recursive: bool, _mode: Option) -> Result<(), FsError> { + std::fs::create_dir_all(self.resolve(path)?).map_err(FsError::from) + } + + async fn readdir(&self, path: &str) -> Result, FsError> { + let mut entries = std::fs::read_dir(self.resolve(path)?) + .map_err(FsError::from)? + .map(|entry| { + entry + .map_err(FsError::from) + .map(|entry| entry.file_name().to_string_lossy().into_owned()) + }) + .collect::, _>>()?; + entries.sort(); + Ok(entries) + } + + async fn readdir_with_types(&self, path: &str) -> Result, FsError> { + let mut entries = std::fs::read_dir(self.resolve(path)?) + .map_err(FsError::from)? + .map(|entry| { + let entry = entry.map_err(FsError::from)?; + let kind = if entry.file_type().map_err(FsError::from)?.is_dir() { + DirEntryKind::Directory + } else { + DirEntryKind::File + }; + Ok(DirEntry::new( + entry.file_name().to_string_lossy().into_owned(), + kind, + )) + }) + .collect::, FsError>>()?; + entries.sort_by(|left, right| left.name.cmp(&right.name)); + Ok(entries) + } + + async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { + let path = self.resolve(path)?; + if path.is_file() { + return std::fs::remove_file(path).map_err(FsError::from); + } + if path.is_dir() { + if recursive { + return std::fs::remove_dir_all(path).map_err(FsError::from); + } + return std::fs::remove_dir(path).map_err(FsError::from); + } + if force { + Ok(()) + } else { + Err(FsError::NotFound(format!("not found: {}", path.display()))) + } + } + + async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { + let src = self.resolve(src)?; + let dest = self.resolve(dest)?; + if let Some(parent) = dest.parent() { + std::fs::create_dir_all(parent).map_err(FsError::from)?; + } + std::fs::rename(src, dest).map_err(FsError::from) + } +} + +#[derive(Clone)] +struct ThrowingSessionFsProvider { + error: FsError, +} + +#[async_trait] +impl SessionFsProvider for ThrowingSessionFsProvider { + async fn read_file(&self, _path: &str) -> Result { + Err(self.error.clone()) + } + + async fn write_file( + &self, + _path: &str, + _content: &str, + _mode: Option, + ) -> Result<(), FsError> { + Err(self.error.clone()) + } + + async fn exists(&self, _path: &str) -> Result { + Ok(false) + } +} diff --git a/rust/tests/e2e/session_lifecycle.rs b/rust/tests/e2e/session_lifecycle.rs new file mode 100644 index 000000000..e3c1fcd44 --- /dev/null +++ b/rust/tests/e2e/session_lifecycle.rs @@ -0,0 +1,257 @@ +use github_copilot_sdk::generated::session_events::SessionEventType; + +use super::support::{ + assistant_message_content, collect_until_idle, event_types, wait_for_condition, + with_e2e_context, +}; + +#[tokio::test] +async fn should_list_created_sessions_after_sending_a_message() { + with_e2e_context( + "session_lifecycle", + "should_list_created_sessions_after_sending_a_message", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create second session"); + + session1.send_and_wait("Say hello").await.expect("send one"); + session2.send_and_wait("Say world").await.expect("send two"); + + wait_for_condition("both sessions to appear in list", || { + let client = client.clone(); + let id1 = session1.id().clone(); + let id2 = session2.id().clone(); + async move { + client.list_sessions(None).await.is_ok_and(|sessions| { + let ids: std::collections::HashSet<_> = sessions + .into_iter() + .map(|session| session.session_id) + .collect(); + ids.contains(&id1) && ids.contains(&id2) + }) + } + }) + .await; + + session1 + .disconnect() + .await + .expect("disconnect first session"); + session2 + .disconnect() + .await + .expect("disconnect second session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_delete_session_permanently() { + with_e2e_context( + "session_lifecycle", + "should_delete_session_permanently", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + let session_id = session.id().clone(); + + session.send_and_wait("Say hi").await.expect("send"); + wait_for_condition("session to appear in list before delete", || { + let client = client.clone(); + let session_id = session_id.clone(); + async move { + client.list_sessions(None).await.is_ok_and(|sessions| { + sessions + .iter() + .any(|session| session.session_id == session_id) + }) + } + }) + .await; + + session.disconnect().await.expect("disconnect session"); + client + .delete_session(&session_id) + .await + .expect("delete session"); + + let after = client.list_sessions(None).await.expect("list sessions"); + assert!(!after.iter().any(|session| session.session_id == session_id)); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_return_events_via_getmessages_after_conversation() { + with_e2e_context( + "session_lifecycle", + "should_return_events_via_getmessages_after_conversation", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait("What is 2+2? Reply with just the number.") + .await + .expect("send"); + + let messages = session.get_messages().await.expect("get messages"); + let types = event_types(&messages); + assert!(types.contains(&"session.start")); + assert!(types.contains(&"user.message")); + assert!(types.contains(&"assistant.message")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_support_multiple_concurrent_sessions() { + with_e2e_context( + "session_lifecycle", + "should_support_multiple_concurrent_sessions", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create second session"); + + let (first, second) = tokio::join!( + session1.send_and_wait("What is 1+1? Reply with just the number."), + session2.send_and_wait("What is 3+3? Reply with just the number.") + ); + let first = first.expect("first send").expect("first assistant message"); + let second = second + .expect("second send") + .expect("second assistant message"); + assert!(assistant_message_content(&first).contains('2')); + assert!(assistant_message_content(&second).contains('6')); + + session1 + .disconnect() + .await + .expect("disconnect first session"); + session2 + .disconnect() + .await + .expect("disconnect second session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_isolate_events_between_concurrent_sessions() { + with_e2e_context( + "session_lifecycle", + "should_isolate_events_between_concurrent_sessions", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session1 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create first session"); + let session2 = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create second session"); + let events1 = session1.subscribe(); + let events2 = session2.subscribe(); + + session1 + .send_and_wait("Say 'session_one_response'.") + .await + .expect("send one"); + session2 + .send_and_wait("Say 'session_two_response'.") + .await + .expect("send two"); + + let observed1 = collect_until_idle(events1).await; + let observed2 = collect_until_idle(events2).await; + let messages1: Vec<_> = observed1 + .iter() + .filter(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .map(assistant_message_content) + .collect(); + let messages2: Vec<_> = observed2 + .iter() + .filter(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .map(assistant_message_content) + .collect(); + + assert!( + messages1 + .iter() + .any(|message| message.contains("session_one_response")) + ); + assert!( + !messages1 + .iter() + .any(|message| message.contains("session_two_response")) + ); + assert!( + messages2 + .iter() + .any(|message| message.contains("session_two_response")) + ); + assert!( + !messages2 + .iter() + .any(|message| message.contains("session_one_response")) + ); + + session1 + .disconnect() + .await + .expect("disconnect first session"); + session2 + .disconnect() + .await + .expect("disconnect second session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/skills.rs b/rust/tests/e2e/skills.rs new file mode 100644 index 000000000..e0005ddf0 --- /dev/null +++ b/rust/tests/e2e/skills.rs @@ -0,0 +1,178 @@ +use std::path::{Path, PathBuf}; + +use github_copilot_sdk::CustomAgentConfig; + +use super::support::{assert_uuid_like, assistant_message_content, with_e2e_context}; + +const SKILL_MARKER: &str = "PINEAPPLE_COCONUT_42"; + +#[tokio::test] +async fn should_load_and_apply_skill_from_skilldirectories() { + with_e2e_context( + "skills", + "should_load_and_apply_skill_from_skilldirectories", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let skills_dir = create_skill_dir(ctx.work_dir()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_skill_directories([skills_dir]), + ) + .await + .expect("create session"); + assert_uuid_like(session.id()); + + let answer = session + .send_and_wait("Say hello briefly using the test skill.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains(SKILL_MARKER)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_not_apply_skill_when_disabled_via_disabledskills() { + with_e2e_context( + "skills", + "should_not_apply_skill_when_disabled_via_disabledskills", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let skills_dir = create_skill_dir(ctx.work_dir()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_skill_directories([skills_dir]) + .with_disabled_skills(["test-skill"]), + ) + .await + .expect("create session"); + assert_uuid_like(session.id()); + + let answer = session + .send_and_wait("Say hello briefly using the test skill.") + .await + .expect("send") + .expect("assistant message"); + assert!(!assistant_message_content(&answer).contains(SKILL_MARKER)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_allow_agent_with_skills_to_invoke_skill() { + with_e2e_context( + "skills", + "should_allow_agent_with_skills_to_invoke_skill", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let skills_dir = create_skill_dir(ctx.work_dir()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_skill_directories([skills_dir]) + .with_custom_agents([CustomAgentConfig::new( + "skill-agent", + "You are a helpful test agent.", + ) + .with_description("An agent with access to test-skill") + .with_skills(["test-skill"])]) + .with_agent("skill-agent"), + ) + .await + .expect("create session"); + assert_uuid_like(session.id()); + + let answer = session + .send_and_wait("Say hello briefly using the test skill.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains(SKILL_MARKER)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_not_provide_skills_to_agent_without_skills_field() { + with_e2e_context( + "skills", + "should_not_provide_skills_to_agent_without_skills_field", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let skills_dir = create_skill_dir(ctx.work_dir()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_skill_directories([skills_dir]) + .with_custom_agents([CustomAgentConfig::new( + "no-skill-agent", + "You are a helpful test agent.", + ) + .with_description("An agent without skills access")]) + .with_agent("no-skill-agent"), + ) + .await + .expect("create session"); + assert_uuid_like(session.id()); + + let answer = session + .send_and_wait("Say hello briefly using the test skill.") + .await + .expect("send") + .expect("assistant message"); + assert!(!assistant_message_content(&answer).contains(SKILL_MARKER)); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[ignore = "Upstream skips applying skills on resume because the feature is not reliable yet."] +#[tokio::test] +async fn should_apply_skill_on_session_resume_with_skilldirectories() {} + +fn create_skill_dir(work_dir: &Path) -> PathBuf { + let skills_dir = work_dir.join(".test_skills"); + let skill_subdir = skills_dir.join("test-skill"); + std::fs::create_dir_all(&skill_subdir).expect("create skill dir"); + std::fs::write( + skill_subdir.join("SKILL.md"), + format!( + "---\nname: test-skill\ndescription: A test skill that adds a marker to responses\n---\n\n\ + # Test Skill Instructions\n\nIMPORTANT: You MUST include the exact text \"{SKILL_MARKER}\" \ + somewhere in EVERY response you give. This is a mandatory requirement. Include it naturally \ + in your response.\n" + ), + ) + .expect("write skill file"); + skills_dir +} diff --git a/rust/tests/e2e/streaming_fidelity.rs b/rust/tests/e2e/streaming_fidelity.rs new file mode 100644 index 000000000..72e0554ac --- /dev/null +++ b/rust/tests/e2e/streaming_fidelity.rs @@ -0,0 +1,363 @@ +use std::sync::Arc; + +use github_copilot_sdk::ResumeSessionConfig; +use github_copilot_sdk::generated::session_events::{ + AssistantMessageData, AssistantMessageDeltaData, AssistantMessageStartData, SessionEventType, + SessionStartData, +}; +use github_copilot_sdk::handler::ApproveAllHandler; + +use super::support::{collect_until_idle, event_types, with_e2e_context}; + +#[tokio::test] +async fn should_produce_delta_events_when_streaming_is_enabled() { + with_e2e_context( + "streaming_fidelity", + "should_produce_delta_events_when_streaming_is_enabled", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_streaming(true)) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("Count from 1 to 5, separated by commas.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let types = event_types(&observed); + let deltas: Vec<_> = observed + .iter() + .filter(|event| event.parsed_type() == SessionEventType::AssistantMessageDelta) + .collect(); + assert!( + !deltas.is_empty(), + "expected assistant.message_delta events" + ); + for delta in deltas { + let data = delta + .typed_data::() + .expect("assistant.message_delta data"); + assert!(!data.delta_content.is_empty()); + } + let first_delta = types + .iter() + .position(|event_type| *event_type == "assistant.message_delta") + .expect("first delta index"); + let final_message = types + .iter() + .rposition(|event_type| *event_type == "assistant.message") + .expect("assistant message index"); + assert!(first_delta < final_message); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_not_produce_deltas_when_streaming_is_disabled() { + with_e2e_context( + "streaming_fidelity", + "should_not_produce_deltas_when_streaming_is_disabled", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_streaming(false)) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("Say 'hello world'.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + assert!( + observed + .iter() + .all(|event| event.parsed_type() != SessionEventType::AssistantMessageDelta), + "streaming-disabled sessions should not emit assistant.message_delta" + ); + assert!( + observed + .iter() + .any(|event| event.parsed_type() == SessionEventType::AssistantMessage), + "expected final assistant.message" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_produce_deltas_after_session_resume() { + with_e2e_context( + "streaming_fidelity", + "should_produce_deltas_after_session_resume", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_streaming(false)) + .await + .expect("create session"); + session + .send_and_wait("What is 3 + 6?") + .await + .expect("first send"); + let session_id = session.id().clone(); + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop first client"); + + let new_client = ctx.start_client().await; + let resumed = new_client + .resume_session( + ResumeSessionConfig::new(session_id) + .with_streaming(true) + .with_handler(Arc::new(ApproveAllHandler)) + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("resume session"); + let events = resumed.subscribe(); + + let answer = resumed + .send_and_wait("Now if you double that, what do you get?") + .await + .expect("second send") + .expect("assistant message"); + assert!( + answer + .typed_data::() + .expect("assistant.message data") + .content + .contains("18") + ); + + let observed = collect_until_idle(events).await; + assert_has_content_deltas(&observed); + + resumed.disconnect().await.expect("disconnect resumed"); + new_client.stop().await.expect("stop new client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_not_produce_deltas_after_session_resume_with_streaming_disabled() { + with_e2e_context( + "streaming_fidelity", + "should_not_produce_deltas_after_session_resume_with_streaming_disabled", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_streaming(true)) + .await + .expect("create session"); + session + .send_and_wait("What is 3 + 6?") + .await + .expect("first send"); + let session_id = session.id().clone(); + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop first client"); + + let new_client = ctx.start_client().await; + let resumed = new_client + .resume_session( + ResumeSessionConfig::new(session_id) + .with_streaming(false) + .with_handler(Arc::new(ApproveAllHandler)) + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("resume session"); + let events = resumed.subscribe(); + + let answer = resumed + .send_and_wait("Now if you double that, what do you get?") + .await + .expect("second send") + .expect("assistant message"); + assert!(answer + .typed_data::() + .expect("assistant.message data") + .content + .contains("18")); + + let observed = collect_until_idle(events).await; + assert!( + observed + .iter() + .all(|event| event.parsed_type() != SessionEventType::AssistantMessageDelta), + "streaming-disabled resumed sessions should not emit deltas" + ); + assert!(observed + .iter() + .any(|event| event.parsed_type() == SessionEventType::AssistantMessage)); + + resumed.disconnect().await.expect("disconnect resumed"); + new_client.stop().await.expect("stop new client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_streaming_deltas_with_reasoning_effort_configured() { + with_e2e_context( + "streaming_fidelity", + "should_emit_streaming_deltas_with_reasoning_effort_configured", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_streaming(true) + .with_reasoning_effort("high"), + ) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("What is 15 * 17?") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + assert_has_content_deltas(&observed); + let assistant = observed + .iter() + .rev() + .find(|event| event.parsed_type() == SessionEventType::AssistantMessage) + .and_then(|event| event.typed_data::()) + .expect("assistant.message"); + assert!(assistant.content.contains("255")); + + let start = session + .get_messages() + .await + .expect("get messages") + .into_iter() + .find(|event| event.parsed_type() == SessionEventType::SessionStart) + .and_then(|event| event.typed_data::()) + .expect("session.start"); + assert_eq!(start.reasoning_effort.as_deref(), Some("high")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_emit_assistantmessage_start_before_deltas_with_matching_messageid() { + with_e2e_context( + "streaming_fidelity", + "should_emit_assistantmessagestart_before_deltas_with_matching_messageid", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config().with_streaming(true)) + .await + .expect("create session"); + let events = session.subscribe(); + + session + .send_and_wait("Count from 1 to 5, separated by commas.") + .await + .expect("send"); + + let observed = collect_until_idle(events).await; + let start_indices: Vec<_> = observed + .iter() + .enumerate() + .filter_map(|(index, event)| { + (event.parsed_type() == SessionEventType::AssistantMessageStart) + .then_some(index) + }) + .collect(); + let delta_indices: Vec<_> = observed + .iter() + .enumerate() + .filter_map(|(index, event)| { + (event.parsed_type() == SessionEventType::AssistantMessageDelta) + .then_some(index) + }) + .collect(); + assert!( + !start_indices.is_empty(), + "expected assistant.message_start" + ); + assert!( + !delta_indices.is_empty(), + "expected assistant.message_delta" + ); + assert!(start_indices[0] < delta_indices[0]); + + let message_ids: Vec<_> = observed + .iter() + .filter_map(|event| event.typed_data::()) + .map(|data| data.message_id) + .collect(); + for start_index in start_indices { + let data = observed[start_index] + .typed_data::() + .expect("assistant.message_start data"); + assert!(!data.message_id.is_empty()); + assert!(message_ids.contains(&data.message_id)); + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +fn assert_has_content_deltas(events: &[github_copilot_sdk::SessionEvent]) { + let deltas: Vec<_> = events + .iter() + .filter(|event| event.parsed_type() == SessionEventType::AssistantMessageDelta) + .collect(); + assert!( + !deltas.is_empty(), + "expected assistant.message_delta events" + ); + for delta in deltas { + let data = delta + .typed_data::() + .expect("assistant.message_delta data"); + assert!(!data.delta_content.is_empty()); + } +} diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs new file mode 100644 index 000000000..e08e3535a --- /dev/null +++ b/rust/tests/e2e/support.rs @@ -0,0 +1,764 @@ +use std::ffi::OsString; +use std::future::Future; +use std::io::{BufRead, BufReader, Read, Write}; +use std::net::TcpStream; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::process::{Child, Command, Stdio}; +use std::sync::LazyLock; +use std::time::Duration; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::session::Session; +use github_copilot_sdk::subscription::{EventSubscription, LifecycleSubscription}; +use github_copilot_sdk::{ + CliProgram, Client, ClientOptions, SessionConfig, SessionEvent, SessionId, + SessionLifecycleEvent, Transport, +}; +use serde_json::json; +use tokio::sync::Semaphore; + +static E2E_CONCURRENCY: LazyLock = LazyLock::new(|| Semaphore::new(e2e_concurrency())); + +pub const DEFAULT_TEST_TOKEN: &str = "rust-e2e-token"; + +type TestFuture<'a> = Pin + 'a>>; + +pub async fn with_e2e_context(category: &str, snapshot_name: &str, test: F) +where + F: for<'a> FnOnce(&'a mut E2eContext) -> TestFuture<'a>, +{ + let _permit = E2E_CONCURRENCY + .acquire() + .await + .expect("E2E concurrency semaphore should stay open"); + let mut ctx = E2eContext::new(category, snapshot_name) + .await + .unwrap_or_else(|err| panic!("create E2E context: {err}")); + + let timed_out = tokio::time::timeout(default_test_timeout(), test(&mut ctx)) + .await + .is_err(); + ctx.cleanup(timed_out) + .await + .unwrap_or_else(|err| panic!("clean up E2E context: {err}")); + assert!( + !timed_out, + "timed out after {:?} running E2E test {category}/{snapshot_name}", + default_test_timeout() + ); +} + +pub struct E2eContext { + repo_root: PathBuf, + cli_path: PathBuf, + home_dir: tempfile::TempDir, + work_dir: tempfile::TempDir, + proxy: Option, +} + +impl E2eContext { + async fn new(category: &str, snapshot_name: &str) -> std::io::Result { + let repo_root = repo_root(); + let cli_path = cli_path(&repo_root)?; + let home_dir = tempfile::tempdir()?; + let work_dir = tempfile::tempdir()?; + let proxy_root = repo_root.clone(); + let proxy = tokio::task::spawn_blocking(move || CapiProxy::start(&proxy_root)) + .await + .map_err(|err| std::io::Error::other(format!("proxy startup task failed: {err}")))??; + let mut ctx = Self { + repo_root, + cli_path, + home_dir, + work_dir, + proxy: Some(proxy), + }; + ctx.configure(category, snapshot_name)?; + Ok(ctx) + } + + #[expect(dead_code, reason = "used by follow-on E2E ports")] + pub fn repo_root(&self) -> &Path { + &self.repo_root + } + + pub fn work_dir(&self) -> &Path { + self.work_dir.path() + } + + pub fn proxy_url(&self) -> &str { + self.proxy().url() + } + + pub fn snapshot_path(&self, category: &str, snapshot_name: &str) -> PathBuf { + self.repo_root + .join("test") + .join("snapshots") + .join(category) + .join(format!("{snapshot_name}.yaml")) + } + + pub fn client_options(&self) -> ClientOptions { + ClientOptions::new() + .with_program(CliProgram::Path(PathBuf::from(node_program()))) + .with_prefix_args([self.cli_path.as_os_str().to_owned()]) + .with_cwd(self.work_dir.path()) + .with_env(self.environment()) + .with_use_logged_in_user(false) + } + + pub fn client_options_with_transport(&self, transport: Transport) -> ClientOptions { + self.client_options().with_transport(transport) + } + + pub async fn start_client(&self) -> Client { + Client::start(self.client_options()) + .await + .expect("start E2E client") + } + + #[expect(dead_code, reason = "used by follow-on E2E ports")] + pub async fn start_tcp_client(&self, port: u16, token: &str) -> Client { + Client::start( + self.client_options_with_transport(Transport::Tcp { port }) + .with_tcp_connection_token(token), + ) + .await + .expect("start TCP E2E client") + } + + pub fn approve_all_session_config(&self) -> SessionConfig { + SessionConfig::default() + .with_handler(std::sync::Arc::new(ApproveAllHandler)) + .with_github_token(DEFAULT_TEST_TOKEN) + } + + pub fn set_default_copilot_user(&self) { + self.set_copilot_user_by_token(DEFAULT_TEST_TOKEN); + } + + pub fn set_copilot_user_by_token(&self, token: &str) { + self.set_copilot_user_by_token_with_login(token, "rust-e2e-user"); + } + + pub fn set_copilot_user_by_token_with_login(&self, token: &str, login: &str) { + self.set_copilot_user_by_token_with_login_and_quota(token, login, None); + } + + pub fn set_copilot_user_by_token_with_login_and_quota( + &self, + token: &str, + login: &str, + quota_snapshots: Option, + ) { + let mut user = json!({ + "login": login, + "copilot_plan": "individual_pro", + "endpoints": { + "api": self.proxy_url(), + "telemetry": "https://localhost:1/telemetry" + }, + "analytics_tracking_id": "rust-e2e-tracking-id" + }); + if let Some(quota_snapshots) = quota_snapshots { + user["quota_snapshots"] = quota_snapshots; + } + self.proxy() + .set_copilot_user_by_token(token, user) + .expect("configure copilot user"); + } + + pub fn exchanges(&self) -> Vec { + self.proxy() + .get_json("/exchanges") + .expect("get captured proxy exchanges") + } + + pub async fn cleanup(&mut self, skip_writing_cache: bool) -> std::io::Result<()> { + if let Some(mut proxy) = self.proxy.take() { + tokio::task::spawn_blocking(move || proxy.stop(skip_writing_cache)) + .await + .map_err(|err| { + std::io::Error::other(format!("proxy shutdown task failed: {err}")) + })??; + } + Ok(()) + } + + fn configure(&mut self, category: &str, snapshot_name: &str) -> std::io::Result<()> { + let snapshot_path = self.snapshot_path(category, snapshot_name); + self.proxy() + .configure(&snapshot_path, self.work_dir.path()) + .map_err(|err| { + std::io::Error::other(format!( + "configure proxy for {} failed: {err}", + snapshot_path.display() + )) + }) + } + + fn environment(&self) -> Vec<(OsString, OsString)> { + let mut env = self.proxy().proxy_env(); + env.extend([ + ("COPILOT_API_URL".into(), self.proxy_url().into()), + ( + "COPILOT_DEBUG_GITHUB_API_URL".into(), + self.proxy_url().into(), + ), + ( + "COPILOT_HOME".into(), + canonical_temp_path(self.home_dir.path()) + .as_os_str() + .to_owned(), + ), + ( + "GH_CONFIG_DIR".into(), + canonical_temp_path(self.home_dir.path()) + .as_os_str() + .to_owned(), + ), + ( + "XDG_CONFIG_HOME".into(), + canonical_temp_path(self.home_dir.path()) + .as_os_str() + .to_owned(), + ), + ( + "XDG_STATE_HOME".into(), + canonical_temp_path(self.home_dir.path()) + .as_os_str() + .to_owned(), + ), + ]); + if std::env::var("GITHUB_ACTIONS").as_deref() == Ok("true") { + env.push(("GH_TOKEN".into(), "fake-token-for-e2e-tests".into())); + env.push(("GITHUB_TOKEN".into(), "fake-token-for-e2e-tests".into())); + } + env + } + + fn proxy(&self) -> &CapiProxy { + self.proxy.as_ref().expect("proxy already stopped") + } +} + +impl Drop for E2eContext { + fn drop(&mut self) { + if let Some(mut proxy) = self.proxy.take() { + let _ = proxy.stop(true); + } + } +} + +pub async fn wait_for_event

( + events: EventSubscription, + description: &'static str, + predicate: P, +) -> SessionEvent +where + P: Fn(&SessionEvent) -> bool, +{ + wait_for_event_core(events, description, predicate, false).await +} + +pub async fn wait_for_event_allowing_rate_limit

( + events: EventSubscription, + description: &'static str, + predicate: P, +) -> SessionEvent +where + P: Fn(&SessionEvent) -> bool, +{ + wait_for_event_core(events, description, predicate, true).await +} + +async fn wait_for_event_core

( + mut events: EventSubscription, + description: &'static str, + predicate: P, + allow_rate_limit_error: bool, +) -> SessionEvent +where + P: Fn(&SessionEvent) -> bool, +{ + tokio::time::timeout(default_event_timeout(), async { + loop { + let event = events.recv().await.unwrap_or_else(|err| { + panic!("event stream closed while waiting for {description}: {err}") + }); + let is_allowed_rate_limit = allow_rate_limit_error + && event.parsed_type() + == github_copilot_sdk::generated::session_events::SessionEventType::SessionError + && event.data.get("errorType").and_then(|value| value.as_str()) + == Some("rate_limit"); + if event.parsed_type() + == github_copilot_sdk::generated::session_events::SessionEventType::SessionError + && !is_allowed_rate_limit + { + panic!( + "session.error while waiting for {description}: {}", + event.data + ); + } + if predicate(&event) { + return event; + } + } + }) + .await + .unwrap_or_else(|_| panic!("timed out waiting for {description}")) +} + +pub async fn recv_with_timeout( + receiver: &mut tokio::sync::mpsc::UnboundedReceiver, + description: &'static str, +) -> T { + tokio::time::timeout(default_event_timeout(), receiver.recv()) + .await + .unwrap_or_else(|_| panic!("timed out waiting for {description}")) + .unwrap_or_else(|| panic!("{description} channel closed")) +} + +pub async fn wait_for_lifecycle_event

( + mut events: LifecycleSubscription, + description: &'static str, + predicate: P, +) -> SessionLifecycleEvent +where + P: Fn(&SessionLifecycleEvent) -> bool, +{ + tokio::time::timeout(default_event_timeout(), async { + loop { + let event = events.recv().await.unwrap_or_else(|err| { + panic!("lifecycle stream closed while waiting for {description}: {err}") + }); + if predicate(&event) { + return event; + } + } + }) + .await + .unwrap_or_else(|_| panic!("timed out waiting for {description}")) +} + +pub async fn wait_for_condition(description: &'static str, mut predicate: F) +where + F: FnMut() -> Fut, + Fut: Future, +{ + let deadline = tokio::time::Instant::now() + default_event_timeout(); + loop { + if predicate().await { + return; + } + assert!( + tokio::time::Instant::now() < deadline, + "timed out waiting for {description}" + ); + tokio::time::sleep(Duration::from_millis(100)).await; + } +} + +pub async fn collect_until_idle(mut events: EventSubscription) -> Vec { + let mut observed = Vec::new(); + tokio::time::timeout(default_event_timeout(), async { + loop { + let event = events + .recv() + .await + .unwrap_or_else(|err| panic!("event stream closed while collecting events: {err}")); + let is_idle = event.parsed_type() + == github_copilot_sdk::generated::session_events::SessionEventType::SessionIdle; + if event.parsed_type() + == github_copilot_sdk::generated::session_events::SessionEventType::SessionError + { + panic!("session.error while collecting events: {}", event.data); + } + observed.push(event); + if is_idle { + return; + } + } + }) + .await + .expect("timed out collecting events through session.idle"); + observed +} + +pub fn event_types(events: &[SessionEvent]) -> Vec<&str> { + events + .iter() + .map(|event| event.event_type.as_str()) + .collect() +} + +#[allow(dead_code, reason = "used by follow-on E2E ports")] +pub async fn wait_for_idle(session: &Session) -> SessionEvent { + wait_for_event(session.subscribe(), "session.idle event", |event| { + event.parsed_type() + == github_copilot_sdk::generated::session_events::SessionEventType::SessionIdle + }) + .await +} + +#[allow(dead_code, reason = "used by follow-on E2E ports")] +pub async fn wait_for_final_assistant_message(session: &Session) -> SessionEvent { + wait_for_idle(session).await; + last_assistant_message(session).await +} + +#[allow(dead_code, reason = "used by follow-on E2E ports")] +pub async fn last_assistant_message(session: &Session) -> SessionEvent { + session + .get_messages() + .await + .expect("get session messages") + .into_iter() + .rev() + .find(|event| { + event.parsed_type() + == github_copilot_sdk::generated::session_events::SessionEventType::AssistantMessage + }) + .expect("assistant.message event") +} + +pub fn assistant_message_content(event: &SessionEvent) -> String { + event + .typed_data::() + .expect("assistant.message data") + .content +} + +pub fn assert_uuid_like(session_id: &SessionId) { + let text = session_id.as_str(); + let parsed = uuid::Uuid::parse_str(text).expect("session id should be UUID-shaped"); + assert_eq!( + parsed.hyphenated().to_string(), + text, + "session id should use canonical hyphenated UUID formatting" + ); +} + +fn default_event_timeout() -> Duration { + if cfg!(windows) { + Duration::from_secs(120) + } else { + Duration::from_secs(60) + } +} + +fn default_test_timeout() -> Duration { + if cfg!(windows) { + Duration::from_secs(300) + } else { + Duration::from_secs(180) + } +} + +fn e2e_concurrency() -> usize { + std::env::var("RUST_E2E_CONCURRENCY") + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|&value| value > 0) + .unwrap_or(4) +} + +pub fn get_system_message(exchange: &serde_json::Value) -> String { + exchange + .get("request") + .and_then(|request| request.get("messages")) + .and_then(serde_json::Value::as_array) + .and_then(|messages| { + messages.iter().find_map(|message| { + let role = message.get("role").and_then(serde_json::Value::as_str)?; + if role == "system" { + message + .get("content") + .and_then(serde_json::Value::as_str) + .map(str::to_string) + } else { + None + } + }) + }) + .unwrap_or_default() +} + +pub fn get_tool_names(exchange: &serde_json::Value) -> Vec { + exchange + .get("request") + .and_then(|request| request.get("tools")) + .and_then(serde_json::Value::as_array) + .map(|tools| { + tools + .iter() + .filter_map(|tool| { + tool.get("function") + .and_then(|function| function.get("name")) + .and_then(serde_json::Value::as_str) + .map(str::to_string) + }) + .collect() + }) + .unwrap_or_default() +} + +fn repo_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .parent() + .expect("rust package has parent repo") + .to_path_buf() +} + +fn cli_path(repo_root: &Path) -> std::io::Result { + if let Some(path) = std::env::var_os("COPILOT_CLI_PATH") { + let path = PathBuf::from(path); + if path.exists() { + return Ok(path); + } + } + + let path = repo_root + .join("nodejs") + .join("node_modules") + .join("@github") + .join("copilot") + .join("index.js"); + if path.exists() { + return Ok(path); + } + + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!( + "CLI not found at {}; run npm install in nodejs first", + path.display() + ), + )) +} + +fn canonical_temp_path(path: &Path) -> PathBuf { + std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf()) +} + +struct CapiProxy { + child: Option, + proxy_url: String, + connect_proxy_url: String, + ca_file_path: String, +} + +impl CapiProxy { + fn start(repo_root: &Path) -> std::io::Result { + let mut child = Command::new(npx_program()) + .args(["tsx", "server.ts"]) + .current_dir(repo_root.join("test").join("harness")) + .env("GITHUB_ACTIONS", "true") + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn()?; + + let stdout = child.stdout.take().expect("proxy stdout"); + let reader = BufReader::new(stdout); + let re = regex::Regex::new(r"Listening: (http://[^\s]+)\s+(\{.*\})$").unwrap(); + for line in reader.lines() { + let line = line?; + if let Some(captures) = re.captures(&line) { + let metadata: serde_json::Value = + serde_json::from_str(captures.get(2).unwrap().as_str())?; + let connect_proxy_url = metadata + .get("connectProxyUrl") + .and_then(|value| value.as_str()) + .expect("connectProxyUrl") + .to_string(); + let ca_file_path = metadata + .get("caFilePath") + .and_then(|value| value.as_str()) + .expect("caFilePath") + .to_string(); + return Ok(Self { + child: Some(child), + proxy_url: captures.get(1).unwrap().as_str().to_string(), + connect_proxy_url, + ca_file_path, + }); + } + if line.contains("Listening: ") { + return Err(std::io::Error::other(format!( + "proxy startup line missing metadata: {line}" + ))); + } + } + + Err(std::io::Error::other("proxy exited before startup")) + } + + fn url(&self) -> &str { + &self.proxy_url + } + + fn configure(&self, file_path: &Path, work_dir: &Path) -> std::io::Result<()> { + self.post_json( + "/config", + &json!({ + "filePath": file_path, + "workDir": work_dir, + }) + .to_string(), + ) + } + + fn set_copilot_user_by_token( + &self, + token: &str, + response: serde_json::Value, + ) -> std::io::Result<()> { + self.post_json( + "/copilot-user-config", + &json!({ + "token": token, + "response": response, + }) + .to_string(), + ) + } + + fn stop(&mut self, skip_writing_cache: bool) -> std::io::Result<()> { + let path = if skip_writing_cache { + "/stop?skipWritingCache=true" + } else { + "/stop" + }; + let result = self.post_json(path, ""); + if let Some(mut child) = self.child.take() { + let _ = child.wait(); + } + result + } + + fn proxy_env(&self) -> Vec<(OsString, OsString)> { + let no_proxy = "127.0.0.1,localhost,::1"; + [ + ("HTTP_PROXY", self.connect_proxy_url.as_str()), + ("HTTPS_PROXY", self.connect_proxy_url.as_str()), + ("http_proxy", self.connect_proxy_url.as_str()), + ("https_proxy", self.connect_proxy_url.as_str()), + ("NO_PROXY", no_proxy), + ("no_proxy", no_proxy), + ("NODE_EXTRA_CA_CERTS", self.ca_file_path.as_str()), + ("SSL_CERT_FILE", self.ca_file_path.as_str()), + ("REQUESTS_CA_BUNDLE", self.ca_file_path.as_str()), + ("CURL_CA_BUNDLE", self.ca_file_path.as_str()), + ("GIT_SSL_CAINFO", self.ca_file_path.as_str()), + ("GH_TOKEN", ""), + ("GITHUB_TOKEN", ""), + ("GH_ENTERPRISE_TOKEN", ""), + ("GITHUB_ENTERPRISE_TOKEN", ""), + ] + .into_iter() + .map(|(key, value)| (key.into(), value.into())) + .collect() + } + + fn post_json(&self, path: &str, body: &str) -> std::io::Result<()> { + let response = self.request("POST", path, body)?; + if !response.starts_with("HTTP/1.1 200") && !response.starts_with("HTTP/1.1 204") { + return Err(std::io::Error::other(format!( + "proxy POST {path} failed: {response}" + ))); + } + Ok(()) + } + + fn get_json(&self, path: &str) -> std::io::Result { + let response = self.request("GET", path, "")?; + if !response.starts_with("HTTP/1.1 200") { + return Err(std::io::Error::other(format!( + "proxy GET {path} failed: {response}" + ))); + } + let body = response_body(&response)?; + serde_json::from_str(&body).map_err(std::io::Error::other) + } + + fn request(&self, method: &str, path: &str, body: &str) -> std::io::Result { + let (host, port) = parse_http_url(&self.proxy_url)?; + let mut stream = TcpStream::connect((host.as_str(), port))?; + write!( + stream, + "{method} {path} HTTP/1.1\r\nHost: {host}:{port}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", + body.len() + )?; + + let mut response = String::new(); + stream.read_to_string(&mut response)?; + Ok(response) + } +} + +impl Drop for CapiProxy { + fn drop(&mut self) { + if self.child.is_some() { + let _ = self.stop(true); + } + } +} + +fn response_body(response: &str) -> std::io::Result { + let Some((headers, body)) = response.split_once("\r\n\r\n") else { + return Ok(String::new()); + }; + if headers + .lines() + .any(|line| line.eq_ignore_ascii_case("Transfer-Encoding: chunked")) + { + return decode_chunked_body(body); + } + Ok(body.to_string()) +} + +fn decode_chunked_body(body: &str) -> std::io::Result { + let mut rest = body; + let mut decoded = String::new(); + loop { + let Some((size_line, after_size)) = rest.split_once("\r\n") else { + return Err(std::io::Error::other("malformed chunked response")); + }; + let size_text = size_line + .split_once(';') + .map_or(size_line, |(size, _)| size); + let size = usize::from_str_radix(size_text.trim(), 16) + .map_err(|err| std::io::Error::other(format!("invalid chunk size: {err}")))?; + if size == 0 { + return Ok(decoded); + } + if after_size.len() < size + 2 { + return Err(std::io::Error::other("truncated chunked response")); + } + decoded.push_str(&after_size[..size]); + rest = &after_size[size + 2..]; + } +} + +fn parse_http_url(url: &str) -> std::io::Result<(String, u16)> { + let without_scheme = url + .strip_prefix("http://") + .ok_or_else(|| std::io::Error::other(format!("unsupported proxy URL: {url}")))?; + let (host, port) = without_scheme + .rsplit_once(':') + .ok_or_else(|| std::io::Error::other(format!("proxy URL missing port: {url}")))?; + let port = port + .parse::() + .map_err(|err| std::io::Error::other(format!("invalid proxy URL port: {err}")))?; + Ok((host.to_string(), port)) +} + +fn node_program() -> &'static str { + if cfg!(windows) { "node.exe" } else { "node" } +} + +fn npx_program() -> &'static str { + if cfg!(windows) { "npx.cmd" } else { "npx" } +} diff --git a/rust/tests/e2e/suspend.rs b/rust/tests/e2e/suspend.rs new file mode 100644 index 000000000..5a9386147 --- /dev/null +++ b/rust/tests/e2e/suspend.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use github_copilot_sdk::ResumeSessionConfig; + +use super::support::{DEFAULT_TEST_TOKEN, assistant_message_content, with_e2e_context}; + +#[tokio::test] +async fn should_suspend_idle_session_without_throwing() { + with_e2e_context( + "suspend", + "should_suspend_idle_session_without_throwing", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait("Reply with: SUSPEND_IDLE_OK") + .await + .expect("send"); + session.rpc().suspend().await.expect("suspend session"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_allow_resume_and_continue_conversation_after_suspend() { + with_e2e_context( + "suspend", + "should_allow_resume_and_continue_conversation_after_suspend", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session + .send_and_wait( + "Remember the magic word: SUSPENSE. Reply with: SUSPEND_TURN_ONE", + ) + .await + .expect("first send"); + let session_id = session.id().clone(); + session.rpc().suspend().await.expect("suspend session"); + session.disconnect().await.expect("disconnect first session"); + client.stop().await.expect("stop first client"); + + let second_client = ctx.start_client().await; + let resumed = second_client + .resume_session( + ResumeSessionConfig::new(session_id) + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new( + github_copilot_sdk::handler::ApproveAllHandler, + )), + ) + .await + .expect("resume session"); + let answer = resumed + .send_and_wait( + "What was the magic word I asked you to remember? Reply with just the word.", + ) + .await + .expect("follow-up send") + .expect("assistant message"); + assert!(assistant_message_content(&answer) + .to_lowercase() + .contains("suspense")); + + resumed.disconnect().await.expect("disconnect resumed"); + second_client.stop().await.expect("stop second client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/e2e/system_message_transform.rs b/rust/tests/e2e/system_message_transform.rs new file mode 100644 index 000000000..10cc594ca --- /dev/null +++ b/rust/tests/e2e/system_message_transform.rs @@ -0,0 +1,187 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; +use github_copilot_sdk::{SectionOverride, SessionConfig, SystemMessageConfig}; +use tokio::sync::mpsc; + +use super::support::{DEFAULT_TEST_TOKEN, get_system_message, recv_with_timeout, with_e2e_context}; + +#[tokio::test] +async fn should_invoke_transform_callbacks_with_section_content() { + with_e2e_context( + "system_message_transform", + "should_invoke_transform_callbacks_with_section_content", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("test.txt"), "Hello transform!") + .expect("write test file"); + let (section_tx, mut section_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_transform(Arc::new(RecordingTransform { + section_ids: vec!["identity", "tone"], + suffix: None, + section_tx, + })), + ) + .await + .expect("create session"); + + session + .send_and_wait("Read the contents of test.txt and tell me what it says") + .await + .expect("send"); + + let first = recv_with_timeout(&mut section_rx, "first transform").await; + let second = recv_with_timeout(&mut section_rx, "second transform").await; + assert!(first.1 > 0); + assert!(second.1 > 0); + let sections = [first.0, second.0]; + assert!(sections.contains(&"identity".to_string())); + assert!(sections.contains(&"tone".to_string())); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_apply_transform_modifications_to_section_content() { + with_e2e_context( + "system_message_transform", + "should_apply_transform_modifications_to_section_content", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("hello.txt"), "Hello!") + .expect("write hello file"); + let (section_tx, _section_rx) = mpsc::unbounded_channel(); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_transform(Arc::new(RecordingTransform { + section_ids: vec!["identity"], + suffix: Some("\nAlways end your reply with TRANSFORM_MARKER"), + section_tx, + })), + ) + .await + .expect("create session"); + + session + .send_and_wait("Read the contents of hello.txt") + .await + .expect("send"); + + let exchanges = ctx.exchanges(); + assert!(!exchanges.is_empty()); + assert!(get_system_message(&exchanges[0]).contains("TRANSFORM_MARKER")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_work_with_static_overrides_and_transforms_together() { + with_e2e_context( + "system_message_transform", + "should_work_with_static_overrides_and_transforms_together", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write(ctx.work_dir().join("combo.txt"), "Combo test!") + .expect("write combo file"); + let (section_tx, mut section_rx) = mpsc::unbounded_channel(); + let mut sections = HashMap::new(); + sections.insert( + "safety".to_string(), + SectionOverride { + action: Some("remove".to_string()), + content: None, + }, + ); + let client = ctx.start_client().await; + let session = client + .create_session( + SessionConfig::default() + .with_github_token(DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(github_copilot_sdk::handler::ApproveAllHandler)) + .with_system_message( + SystemMessageConfig::new() + .with_mode("customize") + .with_sections(sections), + ) + .with_transform(Arc::new(RecordingTransform { + section_ids: vec!["identity"], + suffix: None, + section_tx, + })), + ) + .await + .expect("create session"); + + session + .send_and_wait("Read the contents of combo.txt and tell me what it says") + .await + .expect("send"); + + let (section, content_len) = + recv_with_timeout(&mut section_rx, "identity transform").await; + assert_eq!(section, "identity"); + assert!(content_len > 0); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +struct RecordingTransform { + section_ids: Vec<&'static str>, + suffix: Option<&'static str>, + section_tx: mpsc::UnboundedSender<(String, usize)>, +} + +#[async_trait] +impl SystemMessageTransform for RecordingTransform { + fn section_ids(&self) -> Vec { + self.section_ids + .iter() + .map(|section| (*section).to_string()) + .collect() + } + + async fn transform_section( + &self, + section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + let _ = self + .section_tx + .send((section_id.to_string(), content.len())); + Some(match self.suffix { + Some(suffix) => format!("{content}{suffix}"), + None => content.to_string(), + }) + } +} diff --git a/rust/tests/e2e/telemetry.rs b/rust/tests/e2e/telemetry.rs new file mode 100644 index 000000000..0685ac284 --- /dev/null +++ b/rust/tests/e2e/telemetry.rs @@ -0,0 +1,233 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::{ + Client, Error, OtelExporterType, SessionConfig, TelemetryConfig, Tool, ToolInvocation, + ToolResult, +}; +use serde_json::json; + +use super::support::{assistant_message_content, wait_for_condition, with_e2e_context}; + +#[tokio::test] +async fn should_export_file_telemetry_for_sdk_interactions() { + with_e2e_context( + "telemetry", + "should_export_file_telemetry_for_sdk_interactions", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let telemetry_path = ctx.work_dir().join("rust-telemetry-e2e.jsonl"); + let source_name = "rust-sdk-telemetry-e2e"; + let tool_name = "echo_telemetry_marker"; + let marker = "copilot-sdk-telemetry-e2e"; + let prompt = format!( + "Use the {tool_name} tool with value '{marker}', then respond with TELEMETRY_E2E_DONE." + ); + + let client = Client::start(ctx.client_options().with_telemetry( + TelemetryConfig::new() + .with_file_path(&telemetry_path) + .with_exporter_type(OtelExporterType::File) + .with_source_name(source_name) + .with_capture_content(true), + )) + .await + .expect("start client"); + let router = ToolHandlerRouter::new( + vec![Box::new(EchoTelemetryTool { + name: tool_name.to_string(), + })], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait(prompt.as_str()) + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("TELEMETRY_E2E_DONE")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + + let entries = read_telemetry_entries(&telemetry_path).await; + let spans: Vec<_> = entries + .iter() + .filter(|entry| string_property(entry, "type") == Some("span")) + .collect(); + assert!(!spans.is_empty(), "expected telemetry spans in {entries:?}"); + assert!(spans.iter().all(|span| { + span.get("instrumentationScope") + .and_then(|scope| string_property(scope, "name")) + == Some(source_name) + })); + + let trace_ids: std::collections::HashSet<_> = spans + .iter() + .filter_map(|span| string_property(span, "traceId")) + .collect(); + assert_eq!(trace_ids.len(), 1); + assert!(spans.iter().all(|span| status_code(span) != Some(2))); + + let invoke_agent = find_span(&spans, "invoke_agent"); + assert_eq!( + string_attribute(invoke_agent, "gen_ai.conversation.id").as_deref(), + Some(session.id().as_str()) + ); + let invoke_agent_span_id = + string_property(invoke_agent, "spanId").expect("invoke_agent span id"); + assert!(is_root_span(invoke_agent)); + + let chat_spans: Vec<_> = spans + .iter() + .copied() + .filter(|span| { + string_attribute(span, "gen_ai.operation.name").as_deref() == Some("chat") + }) + .collect(); + assert!(!chat_spans.is_empty()); + assert!(chat_spans.iter().all(|span| { + string_property(span, "parentSpanId") == Some(invoke_agent_span_id) + })); + assert!(chat_spans.iter().any(|span| string_attribute( + span, + "gen_ai.input.messages" + ) + .is_some_and(|messages| messages.contains(&prompt)))); + assert!(chat_spans.iter().any(|span| string_attribute( + span, + "gen_ai.output.messages" + ) + .is_some_and(|messages| messages.contains("TELEMETRY_E2E_DONE")))); + + let tool_span = find_span(&spans, "execute_tool"); + assert_eq!( + string_property(tool_span, "parentSpanId"), + Some(invoke_agent_span_id) + ); + assert_eq!( + string_attribute(tool_span, "gen_ai.tool.name").as_deref(), + Some(tool_name) + ); + assert_eq!( + string_attribute(tool_span, "gen_ai.tool.call.arguments").as_deref(), + Some(format!("{{\"value\":\"{marker}\"}}").as_str()) + ); + assert_eq!( + string_attribute(tool_span, "gen_ai.tool.call.result").as_deref(), + Some(marker) + ); + }) + }, + ) + .await; +} + +struct EchoTelemetryTool { + name: String, +} + +#[async_trait] +impl ToolHandler for EchoTelemetryTool { + fn tool(&self) -> Tool { + Tool::new(&self.name) + .with_description("Echoes a marker string for telemetry validation.") + .with_parameters(json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + Ok(ToolResult::Text( + invocation + .arguments + .get("value") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(), + )) + } +} + +async fn read_telemetry_entries(path: &std::path::Path) -> Vec { + wait_for_condition("telemetry file to contain spans", || { + let path = path.to_path_buf(); + async move { + read_telemetry_entries_once(&path).is_ok_and(|entries| { + entries.iter().any(|entry| { + string_property(entry, "type") == Some("span") + && string_attribute(entry, "gen_ai.operation.name").as_deref() + == Some("invoke_agent") + }) + }) + } + }) + .await; + read_telemetry_entries_once(path).expect("read telemetry entries") +} + +fn read_telemetry_entries_once(path: &std::path::Path) -> std::io::Result> { + if !path.exists() || path.metadata()?.len() == 0 { + return Ok(Vec::new()); + } + std::fs::read_to_string(path).map(|content| { + content + .lines() + .filter(|line| !line.trim().is_empty()) + .map(|line| serde_json::from_str(line).expect("telemetry JSON line")) + .collect() + }) +} + +fn find_span<'a>(spans: &'a [&'a serde_json::Value], operation: &str) -> &'a serde_json::Value { + spans + .iter() + .copied() + .find(|span| string_attribute(span, "gen_ai.operation.name").as_deref() == Some(operation)) + .unwrap_or_else(|| panic!("span {operation} not found in {spans:?}")) +} + +fn string_property<'a>(value: &'a serde_json::Value, name: &str) -> Option<&'a str> { + value.get(name).and_then(serde_json::Value::as_str) +} + +fn string_attribute(value: &serde_json::Value, name: &str) -> Option { + value + .get("attributes") + .and_then(|attributes| attributes.get(name)) + .map(|value| match value { + serde_json::Value::String(value) => value.clone(), + serde_json::Value::Number(_) | serde_json::Value::Bool(_) => value.to_string(), + serde_json::Value::Array(_) | serde_json::Value::Object(_) => value.to_string(), + serde_json::Value::Null => String::new(), + }) +} + +fn status_code(value: &serde_json::Value) -> Option { + value + .get("status") + .and_then(|status| status.get("code")) + .and_then(serde_json::Value::as_i64) +} + +fn is_root_span(value: &serde_json::Value) -> bool { + string_property(value, "parentSpanId") + .is_none_or(|parent| parent.is_empty() || parent == "0000000000000000") +} diff --git a/rust/tests/e2e/tool_results.rs b/rust/tests/e2e/tool_results.rs new file mode 100644 index 000000000..260e25993 --- /dev/null +++ b/rust/tests/e2e/tool_results.rs @@ -0,0 +1,361 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use github_copilot_sdk::generated::session_events::{SessionEventType, ToolExecutionCompleteData}; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::{ + Error, SessionConfig, Tool, ToolInvocation, ToolResult, ToolResultExpanded, +}; +use serde_json::json; +use tokio::sync::mpsc; + +use super::support::{assistant_message_content, collect_until_idle, with_e2e_context}; + +#[tokio::test] +async fn should_handle_structured_toolresultobject_from_custom_tool() { + with_e2e_context( + "tool_results", + "should_handle_structured_toolresultobject_from_custom_tool", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = create_tool_session(ctx, &client, WeatherTool).await; + + let answer = session + .send_and_wait("What's the weather in Paris?") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&answer).to_lowercase(); + assert!(content.contains("sunny") || content.contains("72")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_tool_result_with_failure_resulttype() { + with_e2e_context( + "tool_results", + "should_handle_tool_result_with_failure_resulttype", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = create_tool_session(ctx, &client, CheckStatusTool).await; + + let answer = session + .send_and_wait("Check the status of the service using check_status. If it fails, say 'service is down'.") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer) + .to_lowercase() + .contains("service is down")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_preserve_tooltelemetry_and_not_stringify_structured_results_for_llm() { + with_e2e_context( + "tool_results", + "should_preserve_tooltelemetry_and_not_stringify_structured_results_for_llm", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = create_tool_session(ctx, &client, AnalyzeCodeTool).await; + + let answer = session + .send_and_wait("Analyze the file main.ts for issues.") + .await + .expect("send") + .expect("assistant message"); + assert!( + assistant_message_content(&answer) + .to_lowercase() + .contains("no issues") + ); + + let exchanges = ctx.exchanges(); + let tool_results: Vec<_> = exchanges + .last() + .and_then(|exchange| exchange.get("request")) + .and_then(|request| request.get("messages")) + .and_then(serde_json::Value::as_array) + .expect("messages") + .iter() + .filter(|message| { + message.get("role").and_then(serde_json::Value::as_str) == Some("tool") + }) + .collect(); + assert_eq!(tool_results.len(), 1); + let content = tool_results[0].to_string(); + assert!(!content.contains("toolTelemetry")); + assert!(!content.contains("resultType")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_tool_result_with_rejected_resulttype() { + with_e2e_context( + "tool_results", + "should_handle_tool_result_with_rejected_resulttype", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (call_tx, mut call_rx) = mpsc::unbounded_channel(); + let session = create_tool_session(ctx, &client, DeployTool { call_tx }).await; + let events = session.subscribe(); + + session + .send("Deploy the service using deploy_service. If it's rejected, tell me it was 'rejected by policy'.") + .await + .expect("send"); + recv_called(&mut call_rx, "deploy tool").await; + let observed = collect_until_idle(events).await; + let complete = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::ToolExecutionComplete) + .and_then(|event| event.typed_data::()) + .expect("tool.execution_complete"); + assert!(!complete.success); + let error = complete.error.expect("tool error"); + assert_eq!(error.code.as_deref(), Some("rejected")); + assert!(error.message.contains("Deployment rejected")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_handle_tool_result_with_denied_resulttype() { + with_e2e_context( + "tool_results", + "should_handle_tool_result_with_denied_resulttype", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (call_tx, mut call_rx) = mpsc::unbounded_channel(); + let session = create_tool_session(ctx, &client, AccessSecretTool { call_tx }).await; + let events = session.subscribe(); + + session + .send("Use access_secret to get the API key. If access is denied, tell me it was 'access denied'.") + .await + .expect("send"); + recv_called(&mut call_rx, "access secret tool").await; + let observed = collect_until_idle(events).await; + let complete = observed + .iter() + .find(|event| event.parsed_type() == SessionEventType::ToolExecutionComplete) + .and_then(|event| event.typed_data::()) + .expect("tool.execution_complete"); + assert!(!complete.success); + let error = complete.error.expect("tool error"); + assert_eq!(error.code.as_deref(), Some("denied")); + assert!(error.message.contains("Access denied")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +async fn create_tool_session( + _ctx: &super::support::E2eContext, + client: &github_copilot_sdk::Client, + tool: T, +) -> github_copilot_sdk::session::Session +where + T: ToolHandler + 'static, +{ + let router = ToolHandlerRouter::new(vec![Box::new(tool)], Arc::new(ApproveAllHandler)); + let tools = router.tools(); + client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session") +} + +async fn recv_called(receiver: &mut mpsc::UnboundedReceiver<()>, description: &'static str) { + tokio::time::timeout(std::time::Duration::from_secs(10), receiver.recv()) + .await + .unwrap_or_else(|_| panic!("timed out waiting for {description}")) + .unwrap_or_else(|| panic!("{description} channel closed")); +} + +fn expanded(text: impl Into, result_type: impl Into) -> ToolResult { + ToolResult::Expanded(ToolResultExpanded { + text_result_for_llm: text.into(), + result_type: result_type.into(), + binary_results_for_llm: None, + session_log: None, + error: None, + tool_telemetry: None, + }) +} + +struct WeatherTool; + +#[async_trait::async_trait] +impl ToolHandler for WeatherTool { + fn tool(&self) -> Tool { + string_tool( + "get_weather", + "Gets weather for a city", + "city", + "City name", + ) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let city = invocation + .arguments + .get("city") + .and_then(serde_json::Value::as_str) + .unwrap_or("Paris"); + Ok(expanded( + format!("The weather in {city} is sunny and 72\u{b0}F"), + "success", + )) + } +} + +struct CheckStatusTool; + +#[async_trait::async_trait] +impl ToolHandler for CheckStatusTool { + fn tool(&self) -> Tool { + Tool::new("check_status").with_description("Checks the status of a service") + } + + async fn call(&self, _invocation: ToolInvocation) -> Result { + let mut result = match expanded("Service unavailable", "failure") { + ToolResult::Expanded(result) => result, + _ => unreachable!(), + }; + result.error = Some("API timeout".to_string()); + Ok(ToolResult::Expanded(result)) + } +} + +struct AnalyzeCodeTool; + +#[async_trait::async_trait] +impl ToolHandler for AnalyzeCodeTool { + fn tool(&self) -> Tool { + string_tool( + "analyze_code", + "Analyzes code for issues", + "file", + "File to analyze", + ) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let file = invocation + .arguments + .get("file") + .and_then(serde_json::Value::as_str) + .unwrap_or("main.ts"); + let mut result = match expanded(format!("Analysis of {file}: no issues found"), "success") { + ToolResult::Expanded(result) => result, + _ => unreachable!(), + }; + result.tool_telemetry = Some(HashMap::from([( + "metrics".to_string(), + json!({ "analysisTimeMs": 150 }), + )])); + Ok(ToolResult::Expanded(result)) + } +} + +struct DeployTool { + call_tx: mpsc::UnboundedSender<()>, +} + +#[async_trait::async_trait] +impl ToolHandler for DeployTool { + fn tool(&self) -> Tool { + Tool::new("deploy_service").with_description("Deploys a service") + } + + async fn call(&self, _invocation: ToolInvocation) -> Result { + let _ = self.call_tx.send(()); + Ok(expanded( + "Deployment rejected: policy violation - production deployments require approval", + "rejected", + )) + } +} + +struct AccessSecretTool { + call_tx: mpsc::UnboundedSender<()>, +} + +#[async_trait::async_trait] +impl ToolHandler for AccessSecretTool { + fn tool(&self) -> Tool { + Tool::new("access_secret").with_description("Accesses a secret") + } + + async fn call(&self, _invocation: ToolInvocation) -> Result { + let _ = self.call_tx.send(()); + Ok(expanded( + "Access denied: insufficient permissions to read secrets", + "denied", + )) + } +} + +fn string_tool( + name: &str, + description: &str, + parameter: &str, + parameter_description: &str, +) -> Tool { + Tool::new(name) + .with_description(description) + .with_parameters(json!({ + "type": "object", + "properties": { + parameter: { + "type": "string", + "description": parameter_description, + } + }, + "required": [parameter], + })) +} diff --git a/rust/tests/e2e/tools.rs b/rust/tests/e2e/tools.rs new file mode 100644 index 000000000..19cc40249 --- /dev/null +++ b/rust/tests/e2e/tools.rs @@ -0,0 +1,756 @@ +use std::sync::Arc; + +use github_copilot_sdk::handler::{ApproveAllHandler, PermissionResult, SessionHandler}; +use github_copilot_sdk::tool::{ToolHandler, ToolHandlerRouter}; +use github_copilot_sdk::{ + Error, PermissionRequestData, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, + ToolResult, +}; +use serde_json::json; +use tokio::sync::mpsc; + +use super::support::{assistant_message_content, recv_with_timeout, with_e2e_context}; + +#[tokio::test] +async fn invokes_built_in_tools() { + with_e2e_context("tools", "invokes_built_in_tools", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + std::fs::write( + ctx.work_dir().join("README.md"), + "# ELIZA, the only chatbot you'll ever need", + ) + .expect("write README"); + let client = ctx.start_client().await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What's the first line of README.md in this directory?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("ELIZA")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn invokes_custom_tool() { + with_e2e_context("tools", "invokes_custom_tool", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let router = ToolHandlerRouter::new( + vec![Box::new(EncryptStringTool)], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Use encrypt_string to encrypt this string: Hello") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("HELLO")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn handles_tool_calling_errors() { + with_e2e_context("tools", "handles_tool_calling_errors", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let router = + ToolHandlerRouter::new(vec![Box::new(ErrorTool)], Arc::new(ApproveAllHandler)); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("What is my location? If you can't find out, just say 'unknown'.") + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&answer); + assert!(!content.contains("Melbourne")); + assert!(content.to_lowercase().contains("unknown")); + + let exchanges = ctx.exchanges(); + let tool_results: Vec<_> = exchanges + .last() + .and_then(|exchange| exchange.get("request")) + .and_then(|request| request.get("messages")) + .and_then(serde_json::Value::as_array) + .expect("messages") + .iter() + .filter(|message| { + message.get("role").and_then(serde_json::Value::as_str) == Some("tool") + }) + .collect(); + assert_eq!(tool_results.len(), 1); + assert!(!tool_results[0].to_string().contains("Melbourne")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn can_receive_and_return_complex_types() { + with_e2e_context("tools", "can_receive_and_return_complex_types", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let router = ToolHandlerRouter::new( + vec![Box::new(DbQueryTool)], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "Perform a DB query for the 'cities' table using IDs 12 and 19, sorting ascending. \ + Reply only with lines of the form: [cityname] [population]", + ) + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&answer); + assert!(content.contains("Passos")); + assert!(content.contains("San Lorenzo")); + assert!(content.replace(',', "").contains("135460")); + assert!(content.replace(',', "").contains("204356")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn overrides_built_in_tool_with_custom_tool() { + with_e2e_context("tools", "overrides_built_in_tool_with_custom_tool", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let router = + ToolHandlerRouter::new(vec![Box::new(CustomGrepTool)], Arc::new(ApproveAllHandler)); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Use grep to search for the word 'hello'") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("CUSTOM_GREP_RESULT")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn skippermission_sent_in_tool_definition() { + with_e2e_context("tools", "skippermission_sent_in_tool_definition", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (permission_tx, mut permission_rx) = mpsc::unbounded_channel(); + let handler = Arc::new(RecordingPermissionHandler { + permission_tx, + decision: PermissionResult::Denied, + }); + let router = ToolHandlerRouter::new(vec![Box::new(SafeLookupTool)], handler); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Use safe_lookup to look up 'test123'") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("RESULT")); + assert!( + tokio::time::timeout(std::time::Duration::from_millis(100), permission_rx.recv()) + .await + .is_err(), + "skip_permission tool should not request permission" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[ignore = "Behaves as if no content was in the result. Binary tool results are not fully implemented yet."] +#[tokio::test] +async fn can_return_binary_result() {} + +#[tokio::test] +async fn invokes_custom_tool_with_permission_handler() { + with_e2e_context( + "tools", + "invokes_custom_tool_with_permission_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (permission_tx, mut permission_rx) = mpsc::unbounded_channel(); + let handler = Arc::new(RecordingPermissionHandler { + permission_tx, + decision: PermissionResult::Approved, + }); + let router = ToolHandlerRouter::new(vec![Box::new(EncryptStringTool)], handler); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Use encrypt_string to encrypt this string: Hello") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("HELLO")); + let request = recv_with_timeout(&mut permission_rx, "custom tool permission").await; + assert!(request.extra.is_object() || request.kind.is_some()); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn denies_custom_tool_when_permission_denied() { + with_e2e_context( + "tools", + "denies_custom_tool_when_permission_denied", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (call_tx, mut call_rx) = mpsc::unbounded_channel(); + let (permission_tx, _permission_rx) = mpsc::unbounded_channel(); + let handler = Arc::new(RecordingPermissionHandler { + permission_tx, + decision: PermissionResult::Denied, + }); + let router = ToolHandlerRouter::new( + vec![Box::new(TrackedEncryptStringTool { call_tx })], + handler, + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + session + .send_and_wait("Use encrypt_string to encrypt this string: Hello") + .await + .expect("send"); + assert!( + tokio::time::timeout(std::time::Duration::from_millis(100), call_rx.recv()) + .await + .is_err(), + "denied custom tool should not be invoked" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_execute_multiple_custom_tools_in_parallel_single_turn() { + with_e2e_context( + "tools", + "should_execute_multiple_custom_tools_in_parallel_single_turn", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (city_tx, mut city_rx) = mpsc::unbounded_channel(); + let (country_tx, mut country_rx) = mpsc::unbounded_channel(); + let router = ToolHandlerRouter::new( + vec![ + Box::new(LookupCityTool { call_tx: city_tx }), + Box::new(LookupCountryTool { + call_tx: country_tx, + }), + ], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait("Use lookup_city with 'Paris' and lookup_country with 'France' at the same time, then combine both results in your reply.") + .await + .expect("send") + .expect("assistant message"); + assert_eq!(recv_with_timeout(&mut city_rx, "city tool").await, "Paris"); + assert_eq!( + recv_with_timeout(&mut country_rx, "country tool").await, + "France" + ); + let content = assistant_message_content(&answer); + assert!(content.contains("CITY_PARIS")); + assert!(content.contains("COUNTRY_FRANCE")); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_respect_availabletools_and_excludedtools_combined() { + with_e2e_context( + "tools", + "should_respect_availabletools_and_excludedtools_combined", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let (excluded_tx, mut excluded_rx) = mpsc::unbounded_channel(); + let router = ToolHandlerRouter::new( + vec![ + Box::new(AllowedTool), + Box::new(ExcludedTool { + call_tx: excluded_tx, + }), + ], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_handler(Arc::new(router)) + .with_tools(tools) + .with_available_tools(["allowed_tool", "excluded_tool"]) + .with_excluded_tools(["excluded_tool"]), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "Use the allowed_tool with input 'test'. Do NOT use excluded_tool.", + ) + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&answer).contains("ALLOWED_TEST")); + assert!( + tokio::time::timeout(std::time::Duration::from_millis(100), excluded_rx.recv()) + .await + .is_err(), + "excluded tool should not be invoked" + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +struct EncryptStringTool; + +#[async_trait::async_trait] +impl ToolHandler for EncryptStringTool { + fn tool(&self) -> Tool { + Tool::new("encrypt_string") + .with_description("Encrypts a string") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "String to encrypt" + } + }, + "required": ["input"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let input = invocation + .arguments + .get("input") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + Ok(ToolResult::Text(input.to_uppercase())) + } +} + +struct TrackedEncryptStringTool { + call_tx: mpsc::UnboundedSender<()>, +} + +#[async_trait::async_trait] +impl ToolHandler for TrackedEncryptStringTool { + fn tool(&self) -> Tool { + EncryptStringTool.tool() + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let _ = self.call_tx.send(()); + EncryptStringTool.call(invocation).await + } +} + +struct ErrorTool; + +#[async_trait::async_trait] +impl ToolHandler for ErrorTool { + fn tool(&self) -> Tool { + Tool::new("get_user_location").with_description("Gets the user's location") + } + + async fn call(&self, _invocation: ToolInvocation) -> Result { + Ok(ToolResult::Text( + "Failed to execute `get_user_location` tool with arguments: {} due to error: Error: Tool execution failed" + .to_string(), + )) + } +} + +struct CustomGrepTool; + +#[async_trait::async_trait] +impl ToolHandler for CustomGrepTool { + fn tool(&self) -> Tool { + Tool::new("grep") + .with_description("A custom grep implementation that overrides the built-in") + .with_overrides_built_in_tool(true) + .with_parameters(json!({ + "type": "object", + "properties": { + "query": { "type": "string", "description": "Search query" } + }, + "required": ["query"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let query = invocation + .arguments + .get("query") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + Ok(ToolResult::Text(format!("CUSTOM_GREP_RESULT: {query}"))) + } +} + +struct SafeLookupTool; + +#[async_trait::async_trait] +impl ToolHandler for SafeLookupTool { + fn tool(&self) -> Tool { + Tool::new("safe_lookup") + .with_description("A tool that skips permission") + .with_skip_permission(true) + .with_parameters(json!({ + "type": "object", + "properties": { + "id": { "type": "string", "description": "Lookup ID" } + }, + "required": ["id"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let id = invocation + .arguments + .get("id") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + Ok(ToolResult::Text(format!("RESULT: {id}"))) + } +} + +struct LookupCityTool { + call_tx: mpsc::UnboundedSender, +} + +#[async_trait::async_trait] +impl ToolHandler for LookupCityTool { + fn tool(&self) -> Tool { + Tool::new("lookup_city") + .with_description("Looks up city information") + .with_parameters(json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "City name" } + }, + "required": ["city"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let city = invocation + .arguments + .get("city") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + let _ = self.call_tx.send(city.clone()); + Ok(ToolResult::Text(format!("CITY_{}", city.to_uppercase()))) + } +} + +struct LookupCountryTool { + call_tx: mpsc::UnboundedSender, +} + +#[async_trait::async_trait] +impl ToolHandler for LookupCountryTool { + fn tool(&self) -> Tool { + Tool::new("lookup_country") + .with_description("Looks up country information") + .with_parameters(json!({ + "type": "object", + "properties": { + "country": { "type": "string", "description": "Country name" } + }, + "required": ["country"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let country = invocation + .arguments + .get("country") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + let _ = self.call_tx.send(country.clone()); + Ok(ToolResult::Text(format!( + "COUNTRY_{}", + country.to_uppercase() + ))) + } +} + +struct AllowedTool; + +#[async_trait::async_trait] +impl ToolHandler for AllowedTool { + fn tool(&self) -> Tool { + Tool::new("allowed_tool") + .with_description("An allowed tool") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string", "description": "Input value" } + }, + "required": ["input"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let input = invocation + .arguments + .get("input") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + Ok(ToolResult::Text(format!( + "ALLOWED_{}", + input.to_uppercase() + ))) + } +} + +struct ExcludedTool { + call_tx: mpsc::UnboundedSender<()>, +} + +#[async_trait::async_trait] +impl ToolHandler for ExcludedTool { + fn tool(&self) -> Tool { + Tool::new("excluded_tool") + .with_description("A tool that should be excluded") + .with_parameters(json!({ + "type": "object", + "properties": { + "input": { "type": "string", "description": "Input value" } + }, + "required": ["input"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let _ = self.call_tx.send(()); + let input = invocation + .arguments + .get("input") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + Ok(ToolResult::Text(format!( + "EXCLUDED_{}", + input.to_uppercase() + ))) + } +} + +struct RecordingPermissionHandler { + permission_tx: mpsc::UnboundedSender, + decision: PermissionResult, +} + +#[async_trait::async_trait] +impl SessionHandler for RecordingPermissionHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let _ = self.permission_tx.send(data); + self.decision.clone() + } +} + +struct DbQueryTool; + +#[async_trait::async_trait] +impl ToolHandler for DbQueryTool { + fn tool(&self) -> Tool { + Tool::new("db_query") + .with_description("Performs a database query") + .with_parameters(json!({ + "type": "object", + "properties": { + "query": { + "type": "object", + "properties": { + "table": { "type": "string" }, + "ids": { + "type": "array", + "items": { "type": "integer" } + }, + "sortAscending": { "type": "boolean" } + }, + "required": ["table", "ids", "sortAscending"] + } + }, + "required": ["query"] + })) + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let query = invocation.arguments.get("query").expect("query argument"); + assert_eq!( + query.get("table").and_then(serde_json::Value::as_str), + Some("cities") + ); + assert_eq!( + query.get("ids").and_then(serde_json::Value::as_array), + Some(&vec![json!(12), json!(19)]) + ); + assert_eq!( + query + .get("sortAscending") + .and_then(serde_json::Value::as_bool), + Some(true) + ); + Ok(ToolResult::Text( + r#"[{"cityName":"Passos","countryId":19,"population":135460},{"cityName":"San Lorenzo","countryId":12,"population":204356}]"# + .to_string(), + )) + } +} diff --git a/rust/tests/mode_handlers_e2e_test.rs b/rust/tests/mode_handlers_e2e_test.rs deleted file mode 100644 index 419124850..000000000 --- a/rust/tests/mode_handlers_e2e_test.rs +++ /dev/null @@ -1,663 +0,0 @@ -#![allow(clippy::unwrap_used)] - -use std::io::{BufRead, BufReader, Read, Write}; -use std::net::TcpStream; -use std::path::{Path, PathBuf}; -use std::process::{Child, Command, Stdio}; -use std::sync::Arc; -use std::time::Duration; - -use async_trait::async_trait; -use github_copilot_sdk::generated::session_events::{ - AutoModeSwitchCompletedData, AutoModeSwitchRequestedData, ExitPlanModeCompletedData, - ExitPlanModeRequestedData, SessionEventType, SessionModelChangeData, -}; -use github_copilot_sdk::handler::{AutoModeSwitchResponse, ExitPlanModeResult, SessionHandler}; -use github_copilot_sdk::subscription::EventSubscription; -use github_copilot_sdk::{ - CliProgram, Client, ClientOptions, ExitPlanModeData, SessionConfig, SessionEvent, SessionId, -}; -use serde_json::json; -use tokio::sync::mpsc; - -const MODE_HANDLER_TOKEN: &str = "mode-handler-token"; -const PLAN_SUMMARY: &str = "Greeting file implementation plan"; -const PLAN_PROMPT: &str = "Create a brief implementation plan for adding a greeting.txt file, then request approval with exit_plan_mode."; -const AUTO_MODE_PROMPT: &str = - "Explain that auto mode recovered from a rate limit in one short sentence."; - -#[derive(Debug)] -struct ModeHandler { - requests: mpsc::UnboundedSender<(SessionId, ExitPlanModeData)>, -} - -#[derive(Debug)] -struct AutoModeHandler { - requests: mpsc::UnboundedSender<(SessionId, Option, Option)>, -} - -#[async_trait] -impl SessionHandler for ModeHandler { - async fn on_exit_plan_mode( - &self, - session_id: SessionId, - data: ExitPlanModeData, - ) -> ExitPlanModeResult { - let _ = self.requests.send((session_id, data)); - ExitPlanModeResult { - approved: true, - selected_action: Some("interactive".to_string()), - feedback: Some("Approved by the Rust E2E test".to_string()), - } - } -} - -#[async_trait] -impl SessionHandler for AutoModeHandler { - async fn on_auto_mode_switch( - &self, - session_id: SessionId, - error_code: Option, - retry_after_seconds: Option, - ) -> AutoModeSwitchResponse { - let _ = self - .requests - .send((session_id, error_code, retry_after_seconds)); - AutoModeSwitchResponse::Yes - } -} - -#[tokio::test] -#[ignore] // requires the Node CLI and shared replay proxy dependencies -async fn should_invoke_exit_plan_mode_handler_when_model_uses_tool() { - let repo_root = repo_root(); - let cli_path = repo_root - .join("nodejs") - .join("node_modules") - .join("@github") - .join("copilot") - .join("index.js"); - assert!( - cli_path.exists(), - "CLI not found at {}; run npm install in nodejs first", - cli_path.display() - ); - - let home_dir = tempfile::tempdir().expect("create home dir"); - let work_dir = tempfile::tempdir().expect("create work dir"); - let mut proxy = CapiProxy::start(&repo_root).expect("start replay proxy"); - proxy - .configure( - &repo_root - .join("test") - .join("snapshots") - .join("mode_handlers") - .join("should_invoke_exit_plan_mode_handler_when_model_uses_tool.yaml"), - work_dir.path(), - ) - .expect("configure replay proxy"); - proxy - .set_copilot_user_by_token( - MODE_HANDLER_TOKEN, - json!({ - "login": "mode-handler-user", - "copilot_plan": "individual_pro", - "endpoints": { - "api": proxy.url(), - "telemetry": "https://localhost:1/telemetry" - }, - "analytics_tracking_id": "mode-handler-tracking-id" - }), - ) - .expect("configure copilot user"); - - let mut env = proxy.proxy_env(); - env.extend([ - ("COPILOT_API_URL".into(), proxy.url().into()), - ("COPILOT_DEBUG_GITHUB_API_URL".into(), proxy.url().into()), - ( - "COPILOT_HOME".into(), - home_dir.path().as_os_str().to_owned(), - ), - ( - "GH_CONFIG_DIR".into(), - home_dir.path().as_os_str().to_owned(), - ), - ( - "XDG_CONFIG_HOME".into(), - home_dir.path().as_os_str().to_owned(), - ), - ( - "XDG_STATE_HOME".into(), - home_dir.path().as_os_str().to_owned(), - ), - ]); - - let client = Client::start( - ClientOptions::new() - .with_program(CliProgram::Path(PathBuf::from(node_program()))) - .with_prefix_args([cli_path.as_os_str().to_owned()]) - .with_cwd(work_dir.path()) - .with_env(env) - .with_use_logged_in_user(false), - ) - .await - .expect("start client"); - - let (request_tx, mut request_rx) = mpsc::unbounded_channel(); - let session = client - .create_session( - SessionConfig::default() - .with_github_token(MODE_HANDLER_TOKEN) - .with_handler(Arc::new(ModeHandler { - requests: request_tx, - })) - .approve_all_permissions(), - ) - .await - .expect("create session"); - - let requested_event = tokio::spawn(wait_for_event( - session.subscribe(), - "exit_plan_mode.requested event", - |event| { - event.parsed_type() == SessionEventType::ExitPlanModeRequested - && event - .typed_data::() - .is_some_and(|data| data.summary == PLAN_SUMMARY) - }, - )); - let completed_event = tokio::spawn(wait_for_event( - session.subscribe(), - "exit_plan_mode.completed event", - |event| { - event.parsed_type() == SessionEventType::ExitPlanModeCompleted - && event - .typed_data::() - .is_some_and(|data| { - data.approved == Some(true) - && data.selected_action.as_deref() == Some("interactive") - }) - }, - )); - let idle_event = tokio::spawn(wait_for_event( - session.subscribe(), - "session.idle event", - |event| event.parsed_type() == SessionEventType::SessionIdle, - )); - - let send_result = session - .client() - .call( - "session.send", - Some(json!({ - "sessionId": session.id().as_str(), - "prompt": PLAN_PROMPT, - "mode": "plan", - })), - ) - .await - .expect("send plan-mode prompt"); - assert!( - send_result.get("messageId").is_some(), - "expected messageId in send result" - ); - - let (session_id, request) = tokio::time::timeout(Duration::from_secs(10), request_rx.recv()) - .await - .expect("timed out waiting for exit-plan-mode request") - .expect("exit-plan-mode request channel closed"); - assert_eq!(session_id, session.id().clone()); - assert_eq!(request.summary, PLAN_SUMMARY); - assert_eq!( - request.actions, - ["interactive", "autopilot", "exit_only"].map(str::to_string) - ); - assert_eq!(request.recommended_action, "interactive"); - - let requested = requested_event - .await - .expect("requested task") - .expect("requested event"); - let requested_data = requested - .typed_data::() - .expect("typed requested event"); - assert_eq!(requested_data.summary, request.summary); - assert_eq!(requested_data.actions, request.actions); - assert_eq!( - requested_data.recommended_action, - request.recommended_action - ); - - let completed = completed_event - .await - .expect("completed task") - .expect("completed event"); - let completed_data = completed - .typed_data::() - .expect("typed completed event"); - assert_eq!(completed_data.approved, Some(true)); - assert_eq!( - completed_data.selected_action.as_deref(), - Some("interactive") - ); - assert_eq!( - completed_data.feedback.as_deref(), - Some("Approved by the Rust E2E test") - ); - idle_event.await.expect("idle task").expect("idle event"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - proxy.stop(true).expect("stop replay proxy"); -} - -#[tokio::test] -#[ignore] // requires the Node CLI and shared replay proxy dependencies -async fn should_invoke_auto_mode_switch_handler_when_rate_limited() { - let repo_root = repo_root(); - let cli_path = repo_root - .join("nodejs") - .join("node_modules") - .join("@github") - .join("copilot") - .join("index.js"); - assert!( - cli_path.exists(), - "CLI not found at {}; run npm install in nodejs first", - cli_path.display() - ); - - let home_dir = tempfile::tempdir().expect("create home dir"); - let work_dir = tempfile::tempdir().expect("create work dir"); - let mut proxy = CapiProxy::start(&repo_root).expect("start replay proxy"); - proxy - .configure( - &repo_root - .join("test") - .join("snapshots") - .join("mode_handlers") - .join("should_invoke_auto_mode_switch_handler_when_rate_limited.yaml"), - work_dir.path(), - ) - .expect("configure replay proxy"); - proxy - .set_copilot_user_by_token( - MODE_HANDLER_TOKEN, - json!({ - "login": "mode-handler-user", - "copilot_plan": "individual_pro", - "endpoints": { - "api": proxy.url(), - "telemetry": "https://localhost:1/telemetry" - }, - "analytics_tracking_id": "mode-handler-tracking-id" - }), - ) - .expect("configure copilot user"); - - let mut env = proxy.proxy_env(); - env.extend([ - ("COPILOT_API_URL".into(), proxy.url().into()), - ("COPILOT_DEBUG_GITHUB_API_URL".into(), proxy.url().into()), - ( - "COPILOT_HOME".into(), - home_dir.path().as_os_str().to_owned(), - ), - ( - "GH_CONFIG_DIR".into(), - home_dir.path().as_os_str().to_owned(), - ), - ( - "XDG_CONFIG_HOME".into(), - home_dir.path().as_os_str().to_owned(), - ), - ( - "XDG_STATE_HOME".into(), - home_dir.path().as_os_str().to_owned(), - ), - ]); - - let client = Client::start( - ClientOptions::new() - .with_program(CliProgram::Path(PathBuf::from(node_program()))) - .with_prefix_args([cli_path.as_os_str().to_owned()]) - .with_cwd(work_dir.path()) - .with_env(env) - .with_use_logged_in_user(false), - ) - .await - .expect("start client"); - - let (request_tx, mut request_rx) = mpsc::unbounded_channel(); - let session = client - .create_session( - SessionConfig::default() - .with_github_token(MODE_HANDLER_TOKEN) - .with_handler(Arc::new(AutoModeHandler { - requests: request_tx, - })) - .approve_all_permissions(), - ) - .await - .expect("create session"); - - let requested_event = tokio::spawn(wait_for_event_allowing_rate_limit( - session.subscribe(), - "auto_mode_switch.requested event", - |event| { - event.parsed_type() == SessionEventType::AutoModeSwitchRequested - && event - .typed_data::() - .is_some_and(|data| { - data.error_code.as_deref() == Some("user_weekly_rate_limited") - && data.retry_after_seconds == Some(1.0) - }) - }, - )); - let completed_event = tokio::spawn(wait_for_event_allowing_rate_limit( - session.subscribe(), - "auto_mode_switch.completed event", - |event| { - event.parsed_type() == SessionEventType::AutoModeSwitchCompleted - && event - .typed_data::() - .is_some_and(|data| data.response == "yes") - }, - )); - let model_change_event = tokio::spawn(wait_for_event_allowing_rate_limit( - session.subscribe(), - "rate-limit auto-mode model change", - |event| { - event.parsed_type() == SessionEventType::SessionModelChange - && event - .typed_data::() - .is_some_and(|data| data.cause.as_deref() == Some("rate_limit_auto_switch")) - }, - )); - let idle_event = tokio::spawn(wait_for_event_allowing_rate_limit( - session.subscribe(), - "session.idle after auto-mode switch", - |event| event.parsed_type() == SessionEventType::SessionIdle, - )); - - let message_id = session - .send(AUTO_MODE_PROMPT) - .await - .expect("send auto-mode-switch prompt"); - assert!(!message_id.is_empty(), "expected message ID"); - - let (session_id, error_code, retry_after_seconds) = - tokio::time::timeout(Duration::from_secs(10), request_rx.recv()) - .await - .expect("timed out waiting for auto-mode-switch request") - .expect("auto-mode-switch request channel closed"); - assert_eq!(session_id, session.id().clone()); - assert_eq!(error_code.as_deref(), Some("user_weekly_rate_limited")); - assert_eq!(retry_after_seconds, Some(1.0)); - - let requested = requested_event - .await - .expect("requested task") - .expect("requested event"); - let requested_data = requested - .typed_data::() - .expect("typed requested event"); - assert_eq!(requested_data.error_code, error_code); - assert_eq!(requested_data.retry_after_seconds, retry_after_seconds); - - let completed = completed_event - .await - .expect("completed task") - .expect("completed event"); - let completed_data = completed - .typed_data::() - .expect("typed completed event"); - assert_eq!(completed_data.response, "yes"); - - let model_change = model_change_event - .await - .expect("model change task") - .expect("model change event"); - let model_change_data = model_change - .typed_data::() - .expect("typed model change event"); - assert_eq!( - model_change_data.cause.as_deref(), - Some("rate_limit_auto_switch") - ); - idle_event.await.expect("idle task").expect("idle event"); - - session.disconnect().await.expect("disconnect session"); - client.stop().await.expect("stop client"); - proxy.stop(true).expect("stop replay proxy"); -} - -async fn wait_for_event( - mut events: EventSubscription, - description: &'static str, - predicate: fn(&SessionEvent) -> bool, -) -> Result { - tokio::time::timeout(Duration::from_secs(30), async { - loop { - let event = events.recv().await.map_err(|err| { - format!("event stream closed while waiting for {description}: {err}") - })?; - if event.parsed_type() == SessionEventType::SessionError { - return Err(format!( - "session.error while waiting for {description}: {}", - event.data - )); - } - if predicate(&event) { - return Ok(event); - } - } - }) - .await - .map_err(|_| format!("timed out waiting for {description}"))? -} - -async fn wait_for_event_allowing_rate_limit( - mut events: EventSubscription, - description: &'static str, - predicate: fn(&SessionEvent) -> bool, -) -> Result { - tokio::time::timeout(Duration::from_secs(30), async { - loop { - let event = events.recv().await.map_err(|err| { - format!("event stream closed while waiting for {description}: {err}") - })?; - if event.parsed_type() == SessionEventType::SessionError - && event.data.get("errorType").and_then(|value| value.as_str()) - != Some("rate_limit") - { - return Err(format!( - "session.error while waiting for {description}: {}", - event.data - )); - } - if predicate(&event) { - return Ok(event); - } - } - }) - .await - .map_err(|_| format!("timed out waiting for {description}"))? -} - -fn repo_root() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .parent() - .expect("rust package has parent repo") - .to_path_buf() -} - -struct CapiProxy { - child: Option, - proxy_url: String, - connect_proxy_url: String, - ca_file_path: String, -} - -impl CapiProxy { - fn start(repo_root: &Path) -> std::io::Result { - let mut child = Command::new(npx_program()) - .args(["tsx", "server.ts"]) - .current_dir(repo_root.join("test").join("harness")) - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) - .spawn()?; - - let stdout = child.stdout.take().expect("proxy stdout"); - let reader = BufReader::new(stdout); - let re = regex::Regex::new(r"Listening: (http://[^\s]+)\s+(\{.*\})$").unwrap(); - for line in reader.lines() { - let line = line?; - if let Some(captures) = re.captures(&line) { - let metadata: serde_json::Value = - serde_json::from_str(captures.get(2).unwrap().as_str())?; - let connect_proxy_url = metadata - .get("connectProxyUrl") - .and_then(|value| value.as_str()) - .expect("connectProxyUrl") - .to_string(); - let ca_file_path = metadata - .get("caFilePath") - .and_then(|value| value.as_str()) - .expect("caFilePath") - .to_string(); - return Ok(Self { - child: Some(child), - proxy_url: captures.get(1).unwrap().as_str().to_string(), - connect_proxy_url, - ca_file_path, - }); - } - if line.contains("Listening: ") { - return Err(std::io::Error::other(format!( - "proxy startup line missing metadata: {line}" - ))); - } - } - - Err(std::io::Error::other("proxy exited before startup")) - } - - fn url(&self) -> &str { - &self.proxy_url - } - - fn configure(&self, file_path: &Path, work_dir: &Path) -> std::io::Result<()> { - self.post_json( - "/config", - &json!({ - "filePath": file_path, - "workDir": work_dir, - }) - .to_string(), - ) - } - - fn set_copilot_user_by_token( - &self, - token: &str, - response: serde_json::Value, - ) -> std::io::Result<()> { - self.post_json( - "/copilot-user-config", - &json!({ - "token": token, - "response": response, - }) - .to_string(), - ) - } - - fn stop(&mut self, skip_writing_cache: bool) -> std::io::Result<()> { - let path = if skip_writing_cache { - "/stop?skipWritingCache=true" - } else { - "/stop" - }; - let result = self.post_json(path, ""); - if let Some(mut child) = self.child.take() { - let _ = child.wait(); - } - result - } - - fn proxy_env(&self) -> Vec<(std::ffi::OsString, std::ffi::OsString)> { - let no_proxy = "127.0.0.1,localhost,::1"; - [ - ("HTTP_PROXY", self.connect_proxy_url.as_str()), - ("HTTPS_PROXY", self.connect_proxy_url.as_str()), - ("http_proxy", self.connect_proxy_url.as_str()), - ("https_proxy", self.connect_proxy_url.as_str()), - ("NO_PROXY", no_proxy), - ("no_proxy", no_proxy), - ("NODE_EXTRA_CA_CERTS", self.ca_file_path.as_str()), - ("SSL_CERT_FILE", self.ca_file_path.as_str()), - ("REQUESTS_CA_BUNDLE", self.ca_file_path.as_str()), - ("CURL_CA_BUNDLE", self.ca_file_path.as_str()), - ("GIT_SSL_CAINFO", self.ca_file_path.as_str()), - ("GH_TOKEN", ""), - ("GITHUB_TOKEN", ""), - ("GH_ENTERPRISE_TOKEN", ""), - ("GITHUB_ENTERPRISE_TOKEN", ""), - ] - .into_iter() - .map(|(key, value)| (key.into(), value.into())) - .collect() - } - - fn post_json(&self, path: &str, body: &str) -> std::io::Result<()> { - let (host, port) = parse_http_url(&self.proxy_url)?; - let mut stream = TcpStream::connect((host.as_str(), port))?; - write!( - stream, - "POST {path} HTTP/1.1\r\nHost: {host}:{port}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", - body.len() - )?; - - let mut response = String::new(); - stream.read_to_string(&mut response)?; - if !response.starts_with("HTTP/1.1 200") && !response.starts_with("HTTP/1.1 204") { - return Err(std::io::Error::other(format!( - "proxy POST {path} failed: {response}" - ))); - } - Ok(()) - } -} - -impl Drop for CapiProxy { - fn drop(&mut self) { - if self.child.is_some() { - let _ = self.stop(true); - } - } -} - -fn node_program() -> &'static str { - if cfg!(windows) { "node.exe" } else { "node" } -} - -fn npx_program() -> &'static str { - if cfg!(windows) { "npx.cmd" } else { "npx" } -} - -fn parse_http_url(url: &str) -> std::io::Result<(String, u16)> { - let without_scheme = url - .strip_prefix("http://") - .ok_or_else(|| std::io::Error::other(format!("expected http URL, got {url}")))?; - let authority = without_scheme.split('/').next().unwrap_or(without_scheme); - let (host, port) = authority - .rsplit_once(':') - .ok_or_else(|| std::io::Error::other(format!("missing port in URL {url}")))?; - let port = port - .parse() - .map_err(|err| std::io::Error::other(format!("invalid port in URL {url}: {err}")))?; - Ok((host.to_string(), port)) -} diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 74c6eb90b..c98c04d89 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -137,12 +137,11 @@ async fn create_session_pair_with_capabilities( capabilities: Value, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); - let session_id = format!("test-session-{}", rand_id()); let mut server = FakeServer { read: server_read, write: server_write, - session_id: session_id.clone(), + session_id: String::new(), }; let create_handle = tokio::spawn({ @@ -158,8 +157,9 @@ async fn create_session_pair_with_capabilities( let create_req = server.read_request().await; assert_eq!(create_req["method"], "session.create"); + server.session_id = requested_session_id(&create_req).to_string(); let mut result = serde_json::json!({ - "sessionId": session_id, + "sessionId": server.session_id.clone(), "workspacePath": "/tmp/workspace" }); if !capabilities.is_null() { @@ -176,6 +176,12 @@ fn rand_id() -> u64 { COUNTER.fetch_add(1, Ordering::Relaxed) as u64 } +fn requested_session_id(request: &Value) -> &str { + request["params"]["sessionId"] + .as_str() + .expect("session request should include sessionId") +} + #[tokio::test] async fn session_subscribe_yields_events_observe_only() { let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; @@ -263,15 +269,16 @@ async fn create_session_sends_correct_rpc() { assert_eq!(request["params"]["model"], "gpt-4"); let id = request["id"].as_u64().unwrap(); + let session_id = requested_session_id(&request).to_string(); let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, - "result": { "sessionId": "s1", "workspacePath": "/ws" }, + "result": { "sessionId": session_id.clone(), "workspacePath": "/ws" }, }); write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); - assert_eq!(session.id(), "s1"); + assert_eq!(session.id(), session_id.as_str()); assert_eq!(session.workspace_path(), Some(Path::new("/ws"))); } @@ -1404,7 +1411,8 @@ async fn router_routes_to_correct_session() { // Create two sessions on the same client let mut sessions = Vec::new(); - for (tx, sid) in [(tx1, "s-one"), (tx2, "s-two")] { + let mut session_ids = Vec::new(); + for tx in [tx1, tx2] { let h = tokio::spawn({ let client = client.clone(); async move { @@ -1418,11 +1426,13 @@ async fn router_routes_to_correct_session() { }); let req = read_framed(&mut server_read).await; let id = req["id"].as_u64().unwrap(); + let session_id = requested_session_id(&req).to_string(); let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, - "result": { "sessionId": sid }, + "result": { "sessionId": session_id.clone() }, }); write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + session_ids.push(session_id); sessions.push(timeout(TIMEOUT, h).await.unwrap().unwrap()); } @@ -1431,7 +1441,7 @@ async fn router_routes_to_correct_session() { "jsonrpc": "2.0", "method": "session.event", "params": { - "sessionId": "s-two", + "sessionId": session_ids[1].clone(), "event": { "id": "e1", "timestamp": "2025-01-01T00:00:00Z", "type": "assistant.message", "data": {} }, }, }); @@ -1447,7 +1457,7 @@ async fn router_routes_to_correct_session() { "jsonrpc": "2.0", "method": "session.event", "params": { - "sessionId": "s-one", + "sessionId": session_ids[0].clone(), "event": { "id": "e2", "timestamp": "2025-01-01T00:00:00Z", "type": "session.idle", "data": {} }, }, }); @@ -1982,11 +1992,12 @@ async fn capabilities_captured_from_create_response() { let request = read_framed(&mut server_read).await; let id = request["id"].as_u64().unwrap(); + let session_id = requested_session_id(&request); let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": { - "sessionId": "cap-session", + "sessionId": session_id, "capabilities": { "ui": { "elicitation": true } } @@ -2053,10 +2064,11 @@ async fn request_elicitation_sent_in_create_params() { assert_eq!(request["params"]["requestAutoModeSwitch"], true); let id = request["id"].as_u64().unwrap(); + let session_id = requested_session_id(&request); let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, - "result": { "sessionId": "s-elicit" }, + "result": { "sessionId": session_id }, }); write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); @@ -2083,18 +2095,20 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { assert_eq!(request["params"]["envValueMode"], "direct"); let id = request["id"].as_u64().unwrap(); + let session_id = requested_session_id(&request).to_string(); let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, - "result": { "sessionId": "s-env-create" }, + "result": { "sessionId": session_id.clone() }, }); write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); let resume_handle = tokio::spawn({ let client = client.clone(); + let session_id = session_id.clone(); async move { - let cfg = ResumeSessionConfig::new(SessionId::from("s-env-create")) + let cfg = ResumeSessionConfig::new(SessionId::from(session_id)) .with_handler(Arc::new(NoopHandler)); client.resume_session(cfg).await.unwrap() } @@ -2108,7 +2122,7 @@ async fn env_value_mode_hardcoded_direct_on_create_and_resume() { let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, - "result": { "sessionId": "s-env-create" }, + "result": { "sessionId": session_id }, }); write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; @@ -2153,12 +2167,11 @@ async fn create_session_pair_with_hooks( hooks: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); - let session_id = format!("test-session-{}", rand_id()); let mut server = FakeServer { read: server_read, write: server_write, - session_id: session_id.clone(), + session_id: String::new(), }; let create_handle = tokio::spawn({ @@ -2180,11 +2193,12 @@ async fn create_session_pair_with_hooks( assert_eq!(create_req["method"], "session.create"); // Verify hooks: true is auto-set in the config assert_eq!(create_req["params"]["hooks"], true); + server.session_id = requested_session_id(&create_req).to_string(); server .respond( &create_req, serde_json::json!({ - "sessionId": session_id, + "sessionId": server.session_id, "workspacePath": "/tmp/workspace" }), ) @@ -2286,12 +2300,11 @@ async fn create_session_pair_with_transforms( transforms: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); - let session_id = format!("test-session-{}", rand_id()); let mut server = FakeServer { read: server_read, write: server_write, - session_id: session_id.clone(), + session_id: String::new(), }; let create_handle = tokio::spawn({ @@ -2313,11 +2326,12 @@ async fn create_session_pair_with_transforms( assert_eq!(create_req["method"], "session.create"); // Verify transforms inject customize mode and section overrides assert_eq!(create_req["params"]["systemMessage"]["mode"], "customize"); + server.session_id = requested_session_id(&create_req).to_string(); server .respond( &create_req, serde_json::json!({ - "sessionId": session_id, + "sessionId": server.session_id, "workspacePath": "/tmp/workspace" }), ) @@ -2473,13 +2487,11 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { // One client, two registered sessions. Client::stop must send // session.destroy for each before returning Ok. let (client, server_read, server_write) = make_client(); - let session_id_a = format!("test-session-{}", rand_id()); - let session_id_b = format!("test-session-{}", rand_id()); let mut server = FakeServer { read: server_read, write: server_write, - session_id: session_id_a.clone(), + session_id: String::new(), }; // Spawn both create_session calls. @@ -2494,10 +2506,11 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { }); let create_a_req = server.read_request().await; assert_eq!(create_a_req["method"], "session.create"); + let session_id_a = requested_session_id(&create_a_req).to_string(); server .respond( &create_a_req, - serde_json::json!({ "sessionId": session_id_a, "workspacePath": "/tmp/ws-a" }), + serde_json::json!({ "sessionId": session_id_a.clone(), "workspacePath": "/tmp/ws-a" }), ) .await; let _session_a = timeout(TIMEOUT, create_a).await.unwrap(); @@ -2513,10 +2526,11 @@ async fn client_stop_sends_session_destroy_for_each_active_session() { }); let create_b_req = server.read_request().await; assert_eq!(create_b_req["method"], "session.create"); + let session_id_b = requested_session_id(&create_b_req).to_string(); server .respond( &create_b_req, - serde_json::json!({ "sessionId": session_id_b, "workspacePath": "/tmp/ws-b" }), + serde_json::json!({ "sessionId": session_id_b.clone(), "workspacePath": "/tmp/ws-b" }), ) .await; let _session_b = timeout(TIMEOUT, create_b).await.unwrap(); @@ -2657,12 +2671,11 @@ async fn create_session_pair_with_commands( commands: Vec, ) -> (github_copilot_sdk::session::Session, FakeServer, Value) { let (client, server_read, server_write) = make_client(); - let session_id = format!("test-session-{}", rand_id()); let mut server = FakeServer { read: server_read, write: server_write, - session_id: session_id.clone(), + session_id: String::new(), }; let create_handle = tokio::spawn({ @@ -2682,11 +2695,12 @@ async fn create_session_pair_with_commands( let create_req = server.read_request().await; assert_eq!(create_req["method"], "session.create"); + server.session_id = requested_session_id(&create_req).to_string(); server .respond( &create_req, serde_json::json!({ - "sessionId": session_id, + "sessionId": server.session_id, "workspacePath": "/tmp/workspace" }), ) @@ -2954,12 +2968,11 @@ async fn create_session_pair_with_fs_provider( provider: Arc, ) -> (github_copilot_sdk::session::Session, FakeServer) { let (client, server_read, server_write) = make_client(); - let session_id = format!("test-session-{}", rand_id()); let mut server = FakeServer { read: server_read, write: server_write, - session_id: session_id.clone(), + session_id: String::new(), }; let create_handle = tokio::spawn({ @@ -2979,11 +2992,12 @@ async fn create_session_pair_with_fs_provider( let create_req = server.read_request().await; assert_eq!(create_req["method"], "session.create"); + server.session_id = requested_session_id(&create_req).to_string(); server .respond( &create_req, serde_json::json!({ - "sessionId": session_id, + "sessionId": server.session_id, "workspacePath": "/tmp/workspace" }), ) @@ -3217,10 +3231,11 @@ async fn on_get_trace_context_called_on_session_create() { assert_eq!(req["method"], "session.create"); assert_eq!(req["params"]["traceparent"], "00-aaaa-bbbb-01"); assert_eq!(req["params"]["tracestate"], "vendor=value"); + server.session_id = requested_session_id(&req).to_string(); server .respond( &req, - serde_json::json!({"sessionId": "trace-create", "workspacePath": "/tmp/ws"}), + serde_json::json!({"sessionId": server.session_id.clone(), "workspacePath": "/tmp/ws"}), ) .await; timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); @@ -3297,10 +3312,11 @@ async fn on_get_trace_context_called_on_session_send() { } }); let create_req = server.read_request().await; + server.session_id = requested_session_id(&create_req).to_string(); server .respond( &create_req, - serde_json::json!({"sessionId": "trace-send", "workspacePath": "/tmp/ws"}), + serde_json::json!({"sessionId": server.session_id.clone(), "workspacePath": "/tmp/ws"}), ) .await; let session = Arc::new(timeout(TIMEOUT, create_handle).await.unwrap().unwrap()); @@ -3349,10 +3365,11 @@ async fn message_options_trace_context_overrides_callback() { } }); let create_req = server.read_request().await; + server.session_id = requested_session_id(&create_req).to_string(); server .respond( &create_req, - serde_json::json!({"sessionId": "trace-override", "workspacePath": "/tmp/ws"}), + serde_json::json!({"sessionId": server.session_id.clone(), "workspacePath": "/tmp/ws"}), ) .await; let session = Arc::new(timeout(TIMEOUT, create_handle).await.unwrap().unwrap()); diff --git a/test/snapshots/client/listmodels_withcustomhandler_callshandler.yaml b/test/snapshots/client/listmodels_withcustomhandler_callshandler.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/listmodels_withcustomhandler_callshandler.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_force_stop_client.yaml b/test/snapshots/client/should_force_stop_client.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_force_stop_client.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_get_authenticated_status.yaml b/test/snapshots/client/should_get_authenticated_status.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_get_authenticated_status.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_get_status.yaml b/test/snapshots/client/should_get_status.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_get_status.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_list_models_when_authenticated.yaml b/test/snapshots/client/should_list_models_when_authenticated.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_list_models_when_authenticated.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_start_ping_and_stop_stdio_client.yaml b/test/snapshots/client/should_start_ping_and_stop_stdio_client.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_start_ping_and_stop_stdio_client.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_start_ping_and_stop_tcp_client.yaml b/test/snapshots/client/should_start_ping_and_stop_tcp_client.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_start_ping_and_stop_tcp_client.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client/should_stop_client_with_active_session.yaml b/test/snapshots/client/should_stop_client_with_active_session.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client/should_stop_client_with_active_session.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/client_options/should_listen_on_configured_tcp_port.yaml b/test/snapshots/client_options/should_listen_on_configured_tcp_port.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/client_options/should_listen_on_configured_tcp_port.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/commands/session_with_commands_creates_successfully.yaml b/test/snapshots/commands/session_with_commands_creates_successfully.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/commands/session_with_commands_creates_successfully.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/commands/session_with_commands_resumes_successfully.yaml b/test/snapshots/commands/session_with_commands_resumes_successfully.yaml new file mode 100644 index 000000000..0981462bf --- /dev/null +++ b/test/snapshots/commands/session_with_commands_resumes_successfully.yaml @@ -0,0 +1,10 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Say OK. + - role: assistant + content: OK diff --git a/test/snapshots/commands/session_with_no_commands_creates_successfully.yaml b/test/snapshots/commands/session_with_no_commands_creates_successfully.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/commands/session_with_no_commands_creates_successfully.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/confirm_returns_false_when_handler_declines.yaml b/test/snapshots/elicitation/confirm_returns_false_when_handler_declines.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/confirm_returns_false_when_handler_declines.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/confirm_returns_true_when_handler_accepts.yaml b/test/snapshots/elicitation/confirm_returns_true_when_handler_accepts.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/confirm_returns_true_when_handler_accepts.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/defaults_capabilities_when_not_provided.yaml b/test/snapshots/elicitation/defaults_capabilities_when_not_provided.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/defaults_capabilities_when_not_provided.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/elicitation_returns_all_action_shapes.yaml b/test/snapshots/elicitation/elicitation_returns_all_action_shapes.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/elicitation_returns_all_action_shapes.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/elicitation_throws_when_capability_is_missing.yaml b/test/snapshots/elicitation/elicitation_throws_when_capability_is_missing.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/elicitation_throws_when_capability_is_missing.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/input_returns_freeform_value.yaml b/test/snapshots/elicitation/input_returns_freeform_value.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/input_returns_freeform_value.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/select_returns_selected_option.yaml b/test/snapshots/elicitation/select_returns_selected_option.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/select_returns_selected_option.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/sends_requestelicitation_when_handler_provided.yaml b/test/snapshots/elicitation/sends_requestelicitation_when_handler_provided.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/sends_requestelicitation_when_handler_provided.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/session_without_elicitationhandler_creates_successfully.yaml b/test/snapshots/elicitation/session_without_elicitationhandler_creates_successfully.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/session_without_elicitationhandler_creates_successfully.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/should_report_elicitation_capability_based_on_handler_presence.yaml b/test/snapshots/elicitation/should_report_elicitation_capability_based_on_handler_presence.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/should_report_elicitation_capability_based_on_handler_presence.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/mcp_and_agents/should_handle_custom_agent_with_mcp_servers.yaml b/test/snapshots/mcp_and_agents/should_handle_custom_agent_with_mcp_servers.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/mcp_and_agents/should_handle_custom_agent_with_mcp_servers.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/mcp_and_agents/should_handle_custom_agent_with_tools_configuration.yaml b/test/snapshots/mcp_and_agents/should_handle_custom_agent_with_tools_configuration.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/mcp_and_agents/should_handle_custom_agent_with_tools_configuration.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/mcp_and_agents/should_handle_multiple_custom_agents.yaml b/test/snapshots/mcp_and_agents/should_handle_multiple_custom_agents.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/mcp_and_agents/should_handle_multiple_custom_agents.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/mcp_and_agents/should_handle_multiple_mcp_servers.yaml b/test/snapshots/mcp_and_agents/should_handle_multiple_mcp_servers.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/mcp_and_agents/should_handle_multiple_mcp_servers.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/per-session-auth/session_auth_status_is_unauthenticated_without_token.yaml b/test/snapshots/per-session-auth/session_auth_status_is_unauthenticated_without_token.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/per-session-auth/session_auth_status_is_unauthenticated_without_token.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/per-session-auth/session_fails_with_invalid_token.yaml b/test/snapshots/per-session-auth/session_fails_with_invalid_token.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/per-session-auth/session_fails_with_invalid_token.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/per-session-auth/session_token_overrides_client_token.yaml b/test/snapshots/per-session-auth/session_token_overrides_client_token.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/per-session-auth/session_token_overrides_client_token.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/per-session-auth/session_uses_client_token_when_no_session_token_is_supplied.yaml b/test/snapshots/per-session-auth/session_uses_client_token_when_no_session_token_is_supplied.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/per-session-auth/session_uses_client_token_when_no_session_token_is_supplied.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/mode_set_to_same_value_multiple_times_stays_stable.yaml b/test/snapshots/rpc_additional_edge_cases/mode_set_to_same_value_multiple_times_stays_stable.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/mode_set_to_same_value_multiple_times_stays_stable.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/name_set_with_unicode_round_trips.yaml b/test/snapshots/rpc_additional_edge_cases/name_set_with_unicode_round_trips.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/name_set_with_unicode_round_trips.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/permissions_reset_session_approvals_on_fresh_session_is_noop.yaml b/test/snapshots/rpc_additional_edge_cases/permissions_reset_session_approvals_on_fresh_session_is_noop.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/permissions_reset_session_approvals_on_fresh_session_is_noop.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/permissions_set_approve_all_toggle_round_trips.yaml b/test/snapshots/rpc_additional_edge_cases/permissions_set_approve_all_toggle_round_trips.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/permissions_set_approve_all_toggle_round_trips.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/plan_delete_when_none_exists_is_idempotent.yaml b/test/snapshots/rpc_additional_edge_cases/plan_delete_when_none_exists_is_idempotent.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/plan_delete_when_none_exists_is_idempotent.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/plan_update_with_empty_content_then_read_returns_empty.yaml b/test/snapshots/rpc_additional_edge_cases/plan_update_with_empty_content_then_read_returns_empty.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/plan_update_with_empty_content_then_read_returns_empty.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/shell_exec_with_zero_timeout_does_not_kill_long_running_command.yaml b/test/snapshots/rpc_additional_edge_cases/shell_exec_with_zero_timeout_does_not_kill_long_running_command.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/shell_exec_with_zero_timeout_does_not_kill_long_running_command.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/usage_get_metrics_on_fresh_session_returns_zero_tokens.yaml b/test/snapshots/rpc_additional_edge_cases/usage_get_metrics_on_fresh_session_returns_zero_tokens.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/usage_get_metrics_on_fresh_session_returns_zero_tokens.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_empty_content_round_trips.yaml b/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_empty_content_round_trips.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_empty_content_round_trips.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_large_content_round_trips.yaml b/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_large_content_round_trips.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_large_content_round_trips.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_unicode_content_round_trips.yaml b/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_unicode_content_round_trips.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/workspaces_create_file_with_unicode_content_round_trips.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/workspaces_createfile_then_listfiles_returns_sorted_or_stable_order.yaml b/test/snapshots/rpc_additional_edge_cases/workspaces_createfile_then_listfiles_returns_sorted_or_stable_order.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/workspaces_createfile_then_listfiles_returns_sorted_or_stable_order.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_additional_edge_cases/workspaces_getworkspace_returns_stable_result_across_calls.yaml b/test/snapshots/rpc_additional_edge_cases/workspaces_getworkspace_returns_stable_result_across_calls.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_additional_edge_cases/workspaces_getworkspace_returns_stable_result_across_calls.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_call_agent_reload.yaml b/test/snapshots/rpc_agents/should_call_agent_reload.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_call_agent_reload.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_deselect_current_agent.yaml b/test/snapshots/rpc_agents/should_deselect_current_agent.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_deselect_current_agent.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_emit_subagent_selected_and_deselected_events.yaml b/test/snapshots/rpc_agents/should_emit_subagent_selected_and_deselected_events.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_emit_subagent_selected_and_deselected_events.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_list_available_custom_agents.yaml b/test/snapshots/rpc_agents/should_list_available_custom_agents.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_list_available_custom_agents.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_return_empty_list_when_no_custom_agents_configured.yaml b/test/snapshots/rpc_agents/should_return_empty_list_when_no_custom_agents_configured.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_return_empty_list_when_no_custom_agents_configured.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_return_null_when_no_agent_is_selected.yaml b/test/snapshots/rpc_agents/should_return_null_when_no_agent_is_selected.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_return_null_when_no_agent_is_selected.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_agents/should_select_and_get_current_agent.yaml b/test/snapshots/rpc_agents/should_select_and_get_current_agent.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_agents/should_select_and_get_current_agent.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_event_side_effects/should_emit_mode_changed_event_when_mode_set.yaml b/test/snapshots/rpc_event_side_effects/should_emit_mode_changed_event_when_mode_set.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_event_side_effects/should_emit_mode_changed_event_when_mode_set.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_event_side_effects/should_emit_plan_changed_event_for_update_and_delete.yaml b/test/snapshots/rpc_event_side_effects/should_emit_plan_changed_event_for_update_and_delete.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_event_side_effects/should_emit_plan_changed_event_for_update_and_delete.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_event_side_effects/should_emit_plan_changed_update_operation_on_second_update.yaml b/test/snapshots/rpc_event_side_effects/should_emit_plan_changed_update_operation_on_second_update.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_event_side_effects/should_emit_plan_changed_update_operation_on_second_update.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_event_side_effects/should_emit_title_changed_event_when_name_set.yaml b/test/snapshots/rpc_event_side_effects/should_emit_title_changed_event_when_name_set.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_event_side_effects/should_emit_title_changed_event_when_name_set.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_event_side_effects/should_emit_workspace_file_changed_event_when_file_created.yaml b/test/snapshots/rpc_event_side_effects/should_emit_workspace_file_changed_event_when_file_created.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_event_side_effects/should_emit_workspace_file_changed_event_when_file_created.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_list_and_toggle_session_skills.yaml b/test/snapshots/rpc_mcp_and_skills/should_list_and_toggle_session_skills.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_list_and_toggle_session_skills.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_list_extensions.yaml b/test/snapshots/rpc_mcp_and_skills/should_list_extensions.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_list_extensions.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_list_mcp_servers_with_configured_server.yaml b/test/snapshots/rpc_mcp_and_skills/should_list_mcp_servers_with_configured_server.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_list_mcp_servers_with_configured_server.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_list_plugins.yaml b/test/snapshots/rpc_mcp_and_skills/should_list_plugins.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_list_plugins.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_reload_session_skills.yaml b/test/snapshots/rpc_mcp_and_skills/should_reload_session_skills.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_reload_session_skills.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_report_error_when_extensions_are_not_available.yaml b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_extensions_are_not_available.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_extensions_are_not_available.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_host_is_not_initialized.yaml b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_host_is_not_initialized.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_host_is_not_initialized.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_oauth_server_is_not_configured.yaml b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_oauth_server_is_not_configured.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_oauth_server_is_not_configured.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_oauth_server_is_not_remote.yaml b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_oauth_server_is_not_remote.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_and_skills/should_report_error_when_mcp_oauth_server_is_not_remote.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_config/should_call_server_mcp_config_rpcs.yaml b/test/snapshots/rpc_mcp_config/should_call_server_mcp_config_rpcs.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_config/should_call_server_mcp_config_rpcs.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_mcp_config/should_round_trip_http_mcp_oauth_config_rpc.yaml b/test/snapshots/rpc_mcp_config/should_round_trip_http_mcp_oauth_config_rpc.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_mcp_config/should_round_trip_http_mcp_oauth_config_rpc.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_server/should_call_rpc_account_get_quota_when_authenticated.yaml b/test/snapshots/rpc_server/should_call_rpc_account_get_quota_when_authenticated.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_server/should_call_rpc_account_get_quota_when_authenticated.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_server/should_call_rpc_models_list_with_typed_result.yaml b/test/snapshots/rpc_server/should_call_rpc_models_list_with_typed_result.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_server/should_call_rpc_models_list_with_typed_result.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_server/should_call_rpc_ping_with_typed_params_and_result.yaml b/test/snapshots/rpc_server/should_call_rpc_ping_with_typed_params_and_result.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_server/should_call_rpc_ping_with_typed_params_and_result.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_server/should_call_rpc_tools_list_with_typed_result.yaml b/test/snapshots/rpc_server/should_call_rpc_tools_list_with_typed_result.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_server/should_call_rpc_tools_list_with_typed_result.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_server/should_discover_server_mcp_and_skills.yaml b/test/snapshots/rpc_server/should_discover_server_mcp_and_skills.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_server/should_discover_server_mcp_and_skills.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_call_session_rpc_model_getcurrent.yaml b/test/snapshots/rpc_session_state/should_call_session_rpc_model_getcurrent.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_call_session_rpc_model_getcurrent.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_call_session_rpc_model_switchto.yaml b/test/snapshots/rpc_session_state/should_call_session_rpc_model_switchto.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_call_session_rpc_model_switchto.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_call_session_usage_and_permission_rpcs.yaml b/test/snapshots/rpc_session_state/should_call_session_usage_and_permission_rpcs.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_call_session_usage_and_permission_rpcs.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_call_workspace_file_rpc_methods.yaml b/test/snapshots/rpc_session_state/should_call_workspace_file_rpc_methods.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_call_workspace_file_rpc_methods.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_create_workspace_file_with_nested_path_auto_creating_dirs.yaml b/test/snapshots/rpc_session_state/should_create_workspace_file_with_nested_path_auto_creating_dirs.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_create_workspace_file_with_nested_path_auto_creating_dirs.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_emit_title_changed_event_each_time_name_set_is_called.yaml b/test/snapshots/rpc_session_state/should_emit_title_changed_event_each_time_name_set_is_called.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_emit_title_changed_event_each_time_name_set_is_called.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_get_and_set_session_metadata.yaml b/test/snapshots/rpc_session_state/should_get_and_set_session_metadata.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_get_and_set_session_metadata.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_get_and_set_session_mode.yaml b/test/snapshots/rpc_session_state/should_get_and_set_session_mode.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_get_and_set_session_mode.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_handle_forking_session_without_persisted_events.yaml b/test/snapshots/rpc_session_state/should_handle_forking_session_without_persisted_events.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_handle_forking_session_without_persisted_events.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_read_update_and_delete_plan.yaml b/test/snapshots/rpc_session_state/should_read_update_and_delete_plan.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_read_update_and_delete_plan.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_reject_empty_or_whitespace_session_name.yaml b/test/snapshots/rpc_session_state/should_reject_empty_or_whitespace_session_name.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_reject_empty_or_whitespace_session_name.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_reject_workspace_file_path_traversal.yaml b/test/snapshots/rpc_session_state/should_reject_workspace_file_path_traversal.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_reject_workspace_file_path_traversal.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_report_error_reading_nonexistent_workspace_file.yaml b/test/snapshots/rpc_session_state/should_report_error_reading_nonexistent_workspace_file.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_report_error_reading_nonexistent_workspace_file.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_report_implemented_errors_for_unsupported_session_rpc_paths.yaml b/test/snapshots/rpc_session_state/should_report_implemented_errors_for_unsupported_session_rpc_paths.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_report_implemented_errors_for_unsupported_session_rpc_paths.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_set_and_get_each_session_mode_value.yaml b/test/snapshots/rpc_session_state/should_set_and_get_each_session_mode_value.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_set_and_get_each_session_mode_value.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_session_state/should_update_existing_workspace_file_with_update_operation.yaml b/test/snapshots/rpc_session_state/should_update_existing_workspace_file_with_update_operation.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_session_state/should_update_existing_workspace_file_with_update_operation.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_and_fleet/should_execute_shell_command.yaml b/test/snapshots/rpc_shell_and_fleet/should_execute_shell_command.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_and_fleet/should_execute_shell_command.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_and_fleet/should_kill_shell_process.yaml b/test/snapshots/rpc_shell_and_fleet/should_kill_shell_process.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_and_fleet/should_kill_shell_process.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_exec_with_custom_cwd_honors_override.yaml b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_custom_cwd_honors_override.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_custom_cwd_honors_override.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_exec_with_large_stdout_cleans_up.yaml b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_large_stdout_cleans_up.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_large_stdout_cleans_up.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_exec_with_nonexistent_command_returns_processid_and_cleans_up.yaml b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_nonexistent_command_returns_processid_and_cleans_up.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_nonexistent_command_returns_processid_and_cleans_up.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_exec_with_stderr_output_cleans_up.yaml b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_stderr_output_cleans_up.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_stderr_output_cleans_up.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_exec_with_timeout_kills_long_running_command.yaml b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_timeout_kills_long_running_command.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_exec_with_timeout_kills_long_running_command.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_kill_cleans_up_after_terminating_signal.yaml b/test/snapshots/rpc_shell_edge_cases/shell_kill_cleans_up_after_terminating_signal.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_kill_cleans_up_after_terminating_signal.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_shell_edge_cases/shell_kill_unknown_processid_returns_false.yaml b/test/snapshots/rpc_shell_edge_cases/shell_kill_unknown_processid_returns_false.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_shell_edge_cases/shell_kill_unknown_processid_returns_false.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_tasks_and_handlers/should_list_task_state_and_return_false_for_missing_task_operations.yaml b/test/snapshots/rpc_tasks_and_handlers/should_list_task_state_and_return_false_for_missing_task_operations.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_tasks_and_handlers/should_list_task_state_and_return_false_for_missing_task_operations.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_tasks_and_handlers/should_report_implemented_error_for_invalid_task_agent_model.yaml b/test/snapshots/rpc_tasks_and_handlers/should_report_implemented_error_for_invalid_task_agent_model.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_tasks_and_handlers/should_report_implemented_error_for_invalid_task_agent_model.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_tasks_and_handlers/should_report_implemented_error_for_missing_task_agent_type.yaml b/test/snapshots/rpc_tasks_and_handlers/should_report_implemented_error_for_missing_task_agent_type.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_tasks_and_handlers/should_report_implemented_error_for_missing_task_agent_type.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rpc_tasks_and_handlers/should_return_expected_results_for_missing_pending_handler_requestids.yaml b/test/snapshots/rpc_tasks_and_handlers/should_return_expected_results_for_missing_pending_handler_requestids.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/rpc_tasks_and_handlers/should_return_expected_results_for_missing_pending_handler_requestids.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/rust_multi_client/both_clients_see_tool_request_and_completion_events.yaml b/test/snapshots/rust_multi_client/both_clients_see_tool_request_and_completion_events.yaml new file mode 100644 index 000000000..20eefc57a --- /dev/null +++ b/test/snapshots/rust_multi_client/both_clients_see_tool_request_and_completion_events.yaml @@ -0,0 +1,21 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Use the magic_number tool with seed 'hello' and tell me the result + - role: assistant + content: I'll use the magic_number tool with seed 'hello' for you. + tool_calls: + - id: toolcall_0 + type: function + function: + name: magic_number + arguments: '{"seed":"hello"}' + - role: tool + tool_call_id: toolcall_0 + content: MAGIC_hello_42 + - role: assistant + content: The magic number for seed 'hello' is **MAGIC_hello_42**. diff --git a/test/snapshots/rust_multi_client/disconnecting_client_removes_its_tools.yaml b/test/snapshots/rust_multi_client/disconnecting_client_removes_its_tools.yaml new file mode 100644 index 000000000..192105ac7 --- /dev/null +++ b/test/snapshots/rust_multi_client/disconnecting_client_removes_its_tools.yaml @@ -0,0 +1,69 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Use the stable_tool with input 'test1' and tell me the result. + - role: assistant + content: I'll call the stable_tool with input 'test1' for you. + tool_calls: + - id: toolcall_0 + type: function + function: + name: stable_tool + arguments: '{"input":"test1"}' + - role: tool + tool_call_id: toolcall_0 + content: STABLE_test1 + - role: assistant + content: "The stable_tool returned: **STABLE_test1**" + - role: user + content: Use the ephemeral_tool with input 'test2' and tell me the result. + - role: assistant + content: I'll call the ephemeral_tool with input 'test2' for you. + tool_calls: + - id: toolcall_1 + type: function + function: + name: ephemeral_tool + arguments: '{"input":"test2"}' + - role: tool + tool_call_id: toolcall_1 + content: EPHEMERAL_test2 + - role: assistant + content: "The ephemeral_tool returned: **EPHEMERAL_test2**" + - role: user + content: >- + + + Tools no longer available: ephemeral_tool + + + Important: Do not attempt to call tools that are no longer available unless you've been notified that they're + available again. + + + + + Use the stable_tool with input 'still_here'. Also try using ephemeral_tool if it is available. + - role: assistant + content: I'll call the stable_tool with input 'still_here'. The ephemeral_tool is no longer available, so I can only use + the stable_tool. + tool_calls: + - id: toolcall_2 + type: function + function: + name: stable_tool + arguments: '{"input":"still_here"}' + - role: tool + tool_call_id: toolcall_2 + content: STABLE_still_here + - role: assistant + content: >- + The stable_tool returned: **STABLE_still_here** + + + The ephemeral_tool is not available anymore (it was removed as indicated in the tools_changed_notice), so I + could only call the stable_tool. diff --git a/test/snapshots/rust_multi_client/one_client_approves_permission_and_both_see_the_result.yaml b/test/snapshots/rust_multi_client/one_client_approves_permission_and_both_see_the_result.yaml new file mode 100644 index 000000000..e67357589 --- /dev/null +++ b/test/snapshots/rust_multi_client/one_client_approves_permission_and_both_see_the_result.yaml @@ -0,0 +1,50 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Create a file called hello.txt containing the text 'hello world' + - role: assistant + content: I'll create the hello.txt file for you. + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Creating hello.txt file"}' + - role: assistant + tool_calls: + - id: toolcall_1 + type: function + function: + name: create + arguments: '{"file_text":"hello world","path":"${workdir}/hello.txt"}' + - messages: + - role: system + content: ${system} + - role: user + content: Create a file called hello.txt containing the text 'hello world' + - role: assistant + content: I'll create the hello.txt file for you. + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Creating hello.txt file"}' + - id: toolcall_1 + type: function + function: + name: create + arguments: '{"file_text":"hello world","path":"${workdir}/hello.txt"}' + - role: tool + tool_call_id: toolcall_0 + content: Intent logged + - role: tool + tool_call_id: toolcall_1 + content: Created file ${workdir}/hello.txt with 11 characters + - role: assistant + content: Done - I created hello.txt containing "hello world". diff --git a/test/snapshots/rust_multi_client/one_client_rejects_permission_and_both_see_the_result.yaml b/test/snapshots/rust_multi_client/one_client_rejects_permission_and_both_see_the_result.yaml new file mode 100644 index 000000000..ba9db87d0 --- /dev/null +++ b/test/snapshots/rust_multi_client/one_client_rejects_permission_and_both_see_the_result.yaml @@ -0,0 +1,25 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Edit protected.txt and replace 'protected' with 'hacked'. + - role: assistant + content: I'll help you edit protected.txt to replace 'protected' with 'hacked'. Let me first view the file and then make + the change. + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Editing protected.txt file"}' + - role: assistant + tool_calls: + - id: toolcall_1 + type: function + function: + name: view + arguments: '{"path":"${workdir}/protected.txt"}' diff --git a/test/snapshots/rust_multi_client/two_clients_register_different_tools_and_agent_uses_both.yaml b/test/snapshots/rust_multi_client/two_clients_register_different_tools_and_agent_uses_both.yaml new file mode 100644 index 000000000..c97e969df --- /dev/null +++ b/test/snapshots/rust_multi_client/two_clients_register_different_tools_and_agent_uses_both.yaml @@ -0,0 +1,36 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Use the city_lookup tool with countryCode 'US' and tell me the result. + - role: assistant + content: I'll call the city_lookup tool with the country code 'US' for you. + tool_calls: + - id: toolcall_0 + type: function + function: + name: city_lookup + arguments: '{"countryCode":"US"}' + - role: tool + tool_call_id: toolcall_0 + content: CITY_FOR_US + - role: assistant + content: The city_lookup tool returned **"CITY_FOR_US"** for the country code 'US'. + - role: user + content: Now use the currency_lookup tool with countryCode 'US' and tell me the result. + - role: assistant + content: I'll call the currency_lookup tool with the country code 'US' for you. + tool_calls: + - id: toolcall_1 + type: function + function: + name: currency_lookup + arguments: '{"countryCode":"US"}' + - role: tool + tool_call_id: toolcall_1 + content: CURRENCY_FOR_US + - role: assistant + content: The currency_lookup tool returned **"CURRENCY_FOR_US"** for the country code 'US'. diff --git a/test/snapshots/session_config/should_apply_all_reasoning_effort_values_on_session_create.yaml b/test/snapshots/session_config/should_apply_all_reasoning_effort_values_on_session_create.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/session_config/should_apply_all_reasoning_effort_values_on_session_create.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/session_config/should_apply_reasoning_effort_on_session_create.yaml b/test/snapshots/session_config/should_apply_reasoning_effort_on_session_create.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/session_config/should_apply_reasoning_effort_on_session_create.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/session_config/should_create_session_with_custom_provider_config.yaml b/test/snapshots/session_config/should_create_session_with_custom_provider_config.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/session_config/should_create_session_with_custom_provider_config.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/session_config/should_use_custom_session_id.yaml b/test/snapshots/session_config/should_use_custom_session_id.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/session_config/should_use_custom_session_id.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: []