diff --git a/src/main.rs b/src/main.rs index 11ad58f..9b281b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,20 @@ use crate::ratelimiter::RateLimiter; use borsh::BorshDeserialize; use borsh::BorshSerialize; +#[derive(Debug)] +enum FastsyncError { + #[allow(dead_code)] // allow for debug output + AppError(String), + #[allow(dead_code)] // allow for debug output + IoError(std::io::Error), +} + +impl From for FastsyncError { + fn from(err: std::io::Error) -> FastsyncError { + FastsyncError::IoError(err) + } +} + #[derive(Debug, Clone, Copy, PartialEq)] enum Verbosity { Silent, @@ -262,7 +276,7 @@ struct SendState { enum SendResult { Done, - FileVanished, + FileVanished { fname: PathBuf }, Progress { bytes_sent: u64 }, } @@ -304,7 +318,9 @@ impl SendState { let res = match std::fs::File::open(fname) { Ok(f) => f, Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - return Ok(SendResult::FileVanished) + return Ok(SendResult::FileVanished { + fname: fname.clone(), + }); } Err(e) => return Err(e), }; @@ -420,7 +436,7 @@ fn main_send( sender_events: std::sync::mpsc::Sender, max_bandwidth_mbps: Option, verbosity: Verbosity, -) -> Result<()> { +) -> std::result::Result<(), FastsyncError> { let mut plan = TransferPlan { proto_version: protocol_version, files: Vec::new(), @@ -471,8 +487,14 @@ fn main_send( _ = limiter_mutex.lock().unwrap().insert(ratelimiter); } + let (error_tx, error_rx) = std::sync::mpsc::channel::(); + let mut start_time_opt: Option = None; loop { + if let Some(err) = error_rx.try_recv().ok() { + return Err(err); + } + let (mut stream, addr) = listener.accept()?; let start_time = *start_time_opt.get_or_insert_with(Instant::now); if Verbosity::Verbose == verbosity { @@ -500,6 +522,7 @@ fn main_send( let state_clone = state_arc.clone(); let limiter_mutex_2 = limiter_mutex.clone(); let total_bytes_sent_clone = total_bytes_sent.clone(); + let thread_error_tx = error_tx.clone(); let push_thread = std::thread::spawn(move || { // All the threads iterate through all the files one by one, so all // the threads collaborate on sending the first one, then the second @@ -517,10 +540,20 @@ fn main_send( std::thread::sleep(to_wait.unwrap()); } match file.send_one(&mut stream, verbosity) { - Ok(SendResult::FileVanished) => { - if Verbosity::Verbose == verbosity { - println!("File {:?} vanished", file.id); + Ok(SendResult::FileVanished { fname }) => { + let error_msg = format!( + "File {:?} vanished during transfer, cannot perform full transfer.", + fname + ); + match thread_error_tx.send(FastsyncError::AppError(error_msg.clone())) { + Ok(_) => {} + Err(_) => { + // If other thread reported error already the channel will + // be closed, in this case just log error and exit. + println!("Error channel closed already. Thread encountered an error: {error_msg}"); + } } + return; } Ok(SendResult::Progress { bytes_sent: bytes_written, @@ -539,7 +572,13 @@ fn main_send( continue 'chunks; } Ok(SendResult::Done) => continue 'files, - Err(err) => panic!("Failed to send: {err}"), + Err(err) => { + let error_msg = format!("Failed to send: {err}"); + thread_error_tx + .send(FastsyncError::AppError(error_msg.clone())) + .expect("expected error channel to be open"); + return; + } } } } @@ -556,6 +595,11 @@ fn main_send( push_thread.join().expect("Failed to wait for push thread."); } + // Before we exit, check if any of the threads reported an error + if let Some(err) = error_rx.try_recv().ok() { + return Err(err); + } + Ok(()) } @@ -642,7 +686,7 @@ fn main_recv( write_mode: WriteMode, protocol_version: u16, verbosity: Verbosity, -) -> Result<()> { +) -> std::result::Result<(), FastsyncError> { // First we initiate one connection. The sender will send the plan over // that. We read it. Unbuffered, because we want to skip the buffer for the // remaining reads, but the header is tiny so it should be okay. @@ -655,7 +699,8 @@ fn main_recv( "Sender is version {} and we only support {WIRE_PROTO_VERSION}", plan.proto_version ), - )); + ) + .into()); } if write_mode == WriteMode::AskConfirm { plan.ask_confirm_receive()?; @@ -668,7 +713,7 @@ fn main_recv( // time, and if the network is faster the channel will be full all the time. let (sender, receiver) = mpsc::sync_channel::(16); - let writer_thread = std::thread::spawn::<_, ()>(move || { + let writer_thread = std::thread::spawn(move || { let total_len: u64 = plan.files.iter().map(|f| f.len).sum(); let mut files: Vec<_> = plan.files.into_iter().map(FileReceiver::new).collect(); @@ -686,8 +731,11 @@ fn main_recv( } if bytes_received < total_len { - panic!("Transmission ended, but not all data was received."); + return Err(FastsyncError::AppError( + "Transmission ended, but not all data was received.".to_string(), + )); } + Ok(()) }); // We make n threads that "pull" the data from a socket. The first socket we @@ -755,7 +803,9 @@ fn main_recv( std::mem::drop(sender); for pull_thread in pull_threads { - pull_thread.join().expect("Failed to join pull thread.")?; + pull_thread.join().map_err(|err| { + FastsyncError::AppError(format!("failed to join pull thread: {:?}", err)) + })??; } // After all pulls are done and the transfer is complete, the sender is @@ -773,7 +823,9 @@ fn main_recv( Err(_) => {} } - writer_thread.join().expect("Failed to join writer thread."); + writer_thread.join().map_err(|err| { + FastsyncError::AppError(format!("failed to join writer thread: {:?}", err)) + })??; Ok(()) } @@ -841,10 +893,13 @@ mod tests { 1, Verbosity::Silent, ); - assert_eq!( - res.expect_err("Expected failure").kind(), - ErrorKind::InvalidData - ); + match res { + Ok(_) => panic!("Expected failure, but got success."), + Err(FastsyncError::IoError(err)) => { + assert_eq!(err.kind(), ErrorKind::InvalidData); + } + Err(err) => panic!("Expected IoError, but got {err:?}"), + } } } } @@ -946,4 +1001,64 @@ mod tests { } } } + + #[test] + fn file_deleted_before_send() { + let (events_tx, events_rx) = std::sync::mpsc::channel::(); + env::set_current_dir("/tmp/").unwrap(); + let cwd = env::current_dir().unwrap(); + + let td = TempDir::new_in(".").unwrap(); + let tmp_path = td.path().strip_prefix(cwd).unwrap(); + let path1 = tmp_path.join("file1"); + let path2 = tmp_path.join("file_deleted_before_send"); + let fnames = vec![ + path1.clone().into_os_string().into_string().unwrap(), + path2.clone().into_os_string().into_string().unwrap(), + ]; + + let sender_handle = thread::spawn(move || { + { + for path in &fnames { + let mut f = std::fs::File::create(path).unwrap(); + f.write_all(&vec![0u8; MAX_CHUNK_LEN as usize * 3]).unwrap(); + } + } + + let res = main_send( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), + &fnames, + 1, + events_tx, + None, + Verbosity::Silent, + ); + res + }); + + match events_rx.recv().unwrap() { + SenderEvent::Listening(port) => { + // Remove file after the sender has listed it + std::fs::remove_file(path2).expect("Failed to delete file before send"); + + main_recv( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port), + 1, + WriteMode::Force, + 1, + Verbosity::Silent, + ) + .expect_err("expected receiver to fail"); + } + } + + let result = sender_handle.join().expect("Failed to join sender thread."); + match result { + Ok(_) => panic!("Expected failure, but got success."), + Err(FastsyncError::AppError(err)) => { + assert!(err.contains("vanished during transfer")); + } + Err(err) => panic!("Expected AppError, but got {err:?}"), + } + } }