From a7e990fac83ece646a11dad4b35076346eaa15c0 Mon Sep 17 00:00:00 2001 From: aecsocket Date: Sat, 4 Apr 2026 16:39:00 +0100 Subject: [PATCH] Fix buffer allocations in server pinging code --- packages/app-lib/src/util/server_ping.rs | 81 ++++++++++++++++--- packages/async-minecraft-ping/src/protocol.rs | 67 ++++++++++++++- 2 files changed, 134 insertions(+), 14 deletions(-) diff --git a/packages/app-lib/src/util/server_ping.rs b/packages/app-lib/src/util/server_ping.rs index d03991fbca..59e721b137 100644 --- a/packages/app-lib/src/util/server_ping.rs +++ b/packages/app-lib/src/util/server_ping.rs @@ -8,6 +8,33 @@ use tokio::net::ToSocketAddrs; use tokio::select; use url::Url; +const MAX_MINECRAFT_STATUS_STRING_LENGTH: usize = 32_767; +const MAX_MODERN_STATUS_PACKET_LENGTH: usize = + MAX_MINECRAFT_STATUS_STRING_LENGTH + 4; +const MAX_LEGACY_STATUS_UTF16_LENGTH: usize = + MAX_MINECRAFT_STATUS_STRING_LENGTH; + +/// Ensures the length of a packet as stated by a server is not longer than a +/// hard-coded limit. +/// +/// For example, if we ping a server that says its status packet is 2 billion +/// bytes long, we don't try to allocate a 2 billion byte buffer, since that +/// will OOM our machine. +/// +/// Implemented as a function so that you can easily find callsites and see +/// where we accept unvalidated input from servers. +fn cap_length( + length: usize, + max_length: usize, + context: &'static str, +) -> Result { + if length > max_length { + return Err(ErrorKind::InputError(context.to_string()).into()); + } + + Ok(length) +} + #[derive(Deserialize, Serialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct ServerStatus { @@ -128,13 +155,11 @@ mod modern { stream.write_all(&[0x01, 0x00]).await?; stream.flush().await?; - let packet_length = varint::read(stream).await?; - if packet_length < 0 { - return Err(ErrorKind::InputError( - "Invalid status response packet length".to_string(), - ) - .into()); - } + let packet_length = cap_varint_length( + varint::read(stream).await?, + super::MAX_MODERN_STATUS_PACKET_LENGTH, + "invalid status response packet length", + )?; let mut packet_stream = stream.take(packet_length as u64); let packet_id = varint::read(&mut packet_stream).await?; @@ -144,8 +169,12 @@ mod modern { ) .into()); } - let response_length = varint::read(&mut packet_stream).await?; - let mut json_response = vec![0_u8; response_length as usize]; + let response_length = cap_varint_length( + varint::read(&mut packet_stream).await?, + super::MAX_MINECRAFT_STATUS_STRING_LENGTH, + "invalid status response length", + )?; + let mut json_response = vec![0_u8; response_length]; packet_stream.read_exact(&mut json_response).await?; if packet_stream.limit() > 0 { @@ -155,6 +184,27 @@ mod modern { Ok(serde_json::from_slice(&json_response)?) } + /// Ensures the length of a varint as stated by a server is not longer than a + /// hard-coded limit. + /// + /// For example, if we ping a server that says its status packet is 2 billion + /// bytes long, we don't try to allocate a 2 billion byte buffer, since that + /// will OOM our machine. + /// + /// Implemented as a function so that you can easily find callsites and see + /// where we accept unvalidated input from servers. + fn cap_varint_length( + length: i32, + max_length: usize, + context: &'static str, + ) -> crate::Result { + if length < 0 { + return Err(ErrorKind::InputError(context.to_string()).into()); + } + + super::cap_length(length as usize, max_length, context) + } + async fn ping(stream: &mut TcpStream) -> crate::Result { let ping_magic = chrono::Utc::now().timestamp_millis(); @@ -275,8 +325,17 @@ mod legacy { ))); } - let data_length = stream.read_u16().await?; - let mut data = vec![0u8; data_length as usize * 2]; + let data_length = super::cap_length( + stream.read_u16().await? as usize, + super::MAX_LEGACY_STATUS_UTF16_LENGTH, + "invalid legacy status response length", + )?; + let data_byte_length = data_length.checked_mul(2).ok_or_else(|| { + ErrorKind::InputError( + "invalid legacy status response length".to_string(), + ) + })?; + let mut data = vec![0u8; data_byte_length]; stream.read_exact(&mut data).await?; drop(stream); diff --git a/packages/async-minecraft-ping/src/protocol.rs b/packages/async-minecraft-ping/src/protocol.rs index b161e3515d..e758fce515 100644 --- a/packages/async-minecraft-ping/src/protocol.rs +++ b/packages/async-minecraft-ping/src/protocol.rs @@ -31,6 +31,27 @@ pub enum ProtocolError { Timeout(#[from] tokio::time::error::Elapsed), } +const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767; +const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771; +const MAX_PONG_PACKET_LENGTH: usize = 9; + +/// Ensures the length of a packet as stated by a server is not longer than a +/// hard-coded limit. +/// +/// For example, if we ping a server that says its status packet is 2 billion +/// bytes long, we don't try to allocate a 2 billion byte buffer, since that +/// will OOM our machine. +/// +/// Implemented as a function so that you can easily find callsites and see +/// where we accept unvalidated input from servers. +fn cap_length(length: usize, max_length: usize) -> Result { + if length > max_length { + return Err(ProtocolError::InvalidPacketLength); + } + + Ok(length) +} + /// State represents the desired next state of the /// exchange. /// @@ -98,7 +119,7 @@ impl AsyncWireReadExt for R { } async fn read_string(&mut self) -> Result { - let length = self.read_varint().await?; + let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?; let mut buffer = vec![0; length]; self.read_exact(&mut buffer).await?; @@ -157,6 +178,7 @@ pub trait PacketId { /// to generically get a packet's expected ID. pub trait ExpectedPacketId { fn get_expected_packet_id() -> usize; + fn get_max_packet_length() -> usize; } /// AsyncReadFromBuffer is used to allow @@ -196,7 +218,7 @@ impl AsyncReadRawPacket for R { async fn read_packet( &mut self, ) -> Result { - let length = self.read_varint().await?; + let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?; if length == 0 { return Err(ProtocolError::InvalidPacketLength); @@ -213,7 +235,10 @@ impl AsyncReadRawPacket for R { }); } - let mut buffer = vec![0; length - 1]; + let payload_length = length + .checked_sub(1) + .ok_or(ProtocolError::InvalidPacketLength)?; + let mut buffer = vec![0; payload_length]; self.read_exact(&mut buffer).await?; T::read_from_buffer(buffer).await @@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket { fn get_expected_packet_id() -> usize { 0 } + + fn get_max_packet_length() -> usize { + MAX_STATUS_RESPONSE_PACKET_LENGTH + } } #[async_trait] @@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket { fn get_expected_packet_id() -> usize { 1 } + + fn get_max_packet_length() -> usize { + MAX_PONG_PACKET_LENGTH + } } #[async_trait] @@ -573,4 +606,32 @@ mod tests { let result = reader.read_varint().await; assert!(matches!(result, Err(ProtocolError::InvalidVarInt))); } + + #[tokio::test] + async fn test_oversized_string_length_is_rejected() { + let mut writer = Cursor::new(Vec::new()); + writer + .write_varint(MAX_MINECRAFT_STRING_LENGTH + 1) + .await + .unwrap(); + + let mut reader = Cursor::new(writer.into_inner()); + let result = reader.read_string().await; + + assert!(matches!(result, Err(ProtocolError::InvalidPacketLength))); + } + + #[tokio::test] + async fn test_oversized_packet_length_is_rejected() { + let mut writer = Cursor::new(Vec::new()); + writer + .write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1) + .await + .unwrap(); + + let mut reader = Cursor::new(writer.into_inner()); + let result: Result = reader.read_packet().await; + + assert!(matches!(result, Err(ProtocolError::InvalidPacketLength))); + } }