diff --git a/.github/target.toml b/.github/target.toml new file mode 100644 index 0000000..58a06b0 --- /dev/null +++ b/.github/target.toml @@ -0,0 +1,13 @@ +# Wind CI Build Target Configuration + +[[target]] +os = "ubuntu-latest" +target = "x86_64-unknown-linux-gnu" +release-name = "x86_64-linux" +skip-test = false + +[[target]] +os = "macos-latest" +target = "aarch64-apple-darwin" +release-name = "aarch64-darwin" +skip-test = false diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0714ad5 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,17 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + uses: rust-proxy/workflows/.github/workflows/rust-build.yml@main + with: + packages: "wind" + run-tests: true + only-clippy-tests-on-pr: true diff --git a/crates/wind-core/Cargo.toml b/crates/wind-core/Cargo.toml index cb307c7..cd3d3d6 100644 --- a/crates/wind-core/Cargo.toml +++ b/crates/wind-core/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "wind-core" -version.workspace = true -repository.workspace = true -edition.workspace = true -description.workspace = true +version = "0.1.1" +repository = "https://github.com/proxy-rs/wind" +edition = "2024" +description = "Wind core networking abstractions" license = "MIT OR Apache-2.0" [features] @@ -15,7 +15,7 @@ pin-project = "1" tokio = { version = "1", default-features = false, features = ["io-util", "macros", "time", "net"] } tokio-util = { version = "0.7", features = ["rt"] } -quinn = { version = "0.11", default-features = false, optional = true } +quinn = { workspace = true, default-features = false, optional = true } quinn-udp = "0.5" socket2 = "0.6" diff --git a/crates/wind-core/src/udp.rs b/crates/wind-core/src/udp.rs index 6f3e83e..296a9ab 100644 --- a/crates/wind-core/src/udp.rs +++ b/crates/wind-core/src/udp.rs @@ -10,8 +10,6 @@ use std::{ use bytes::Bytes; use futures::future::poll_fn; -#[cfg(feature = "quic")] -pub use quinn::UdpPoller; pub use quinn_udp::{EcnCodepoint, RecvMeta as QuinnRecvMeta, Transmit, UdpSocketState}; // Re-export quinn-udp's RecvMeta directly // pub use quinn_udp::RecvMeta; @@ -19,7 +17,6 @@ use tokio::io::Interest; use crate::types::TargetAddr; -#[cfg(not(feature = "quic"))] pub trait UdpPoller: Send + Sync + Debug + 'static { fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll>; } @@ -31,9 +28,9 @@ pub trait UdpPoller: Send + Sync + Debug + 'static { #[derive(Debug, Clone)] pub struct RecvMeta { /// The source address of the datagram(s) contained in the buffer - pub addr: SocketAddr, + pub addr: SocketAddr, /// The number of bytes the associated buffer has - pub len: usize, + pub len: usize, /// The size of a single datagram in the associated buffer /// /// When GRO (Generic Receive Offload) is used this indicates the size of a @@ -41,15 +38,15 @@ pub struct RecvMeta { /// [`len`] is greater then this value, then the individual datagrams /// contained have their boundaries at `stride` increments from the start. /// The last datagram could be smaller than `stride`. - pub stride: usize, + pub stride: usize, /// The Explicit Congestion Notification bits for the datagram(s) in the /// buffer - pub ecn: Option, + pub ecn: Option, /// The destination IP address which was encoded in this datagram /// /// Populated on platforms: Windows, Linux, Android (API level > 25), /// FreeBSD, OpenBSD, NetBSD, macOS, and iOS. - pub dst_ip: Option, + pub dst_ip: Option, /// The destination address that this packet is intended for /// This is our custom field for better packet routing pub destination: Option, @@ -59,11 +56,11 @@ impl Default for RecvMeta { /// Constructs a value with arbitrary fields, intended to be overwritten fn default() -> Self { Self { - addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - len: 0, - stride: 0, - ecn: None, - dst_ip: None, + addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + len: 0, + stride: 0, + ecn: None, + dst_ip: None, destination: None, } } @@ -72,11 +69,11 @@ impl Default for RecvMeta { impl From for RecvMeta { fn from(meta: QuinnRecvMeta) -> Self { Self { - addr: meta.addr, - len: meta.len, - stride: meta.stride, - ecn: meta.ecn, - dst_ip: meta.dst_ip, + addr: meta.addr, + len: meta.len, + stride: meta.stride, + ecn: meta.ecn, + dst_ip: meta.dst_ip, destination: None, } } @@ -84,8 +81,8 @@ impl From for RecvMeta { #[derive(Debug, Clone)] pub struct UdpPacket { - pub source: Option, - pub target: TargetAddr, + pub source: Option, + pub target: TargetAddr, pub payload: Bytes, } @@ -130,11 +127,11 @@ pub trait AbstractUdpSocket: Send + Sync { /// Sends data on the socket to the given address. fn poll_send(&self, _cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll> { let transmit = Transmit { - destination: target, - contents: buf, - ecn: None, + destination: target, + contents: buf, + ecn: None, segment_size: None, - src_ip: None, + src_ip: None, }; match self.try_send(&transmit) { Ok(_) => Poll::Ready(Ok(buf.len())), @@ -150,14 +147,14 @@ pub trait AbstractUdpSocket: Send + Sync { #[derive(Debug)] pub struct TokioUdpSocket { - io: tokio::net::UdpSocket, + io: tokio::net::UdpSocket, inner: UdpSocketState, } impl TokioUdpSocket { pub fn new(sock: std::net::UdpSocket) -> std::io::Result { Ok(Self { inner: UdpSocketState::new((&sock).into())?, - io: tokio::net::UdpSocket::from_std(sock)?, + io: tokio::net::UdpSocket::from_std(sock)?, }) } } diff --git a/crates/wind-socks/Cargo.toml b/crates/wind-socks/Cargo.toml index 7db02cc..1db8571 100644 --- a/crates/wind-socks/Cargo.toml +++ b/crates/wind-socks/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "wind-socks" -version.workspace = true -repository.workspace = true -edition.workspace = true -description.workspace = true +version = "0.1.1" +repository = "https://github.com/proxy-rs/wind" +edition = "2024" +description = "Wind SOCKS5 implementation" license = "MIT OR Apache-2.0" [dependencies] -wind-core = { version = "0.1.1", path = "../wind-core"} +wind-core = { path = "../wind-core"} # Async tokio = { version = "1", default-features = false, features = ["net"] } tokio-util = { version = "0.7", features = ["codec"] } diff --git a/crates/wind-tuic/Cargo.toml b/crates/wind-tuic/Cargo.toml index 05a776b..3323bda 100644 --- a/crates/wind-tuic/Cargo.toml +++ b/crates/wind-tuic/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "wind-tuic" -version.workspace = true -repository.workspace = true -edition.workspace = true -description.workspace = true +version = "0.1.1" +repository = "https://github.com/proxy-rs/wind" +edition = "2024" +description = "Wind TUIC protocol implementation" license = "MIT OR Apache-2.0" [features] @@ -22,13 +22,13 @@ ring = [ ] [dependencies] -wind-core = { version = "0.1.1", path = "../wind-core", features = ["quic"]} +wind-core = { path = "../wind-core", features = ["quic"]} # Async tokio = { version = "1", default-features = false, features = ["net"] } tokio-util = { version = "0.7", features = ["codec"] } -quinn = { version = "0.11", default-features = false, features = ["runtime-tokio"]} +quinn = { workspace = true, default-features = false, features = ["runtime-tokio", "qlog"] } crossfire = { version = "2", features = ["tokio"] } tokio-stream = "0.1" diff --git a/crates/wind-tuic/src/inbound.rs b/crates/wind-tuic/src/inbound.rs index 4c1449f..140bc67 100644 --- a/crates/wind-tuic/src/inbound.rs +++ b/crates/wind-tuic/src/inbound.rs @@ -138,8 +138,8 @@ impl Default for TuicInboundOpts { /// TUIC inbound server pub struct TuicInbound { pub ctx: Arc, - opts: TuicInboundOpts, - cancel: CancellationToken, + opts: TuicInboundOpts, + cancel: CancellationToken, } impl TuicInbound { @@ -243,16 +243,16 @@ impl AbstractInbound for TuicInbound { /// Represents an authenticated connection struct InboundCtx { - conn: quinn::Connection, - uuid: Arc>>, - users: HashMap, + conn: quinn::Connection, + uuid: Arc>>, + users: HashMap, udp_sessions: Arc>>, } /// UDP session tracking #[allow(dead_code)] struct UdpSession { - assoc_id: u16, + assoc_id: u16, // Track packet fragments if needed fragments: Cache>, } @@ -323,7 +323,7 @@ async fn handle_connection( } Ok(recv) => recv, }; - + let conn = connection.clone(); if let Err(e) = handle_uni_stream(conn, recv, callback).await { error!("Uni stream error: {:?}", e); @@ -338,7 +338,7 @@ async fn handle_connection( } Ok(streams) => streams, }; - + let conn = connection.clone(); if let Err(e) = handle_bi_stream(conn, send, recv, callback).await { error!("Bi stream error: {:?}", e); @@ -353,7 +353,7 @@ async fn handle_connection( } Ok(datagram) => datagram, }; - + let conn = connection.clone(); if let Err(e) = handle_datagram(conn, datagram, callback).await { error!("Datagram error: {:?}", e); @@ -390,7 +390,7 @@ async fn handle_uni_stream( // Decode address let addr = crate::proto::decode_address(&mut buf, "uni stream packet")?; let payload = buf.split_to(size as usize).freeze(); - + // Convert address to TargetAddr using helper function let target_addr = crate::proto::address_to_target(addr)?; handle_udp_packet(&ctx, assoc_id, target_addr, payload, callback).await?; @@ -491,7 +491,7 @@ async fn handle_datagram( if let Command::Packet { assoc_id, size, .. } = cmd { let addr = crate::proto::decode_address(&mut buf, "datagram packet")?; let payload = buf.split_to(size as usize).freeze(); - + // Convert address to TargetAddr using helper function let target_addr = crate::proto::address_to_target(addr)?; handle_udp_packet(&connection, assoc_id, target_addr, payload, callback).await?; diff --git a/crates/wind-tuic/src/lib.rs b/crates/wind-tuic/src/lib.rs index d5c0cff..8d381a2 100644 --- a/crates/wind-tuic/src/lib.rs +++ b/crates/wind-tuic/src/lib.rs @@ -1,6 +1,7 @@ #![feature(error_generic_member_access)] pub mod proto; +pub mod simple_udp; mod task; pub mod tls; diff --git a/crates/wind-tuic/src/outbound.rs b/crates/wind-tuic/src/outbound.rs index 113557c..bb7c7a6 100644 --- a/crates/wind-tuic/src/outbound.rs +++ b/crates/wind-tuic/src/outbound.rs @@ -26,27 +26,27 @@ use crate::{ }; pub struct TuicOutboundOpts { - pub peer_addr: SocketAddr, - pub sni: String, - pub auth: (Uuid, Arc<[u8]>), + pub peer_addr: SocketAddr, + pub sni: String, + pub auth: (Uuid, Arc<[u8]>), pub zero_rtt_handshake: bool, - pub heartbeat: Duration, - pub gc_interval: Duration, - pub gc_lifetime: Duration, - pub skip_cert_verify: bool, - pub alpn: Vec, + pub heartbeat: Duration, + pub gc_interval: Duration, + pub gc_lifetime: Duration, + pub skip_cert_verify: bool, + pub alpn: Vec, } pub struct TuicOutbound { - pub ctx: Arc, - pub endpoint: quinn::Endpoint, - pub peer_addr: SocketAddr, - pub sni: String, - pub opts: TuicOutboundOpts, - pub connection: quinn::Connection, + pub ctx: Arc, + pub endpoint: quinn::Endpoint, + pub peer_addr: SocketAddr, + pub sni: String, + pub opts: TuicOutboundOpts, + pub connection: quinn::Connection, pub udp_assoc_counter: AtomicU16, - pub token: CancellationToken, - pub udp_session: Cache>, + pub token: CancellationToken, + pub udp_session: Cache>, } impl TuicOutbound { @@ -70,7 +70,7 @@ impl TuicOutbound { )); let mut transport_config = quinn::TransportConfig::default(); transport_config - .congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default())) + .congestion_controller_factory(Arc::new(quinn::congestion::Bbr3Config::default())) .keep_alive_interval(None); client_config.transport_config(Arc::new(transport_config)); @@ -82,7 +82,7 @@ impl TuicOutbound { .map_err(|e| eyre::eyre!("Failed to bind socket to {}: {}", socket_addr, e))? .into_std()?; - let mut endpoint = quinn::Endpoint::new(quinn::EndpointConfig::default(), None, socket, Arc::new(TokioRuntime))?; + let endpoint = quinn::Endpoint::new(quinn::EndpointConfig::default(), None, socket, Arc::new(TokioRuntime))?; endpoint.set_default_client_config(client_config); let connection = endpoint .connect(peer_addr, &server_name) @@ -301,7 +301,7 @@ impl AbstractOutbound for TuicOutbound { } Ok(packet) => packet, }; - + // Received packet from remote, send to local socket // overrided in socks inbound const UNSPECIFIED_V4: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); @@ -320,7 +320,7 @@ impl AbstractOutbound for TuicOutbound { } Ok(packet) => packet, }; - + // Send packet to remote via UDP stream let payload_len = packet.payload.len(); if let Err(e) = udp_stream.send_packet(packet).await { @@ -368,7 +368,7 @@ impl AbstractOutbound for TuicOutbound { } Ok(meta) => meta, }; - + // In outbound context, get target address from meta.destination or use meta.addr let target_addr = meta.destination .as_ref() @@ -450,7 +450,6 @@ impl AbstractOutbound for TuicOutbound { } } - // Clean up the UDP association before exiting if let Err(err) = self.connection.drop_udp(assoc_id).await { info!(target: "[OUT]", "Error dropping UDP association {:#06x}: {}", assoc_id, err); @@ -459,3 +458,97 @@ impl AbstractOutbound for TuicOutbound { Ok(()) } } + +impl TuicOutbound { + pub async fn handle_udp_simple( + &self, + assoc_id: u16, + ) -> Result<(crate::simple_udp::SimpleUdpChannel, crate::simple_udp::SimpleUdpChannelTx), Error> { + use crate::simple_udp::{SimpleUdpChannel, SimpleUdpPacket}; + use std::sync::Arc; + + info!(target: "[OUT]", "Creating new UDP association with simple channel: {:#06x}", assoc_id); + + let connection = self.connection.clone(); + let (channel, channel_tx) = SimpleUdpChannel::new(128); + + // Use crossfire channel compatible with UdpStream + let (wind_tx, wind_rx) = crossfire::mpmc::bounded_async::(128); + let udp_stream = Arc::new(UdpStream::new(connection.clone(), assoc_id, wind_tx)); + self.udp_session.insert(assoc_id, udp_stream.clone()).await; + + let cancel = self.token.child_token(); + let channel_tx_clone = channel_tx.clone(); + let udp_session = self.udp_session.clone(); + + let mut gc_interval = tokio::time::interval(self.opts.gc_interval); + gc_interval.tick().await; // consume the immediate first tick + + self.ctx.tasks.spawn(async move { + loop { + tokio::select! { + _ = cancel.cancelled() => { + info!(target: "[OUT]", "UDP simple stream for association {:#06x} cancelled", assoc_id); + break; + } + + // Remote → caller: forward reassembled packets into the SimpleUdpChannel + result = wind_rx.recv() => { + let Ok(packet) = result else { + warn!(target: "[OUT]", "UDP simple rx channel closed for association {:#06x}", assoc_id); + break; + }; + + let target = match &packet.target { + wind_core::types::TargetAddr::IPv4(ip, port) => SocketAddr::from((*ip, *port)), + wind_core::types::TargetAddr::IPv6(ip, port) => SocketAddr::from((*ip, *port)), + _ => continue, + }; + + let simple_packet = SimpleUdpPacket::new(None, target, packet.payload); + + if channel_tx_clone.send_from_remote(simple_packet).await.is_err() { + warn!(target: "[OUT]", "SimpleUdpChannel receiver dropped for association {:#06x}, closing", assoc_id); + break; + } + } + + // Caller → remote: forward packets from the SimpleUdpChannelTx into the TUIC connection + result = channel_tx_clone.to_remote_rx.recv() => { + let Ok(simple_packet) = result else { + warn!(target: "[OUT]", "SimpleUdpChannel to_remote channel closed for association {:#06x}", assoc_id); + break; + }; + + let target = wind_core::types::TargetAddr::from(simple_packet.target); + let wind_packet = wind_core::udp::UdpPacket { + source: None, + target, + payload: simple_packet.payload, + }; + + let payload_len = wind_packet.payload.len(); + if let Err(e) = udp_stream.send_packet(wind_packet).await { + warn!(target: "[OUT]", "Failed to send UDP packet to remote (assoc {:#06x}): {}", assoc_id, e); + } else { + info!(target: "[OUT]", "Sent UDP packet to remote ({} bytes, assoc {:#06x})", payload_len, assoc_id); + } + } + + // Periodic GC: evict stale fragment reassembly state + _ = gc_interval.tick() => { + udp_stream.collect_garbage().await; + } + } + } + + // Cleanup: remove from session table and send Dissociate to peer + udp_session.remove(&assoc_id).await; + if let Err(e) = connection.drop_udp(assoc_id).await { + info!(target: "[OUT]", "Error dropping UDP association {:#06x}: {}", assoc_id, e); + } + }); + + Ok((channel, channel_tx)) + } +} diff --git a/crates/wind-tuic/src/proto/addr.rs b/crates/wind-tuic/src/proto/addr.rs index 625f414..85d022a 100644 --- a/crates/wind-tuic/src/proto/addr.rs +++ b/crates/wind-tuic/src/proto/addr.rs @@ -38,10 +38,10 @@ pub enum Address { #[derive(IntoPrimitive, FromPrimitive, Copy, Clone, Debug, PartialEq)] #[repr(u8)] pub enum AddressType { - None = u8::MAX, + None = u8::MAX, Domain = 0, - IPv4 = 1, - IPv6 = 2, + IPv4 = 1, + IPv6 = 2, #[num_enum(catch_all)] Other(u8), } @@ -80,7 +80,12 @@ impl Decoder for AddressCodec { // Parse address type from first byte let addr_type = AddressType::from(src[0]); - ensure!(!matches!(addr_type, AddressType::Other(_)), UnknownAddressTypeSnafu { value: u8::from(addr_type) }); + ensure!( + !matches!(addr_type, AddressType::Other(_)), + UnknownAddressTypeSnafu { + value: u8::from(addr_type) + } + ); match addr_type { AddressType::None => { diff --git a/crates/wind-tuic/src/proto/cmd.rs b/crates/wind-tuic/src/proto/cmd.rs index 736dcb3..1aa3b4e 100644 --- a/crates/wind-tuic/src/proto/cmd.rs +++ b/crates/wind-tuic/src/proto/cmd.rs @@ -11,16 +11,16 @@ pub struct CmdCodec(pub CmdType); #[derive(Debug, Clone, PartialEq)] pub enum Command { Auth { - uuid: uuid::Uuid, + uuid: uuid::Uuid, token: [u8; 32], }, Connect, Packet { - assoc_id: u16, - pkt_id: u16, + assoc_id: u16, + pkt_id: u16, frag_total: u8, - frag_id: u8, - size: u16, + frag_id: u8, + size: u16, }, Dissociate { assoc_id: u16, @@ -54,11 +54,11 @@ impl Decoder for CmdCodec { } Ok(Some(Command::Packet { - assoc_id: src.get_u16(), - pkt_id: src.get_u16(), + assoc_id: src.get_u16(), + pkt_id: src.get_u16(), frag_total: src.get_u8(), - frag_id: src.get_u8(), - size: src.get_u16(), + frag_id: src.get_u8(), + size: src.get_u16(), })) } CmdType::Dissociate => { @@ -132,16 +132,16 @@ mod test { async fn test_cmd_1() -> eyre::Result<()> { let vars = vec![ Command::Auth { - uuid: Uuid::parse_str("02f09a3f-1624-3b1d-8409-44eff7708208")?, + uuid: Uuid::parse_str("02f09a3f-1624-3b1d-8409-44eff7708208")?, token: [1; 32], }, Command::Connect, Command::Packet { - assoc_id: 123, - pkt_id: 123, + assoc_id: 123, + pkt_id: 123, frag_total: 5, - frag_id: 1, - size: 8, + frag_id: 1, + size: 8, }, Command::Dissociate { assoc_id: 23 }, Command::Heartbeat, @@ -173,15 +173,15 @@ mod test { async fn test_cmd_2() -> eyre::Result<()> { let vars = vec![ Command::Auth { - uuid: Uuid::parse_str("02f09a3f-1624-3b1d-8409-44eff7708208")?, + uuid: Uuid::parse_str("02f09a3f-1624-3b1d-8409-44eff7708208")?, token: [1; 32], }, Command::Packet { - assoc_id: 123, - pkt_id: 123, + assoc_id: 123, + pkt_id: 123, frag_total: 5, - frag_id: 1, - size: 8, + frag_id: 1, + size: 8, }, Command::Dissociate { assoc_id: 23 }, ]; diff --git a/crates/wind-tuic/src/proto/error.rs b/crates/wind-tuic/src/proto/error.rs index dfff1b4..1ca557d 100644 --- a/crates/wind-tuic/src/proto/error.rs +++ b/crates/wind-tuic/src/proto/error.rs @@ -7,44 +7,44 @@ use snafu::prelude::*; #[snafu(visibility(pub))] pub enum ProtoError { VersionDismatch { - expect: u8, - current: u8, + expect: u8, + current: u8, backtrace: Backtrace, }, #[snafu(display("Unknown command type {value}"))] UnknownCommandType { - value: u8, + value: u8, backtrace: Backtrace, }, #[snafu(display("Unable to decode address due to type {value}"))] UnknownAddressType { - value: u8, + value: u8, backtrace: Backtrace, }, FailParseDomain { // HEX - raw: String, - source: Utf8Error, + raw: String, + source: Utf8Error, backtrace: Backtrace, }, DomainTooLong { - domain: String, + domain: String, backtrace: Backtrace, }, // Caller should yield BytesRemaining, Io { // #[snafu(backtrace)] - source: std::io::Error, + source: std::io::Error, backtrace: Backtrace, }, NumericOverflow { - field: String, - num: String, + field: String, + num: String, backtrace: Backtrace, }, ReadToEnd { - source: ReadToEndError, + source: ReadToEndError, backtrace: Backtrace, }, } diff --git a/crates/wind-tuic/src/proto/header.rs b/crates/wind-tuic/src/proto/header.rs index 268ba69..e3db606 100644 --- a/crates/wind-tuic/src/proto/header.rs +++ b/crates/wind-tuic/src/proto/header.rs @@ -18,11 +18,11 @@ pub struct Header { #[derive(IntoPrimitive, FromPrimitive, Copy, Clone, Debug, PartialEq)] #[repr(u8)] pub enum CmdType { - Auth = 0, - Connect = 1, - Packet = 2, + Auth = 0, + Connect = 1, + Packet = 2, Dissociate = 3, - Heartbeat = 4, + Heartbeat = 4, #[num_enum(catch_all)] Other(u8), } @@ -43,11 +43,20 @@ impl Decoder for HeaderCodec { return Ok(None); } let ver = src.get_u8(); - ensure!(ver == VER, VersionDismatchSnafu { expect: VER, current: ver }); + ensure!( + ver == VER, + VersionDismatchSnafu { + expect: VER, + current: ver + } + ); let cmd = CmdType::from(src.get_u8()); - ensure!(!matches!(cmd, CmdType::Other(..)), UnknownCommandTypeSnafu { value: u8::from(cmd) }); + ensure!( + !matches!(cmd, CmdType::Other(..)), + UnknownCommandTypeSnafu { value: u8::from(cmd) } + ); Ok(Some(Header::new(cmd))) } diff --git a/crates/wind-tuic/src/proto/mod.rs b/crates/wind-tuic/src/proto/mod.rs index 5072d40..a5bd6ac 100644 --- a/crates/wind-tuic/src/proto/mod.rs +++ b/crates/wind-tuic/src/proto/mod.rs @@ -26,19 +26,22 @@ pub const VER: u8 = 5; /// Helper function to decode header with better error reporting pub fn decode_header(buf: &mut BytesMut, context: &str) -> Result { - HeaderCodec.decode(buf)? + HeaderCodec + .decode(buf)? .ok_or_else(|| eyre!("Incomplete header in {}", context).into()) } /// Helper function to decode command with better error reporting pub fn decode_command(cmd_type: CmdType, buf: &mut BytesMut, context: &str) -> Result { - CmdCodec(cmd_type).decode(buf)? + CmdCodec(cmd_type) + .decode(buf)? .ok_or_else(|| eyre!("Incomplete command in {}", context).into()) } /// Helper function to decode address with better error reporting pub fn decode_address(buf: &mut BytesMut, context: &str) -> Result { - AddressCodec.decode(buf)? + AddressCodec + .decode(buf)? .ok_or_else(|| eyre!("Incomplete address in {}", context).into()) } diff --git a/crates/wind-tuic/src/proto/tests.rs b/crates/wind-tuic/src/proto/tests.rs index f8d55c6..8767181 100644 --- a/crates/wind-tuic/src/proto/tests.rs +++ b/crates/wind-tuic/src/proto/tests.rs @@ -22,7 +22,7 @@ mod test { #[test_log::test(tokio::test)] async fn hex_check_auth_encode() -> eyre::Result<()> { let auth_cmd = Command::Auth { - uuid: Uuid::from_u128(0), + uuid: Uuid::from_u128(0), token: [1u8; 32], }; let mut buf = BytesMut::with_capacity(50); diff --git a/crates/wind-tuic/src/proto/udp_stream.rs b/crates/wind-tuic/src/proto/udp_stream.rs index f44fa68..faf0b15 100644 --- a/crates/wind-tuic/src/proto/udp_stream.rs +++ b/crates/wind-tuic/src/proto/udp_stream.rs @@ -27,30 +27,30 @@ fn init_time() -> &'static Instant { /// Fragment information for reassembly struct FragmentInfo { - assoc_id: u16, - pkt_id: u16, + assoc_id: u16, + pkt_id: u16, frag_total: u8, - frag_id: u8, - source: Option, - target: TargetAddr, + frag_id: u8, + source: Option, + target: TargetAddr, } pub struct UdpStream { - connection: quinn::Connection, - assoc_id: u16, - receive_tx: MAsyncTx, - next_pkt_id: AtomicU16, // Track packet IDs for fragmentation + connection: quinn::Connection, + assoc_id: u16, + receive_tx: MAsyncTx, + next_pkt_id: AtomicU16, // Track packet IDs for fragmentation // Fragment reassembly state (wrapped in Mutex for interior mutability) fragment_buffer: FragmentReassemblyBuffer, } /// Structure to track fragments of a packet for reassembly struct FragmentMetadata { - frag_total: u8, - fragments: Cache, + frag_total: u8, + fragments: Cache, last_updated: AtomicU64, - source: ArcSwapOption, - target: ArcSwap, + source: ArcSwapOption, + target: ArcSwap, } /// Buffer for reassembling fragmented packets @@ -116,7 +116,7 @@ impl FragmentReassemblyBuffer { meta.value().fragments.run_pending_tasks().await; // Check if all fragments have been received - if meta.value().fragments.entry_count() == meta.value().frag_total.into() { + if meta.value().fragments.entry_count() == u64::from(meta.value().frag_total) { // All fragments received, reassemble the packet return self.reassemble_packet(key).await; } @@ -219,7 +219,6 @@ impl UdpStream { return Ok(()); } - self.send_fragmented_packet(packet).await } @@ -395,7 +394,6 @@ mod tests { } } - /// SPEC.md Section 8.6: Fragmentation Size Calculations #[test] fn test_fragment_count_calculation() { @@ -489,12 +487,12 @@ mod tests { let result = buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 100, + assoc_id: 1, + pkt_id: 100, frag_total: 1, - frag_id: 0, - source: None, - target: target.clone(), + frag_id: 0, + source: None, + target: target.clone(), }, payload.clone(), ) @@ -518,12 +516,12 @@ mod tests { let result1 = buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 200, + assoc_id: 1, + pkt_id: 200, frag_total: 2, - frag_id: 0, - source: None, - target: target.clone(), + frag_id: 0, + source: None, + target: target.clone(), }, frag1.clone(), ) @@ -534,12 +532,12 @@ mod tests { let result2 = buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 200, + assoc_id: 1, + pkt_id: 200, frag_total: 2, - frag_id: 1, - source: None, - target: target.clone(), + frag_id: 1, + source: None, + target: target.clone(), }, frag2.clone(), ) @@ -565,12 +563,12 @@ mod tests { buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 300, + assoc_id: 1, + pkt_id: 300, frag_total: 3, - frag_id: 2, - source: None, - target: target.clone(), + frag_id: 2, + source: None, + target: target.clone(), }, frag2.clone(), ) @@ -581,12 +579,12 @@ mod tests { buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 300, + assoc_id: 1, + pkt_id: 300, frag_total: 3, - frag_id: 0, - source: None, - target: target.clone(), + frag_id: 0, + source: None, + target: target.clone(), }, frag0.clone(), ) @@ -597,12 +595,12 @@ mod tests { let result = buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 300, + assoc_id: 1, + pkt_id: 300, frag_total: 3, - frag_id: 1, - source: None, - target: target.clone(), + frag_id: 1, + source: None, + target: target.clone(), }, frag1.clone(), ) @@ -623,12 +621,12 @@ mod tests { buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 100, + assoc_id: 1, + pkt_id: 100, frag_total: 2, - frag_id: 0, - source: None, - target: target.clone(), + frag_id: 0, + source: None, + target: target.clone(), }, Bytes::from("A1"), ) @@ -636,12 +634,12 @@ mod tests { buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 101, + assoc_id: 1, + pkt_id: 101, frag_total: 2, - frag_id: 0, - source: None, - target: target.clone(), + frag_id: 0, + source: None, + target: target.clone(), }, Bytes::from("B1"), ) @@ -651,12 +649,12 @@ mod tests { let result1 = buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 100, + assoc_id: 1, + pkt_id: 100, frag_total: 2, - frag_id: 1, - source: None, - target: target.clone(), + frag_id: 1, + source: None, + target: target.clone(), }, Bytes::from("A2"), ) @@ -668,12 +666,12 @@ mod tests { let result2 = buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 101, + assoc_id: 1, + pkt_id: 101, frag_total: 2, - frag_id: 1, - source: None, - target: target.clone(), + frag_id: 1, + source: None, + target: target.clone(), }, Bytes::from("B2"), ) @@ -692,12 +690,12 @@ mod tests { buffer .add_fragment( FragmentInfo { - assoc_id: 1, - pkt_id: 400, + assoc_id: 1, + pkt_id: 400, frag_total: 2, - frag_id: 0, - source: None, - target: target.clone(), + frag_id: 0, + source: None, + target: target.clone(), }, Bytes::from("test"), ) diff --git a/crates/wind-tuic/src/simple_udp.rs b/crates/wind-tuic/src/simple_udp.rs new file mode 100644 index 0000000..5fc5257 --- /dev/null +++ b/crates/wind-tuic/src/simple_udp.rs @@ -0,0 +1,76 @@ +//! Simplified UDP packet channel for wind-tuic +//! +//! This replaces the heavy AbstractUdpSocket trait with simple channels + +use bytes::Bytes; +use crossfire::{MAsyncRx, MAsyncTx}; +use std::net::SocketAddr; + +/// Simple UDP packet structure +#[derive(Debug, Clone)] +pub struct SimpleUdpPacket { + pub source: Option, + pub target: SocketAddr, + pub payload: Bytes, +} + +impl SimpleUdpPacket { + /// Create a new UDP packet + pub fn new(source: Option, target: SocketAddr, payload: Bytes) -> Self { + Self { source, target, payload } + } +} + +/// Simple bidirectional channel for UDP packets +pub struct SimpleUdpChannel { + /// Sender for packets going to remote + pub to_remote_tx: MAsyncTx, + /// Receiver for packets coming from remote + pub from_remote_rx: MAsyncRx, +} + +impl SimpleUdpChannel { + /// Create a new UDP channel with specified buffer size + pub fn new(buffer_size: usize) -> (Self, SimpleUdpChannelTx) { + let (to_remote_tx, to_remote_rx) = crossfire::mpmc::bounded_async::(buffer_size); + let (from_remote_tx, from_remote_rx) = crossfire::mpmc::bounded_async::(buffer_size); + + let channel = Self { + to_remote_tx, + from_remote_rx, + }; + + let tx = SimpleUdpChannelTx { + to_remote_rx, + from_remote_tx, + }; + + (channel, tx) + } +} + +/// The "other side" of the UDP channel +#[derive(Clone)] +pub struct SimpleUdpChannelTx { + /// Receiver for packets going to remote (read from here to send to TUIC) + pub to_remote_rx: MAsyncRx, + /// Sender for packets coming from remote (write here when received from TUIC) + pub from_remote_tx: MAsyncTx, +} + +impl SimpleUdpChannelTx { + /// Send a packet from remote to local + pub async fn send_from_remote(&self, packet: SimpleUdpPacket) -> Result<(), crossfire::SendError> { + self.from_remote_tx.send(packet).await + } + + /// Receive a packet from local to send to remote + pub async fn recv_to_remote(&self) -> Result { + self.to_remote_rx.recv().await + } + + /// Try to receive a packet from local (non-blocking) + pub fn try_recv_to_remote(&self) -> Result { + self.to_remote_rx.try_recv() + } +} diff --git a/crates/wind-tuic/src/task.rs b/crates/wind-tuic/src/task.rs index f521779..1c93fbc 100644 --- a/crates/wind-tuic/src/task.rs +++ b/crates/wind-tuic/src/task.rs @@ -39,7 +39,7 @@ where Err(e) => unimplemented!("unhandled error {e:?}"), Ok(item) => item, }; - + info!("Accepted new {}", name); if let Err(e) = tx.send_timeout(item, Duration::from_secs(1)).await { unimplemented!("unhandled error {e:?}"); diff --git a/crates/wind/Cargo.toml b/crates/wind/Cargo.toml index 37fc510..aa677a0 100644 --- a/crates/wind/Cargo.toml +++ b/crates/wind/Cargo.toml @@ -1,15 +1,15 @@ [package] name = "wind" -version.workspace = true -repository.workspace = true -edition.workspace = true +version = "0.1.1" +repository = "https://github.com/proxy-rs/wind" +edition = "2024" description = "A proxy tool written in Rust" license = "AGPL-3.0-or-later" [dependencies] -wind-core = { version = "0.1.1", path = "../wind-core"} -wind-socks = { version = "0.1.1", path = "../wind-socks"} -wind-tuic = { version = "0.1.1", path = "../wind-tuic"} +wind-core = { path = "../wind-core"} +wind-socks = { path = "../wind-socks"} +wind-tuic = { path = "../wind-tuic"} # Async tokio = { version = "1", features = ["rt-multi-thread", "signal", "net"] }