use super::*; use range_set_blaze::RangeSetBlaze; use std::io::{Error, ErrorKind}; use std::sync::atomic::{AtomicU16, Ordering}; // AssemblyBuffer Version 1 properties const VERSION_1: u8 = 1; type LengthType = u16; type SequenceType = u16; const HEADER_LEN: usize = 8; const MAX_LEN: usize = LengthType::MAX as usize; // XXX: keep statistics on all drops and why we dropped them // XXX: move to config eventually? pub const FRAGMENT_LEN: usize = 1280 - HEADER_LEN; const MAX_CONCURRENT_HOSTS: usize = 256; const MAX_ASSEMBLIES_PER_HOST: usize = 256; const MAX_BUFFER_PER_HOST: usize = 256 * 1024; const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000; ///////////////////////////////////////////////////////// #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] struct PeerKey { remote_addr: SocketAddr, } #[derive(Clone, Eq, PartialEq)] struct MessageAssembly { timestamp: u64, seq: SequenceType, data: Vec, parts: RangeSetBlaze, } #[derive(Clone, Eq, PartialEq)] struct PeerMessages { total_buffer: usize, assemblies: VecDeque, } impl PeerMessages { pub fn new() -> Self { Self { total_buffer: 0, assemblies: VecDeque::new(), } } fn merge_in_data( &mut self, timestamp: u64, ass: usize, off: LengthType, len: LengthType, chunk: &[u8], ) -> bool { let assembly = &mut self.assemblies[ass]; // Ensure the new fragment hasn't redefined the message length, reusing the same seq if assembly.data.len() != len as usize { // Drop the assembly and just go with the new fragment as starting a new assembly let seq = assembly.seq; self.remove_assembly(ass); self.new_assembly(timestamp, seq, off, len, chunk); return false; } let part_start = off; let part_end = off + chunk.len() as LengthType - 1; let part = RangeSetBlaze::from_iter([part_start..=part_end]); // if fragments overlap, drop the old assembly and go with a new one if !assembly.parts.is_disjoint(&part) { let seq = assembly.seq; self.remove_assembly(ass); self.new_assembly(timestamp, seq, off, len, chunk); return false; } // Merge part assembly.parts |= part; assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk); // Check to see if this part is done if assembly.parts.ranges_len() == 1 && assembly.parts.first().unwrap() == 0 && assembly.parts.last().unwrap() == len - 1 { return true; } false } fn new_assembly( &mut self, timestamp: u64, seq: SequenceType, off: LengthType, len: LengthType, chunk: &[u8], ) -> usize { // ensure we have enough space for the new assembly self.reclaim_space(len as usize); // make the assembly let part_start = off; let part_end = off + chunk.len() as LengthType - 1; let mut assembly = MessageAssembly { timestamp, seq, data: vec![0u8; len as usize], parts: RangeSetBlaze::from_iter([part_start..=part_end]), }; assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk); // Add the buffer length in self.total_buffer += assembly.data.len(); self.assemblies.push_front(assembly); // Was pushed front, return the front index 0 } fn remove_assembly(&mut self, index: usize) -> MessageAssembly { let assembly = self.assemblies.remove(index).unwrap(); self.total_buffer -= assembly.data.len(); assembly } fn truncate_assemblies(&mut self, new_len: usize) { for an in new_len..self.assemblies.len() { self.total_buffer -= self.assemblies[an].data.len(); } self.assemblies.truncate(new_len); } fn reclaim_space(&mut self, needed_space: usize) { // If we have too many assemblies or too much buffer rotate some out while self.assemblies.len() > (MAX_ASSEMBLIES_PER_HOST - 1) || self.total_buffer > (MAX_BUFFER_PER_HOST - needed_space) { self.remove_assembly(self.assemblies.len() - 1); } } pub fn insert_fragment( &mut self, seq: SequenceType, off: LengthType, len: LengthType, chunk: &[u8], ) -> Option> { // Get the current timestamp let cur_ts = get_timestamp(); // Get the assembly this belongs to by its sequence number let mut ass = None; for an in 0..self.assemblies.len() { // If this assembly's timestamp is too old, then everything after it will be too, drop em all let age = cur_ts.saturating_sub(self.assemblies[an].timestamp); if age > MAX_ASSEMBLY_AGE_US { self.truncate_assemblies(an); break; } // If this assembly has a matching seq, then assemble with it if self.assemblies[an].seq == seq { ass = Some(an); } } if ass.is_none() { // Add a new assembly to the front and return the first index self.new_assembly(cur_ts, seq, off, len, chunk); return None; } let ass = ass.unwrap(); // Now that we have an assembly, merge in the fragment let done = self.merge_in_data(cur_ts, ass, off, len, chunk); // If the assembly is now equal to the entire range, then return it if done { let assembly = self.remove_assembly(ass); return Some(assembly.data); } // Otherwise, do nothing None } } ///////////////////////////////////////////////////////// struct AssemblyBufferInner { peer_message_map: HashMap, } struct AssemblyBufferUnlockedInner { outbound_lock_table: AsyncTagLockTable, next_seq: AtomicU16, } /// Packet reassembly and fragmentation handler /// No retry, no acknowledgment, no flow control /// Just trying to survive lower path MTU for larger messages #[derive(Clone)] pub struct AssemblyBuffer { inner: Arc>, unlocked_inner: Arc, } impl AssemblyBuffer { fn new_unlocked_inner() -> AssemblyBufferUnlockedInner { AssemblyBufferUnlockedInner { outbound_lock_table: AsyncTagLockTable::new(), next_seq: AtomicU16::new(0), } } fn new_inner() -> AssemblyBufferInner { AssemblyBufferInner { peer_message_map: HashMap::new(), } } pub fn new() -> Self { Self { inner: Arc::new(Mutex::new(Self::new_inner())), unlocked_inner: Arc::new(Self::new_unlocked_inner()), } } /// Receive a packet chunk and add to the message assembly /// if a message has been completely, return it pub fn insert_frame( &self, frame: &[u8], remote_addr: SocketAddr, ) -> NetworkResult>> { // If we receive a zero length frame, send it if frame.len() == 0 { return NetworkResult::value(Some(frame.to_vec())); } // If we receive a frame smaller than or equal to the length of the header, drop it // or if this frame is larger than our max message length, then drop it if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN { #[cfg(feature = "network-result-extra")] return NetworkResult::invalid_message(format!( "invalid header length: frame.len={}", frame.len() )); #[cfg(not(feature = "network-result-extra"))] return NetworkResult::invalid_message("invalid header length"); } // --- Decode the header // Drop versions we don't understand if frame[0] != VERSION_1 { #[cfg(feature = "network-result-extra")] return NetworkResult::invalid_message(format!( "invalid frame version: frame[0]={}", frame[0] )); #[cfg(not(feature = "network-result-extra"))] return NetworkResult::invalid_message("invalid frame version"); } // Version 1 header let seq = SequenceType::from_be_bytes(frame[2..4].try_into().unwrap()); let off = LengthType::from_be_bytes(frame[4..6].try_into().unwrap()); let len = LengthType::from_be_bytes(frame[6..HEADER_LEN].try_into().unwrap()); let chunk = &frame[HEADER_LEN..]; // See if we have a whole message and not a fragment if off == 0 && len as usize == chunk.len() { return NetworkResult::value(Some(chunk.to_vec())); } // Drop fragments with offsets greater than or equal to the message length if off >= len { #[cfg(feature = "network-result-extra")] return NetworkResult::invalid_message(format!( "offset greater than length: off={} >= len={}", off, len )); #[cfg(not(feature = "network-result-extra"))] return NetworkResult::invalid_message("offset greater than length"); } // Drop fragments where the chunk would be applied beyond the message length if off as usize + chunk.len() > len as usize { #[cfg(feature = "network-result-extra")] return NetworkResult::invalid_message(format!( "chunk applied beyond message length: off={} + chunk.len={} > len={}", off, chunk.len(), len )); #[cfg(not(feature = "network-result-extra"))] return NetworkResult::invalid_message("chunk applied beyond message length"); } // Get or create the peer message assemblies // and drop the packet if we have too many peers let mut inner = self.inner.lock(); let peer_key = PeerKey { remote_addr }; let peer_count = inner.peer_message_map.len(); match inner.peer_message_map.entry(peer_key) { std::collections::hash_map::Entry::Occupied(mut e) => { let peer_messages = e.get_mut(); // Insert the fragment and see what comes out let out = peer_messages.insert_fragment(seq, off, len, chunk); // If we are returning a message, see if there are any more assemblies for this peer // If not, remove the peer if out.is_some() { if peer_messages.assemblies.len() == 0 { e.remove(); } } NetworkResult::value(out) } std::collections::hash_map::Entry::Vacant(v) => { // See if we have room for one more if peer_count == MAX_CONCURRENT_HOSTS { return NetworkResult::value(None); } // Add the peer let peer_messages = v.insert(PeerMessages::new()); // Insert the fragment and see what comes out NetworkResult::value(peer_messages.insert_fragment(seq, off, len, chunk)) } } } /// Add framing to chunk to send to the wire fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec { assert!(chunk.len() > 0); assert!(message_len <= MAX_LEN); assert!(offset + chunk.len() <= message_len); let off: LengthType = offset as LengthType; let len: LengthType = message_len as LengthType; unsafe { // Uninitialized vector, careful! let mut out = unaligned_u8_vec_uninit(chunk.len() + HEADER_LEN); // Write out header out[0] = VERSION_1; out[1] = 0; // reserved out[2..4].copy_from_slice(&seq.to_be_bytes()); // sequence number out[4..6].copy_from_slice(&off.to_be_bytes()); // offset of chunk inside message out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message // Write out body out[HEADER_LEN..].copy_from_slice(chunk); out } } /// Split a message into packets and send them serially, ensuring /// that they are sent consecutively to a particular remote address, /// never interleaving packets from one message and other to minimize reassembly problems pub async fn split_message( &self, data: Vec, remote_addr: SocketAddr, mut sender: S, ) -> std::io::Result> where S: FnMut(Vec, SocketAddr) -> F, F: Future>>, { if data.len() > MAX_LEN { return Err(Error::from(ErrorKind::InvalidData)); } // Do not frame or split anything zero bytes long, just send it if data.len() == 0 { return sender(data, remote_addr).await; } // Lock per remote addr let _tag_lock = self .unlocked_inner .outbound_lock_table .lock_tag(remote_addr) .await; // Get a message seq let seq = self.unlocked_inner.next_seq.fetch_add(1, Ordering::Relaxed); // Chunk it up let mut offset = 0usize; let message_len = data.len(); for chunk in data.chunks(FRAGMENT_LEN) { // Frame chunk let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq); // Send chunk network_result_try!(sender(framed_chunk, remote_addr).await?); // Go to next chunk offset += chunk.len() } Ok(NetworkResult::value(())) } }