frag work
This commit is contained in:
		@@ -196,7 +196,6 @@ impl AttachmentManager {
 | 
			
		||||
            if let Err(err) = netman.startup().await {
 | 
			
		||||
                error!("network startup failed: {}", err);
 | 
			
		||||
                netman.shutdown().await;
 | 
			
		||||
                restart = true;
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,7 @@ const VERSION_1: u8 = 1;
 | 
			
		||||
type LengthType = u16;
 | 
			
		||||
type SequenceType = u16;
 | 
			
		||||
const HEADER_LEN: usize = 8;
 | 
			
		||||
const MAX_MESSAGE_LEN: usize = LengthType::MAX as usize;
 | 
			
		||||
const MAX_LEN: usize = LengthType::MAX as usize;
 | 
			
		||||
 | 
			
		||||
// XXX: keep statistics on all drops and why we dropped them
 | 
			
		||||
// XXX: move to config
 | 
			
		||||
@@ -16,14 +16,10 @@ const FRAGMENT_LEN: usize = 1280 - HEADER_LEN;
 | 
			
		||||
const MAX_CONCURRENT_HOSTS: usize = 256;
 | 
			
		||||
const MAX_ASSEMBLIES_PER_HOST: usize = 256;
 | 
			
		||||
const MAX_BUFFER_PER_HOST: usize = 256 * 1024;
 | 
			
		||||
const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000;
 | 
			
		||||
 | 
			
		||||
/////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
pub struct Message {
 | 
			
		||||
    data: Vec<u8>,
 | 
			
		||||
    remote_addr: SocketAddr,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
 | 
			
		||||
struct PeerKey {
 | 
			
		||||
    remote_addr: SocketAddr,
 | 
			
		||||
@@ -31,6 +27,7 @@ struct PeerKey {
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Eq, PartialEq)]
 | 
			
		||||
struct MessageAssembly {
 | 
			
		||||
    timestamp: Timestamp,
 | 
			
		||||
    seq: SequenceType,
 | 
			
		||||
    data: Vec<u8>,
 | 
			
		||||
    parts: RangeSetBlaze<LengthType>,
 | 
			
		||||
@@ -38,15 +35,29 @@ struct MessageAssembly {
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Eq, PartialEq)]
 | 
			
		||||
struct PeerMessages {
 | 
			
		||||
    assemblies: Vec<MessageAssembly>,
 | 
			
		||||
    assemblies: LinkedList<MessageAssembly>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl PeerMessages {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            assemblies: Vec::new(),
 | 
			
		||||
            assemblies: LinkedList::new(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn insert_fragment(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        seq: SequenceType,
 | 
			
		||||
        off: LengthType,
 | 
			
		||||
        len: LengthType,
 | 
			
		||||
        chunk: &[u8],
 | 
			
		||||
    ) -> Option<Vec<u8>> {
 | 
			
		||||
        // Get the current timestamp
 | 
			
		||||
        let cur_ts = get_timestamp();
 | 
			
		||||
 | 
			
		||||
        // Get the assembly this belongs to by its sequence number
 | 
			
		||||
        for a in self.assemblies {}
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/////////////////////////////////////////////////////////
 | 
			
		||||
@@ -70,19 +81,19 @@ pub struct AssemblyBuffer {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AssemblyBuffer {
 | 
			
		||||
    pub fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
 | 
			
		||||
    fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
 | 
			
		||||
        AssemblyBufferUnlockedInner {
 | 
			
		||||
            outbound_lock_table: AsyncTagLockTable::new(),
 | 
			
		||||
            next_seq: AtomicU16::new(0),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn new_inner() -> AssemblyBufferInner {
 | 
			
		||||
    fn new_inner() -> AssemblyBufferInner {
 | 
			
		||||
        AssemblyBufferInner {
 | 
			
		||||
            peer_message_map: HashMap::new(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn new(frag_len: usize) -> Self {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            inner: Arc::new(Mutex::new(Self::new_inner())),
 | 
			
		||||
            unlocked_inner: Arc::new(Self::new_unlocked_inner()),
 | 
			
		||||
@@ -91,18 +102,15 @@ impl AssemblyBuffer {
 | 
			
		||||
 | 
			
		||||
    /// Receive a packet chunk and add to the message assembly
 | 
			
		||||
    /// if a message has been completely, return it
 | 
			
		||||
    pub fn receive_packet(&self, frame: &[u8], remote_addr: SocketAddr) -> Option<Message> {
 | 
			
		||||
    pub fn insert_frame(&self, frame: &[u8], remote_addr: SocketAddr) -> Option<Vec<u8>> {
 | 
			
		||||
        // If we receive a zero length frame, send it
 | 
			
		||||
        if frame.len() == 0 {
 | 
			
		||||
            return Some(Message {
 | 
			
		||||
                data: frame.to_vec(),
 | 
			
		||||
                remote_addr,
 | 
			
		||||
            });
 | 
			
		||||
            return Some(frame.to_vec());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // If we receive a frame smaller than or equal to the length of the header, drop it
 | 
			
		||||
        // or if this frame is larger than our max message length, then drop it
 | 
			
		||||
        if frame.len() <= HEADER_LEN || frame.len() > MAX_MESSAGE_LEN {
 | 
			
		||||
        if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN {
 | 
			
		||||
            return None;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -120,10 +128,7 @@ impl AssemblyBuffer {
 | 
			
		||||
 | 
			
		||||
        // See if we have a whole message and not a fragment
 | 
			
		||||
        if off == 0 && len as usize == chunk.len() {
 | 
			
		||||
            return Some(Message {
 | 
			
		||||
                data: frame.to_vec(),
 | 
			
		||||
                remote_addr,
 | 
			
		||||
            });
 | 
			
		||||
            return Some(frame.to_vec());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Drop fragments with offsets greater than or equal to the message length
 | 
			
		||||
@@ -139,25 +144,32 @@ impl AssemblyBuffer {
 | 
			
		||||
        // and drop the packet if we have too many peers
 | 
			
		||||
        let mut inner = self.inner.lock();
 | 
			
		||||
        let peer_key = PeerKey { remote_addr };
 | 
			
		||||
        let peer_messages = match inner.peer_message_map.entry(peer_key) {
 | 
			
		||||
            std::collections::hash_map::Entry::Occupied(e) => e.get_mut(),
 | 
			
		||||
        let peer_count = inner.peer_message_map.len();
 | 
			
		||||
        match inner.peer_message_map.entry(peer_key) {
 | 
			
		||||
            std::collections::hash_map::Entry::Occupied(mut e) => {
 | 
			
		||||
                let peer_messages = e.get_mut();
 | 
			
		||||
 | 
			
		||||
                // Insert the fragment and see what comes out
 | 
			
		||||
                peer_messages.insert_fragment(seq, off, len, chunk)
 | 
			
		||||
            }
 | 
			
		||||
            std::collections::hash_map::Entry::Vacant(v) => {
 | 
			
		||||
                // See if we have room for one more
 | 
			
		||||
                if inner.peer_message_map.len() == MAX_CONCURRENT_HOSTS {
 | 
			
		||||
                if peer_count == MAX_CONCURRENT_HOSTS {
 | 
			
		||||
                    return None;
 | 
			
		||||
                }
 | 
			
		||||
                // Add the peer
 | 
			
		||||
                v.insert(PeerMessages::new())
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
                let peer_messages = v.insert(PeerMessages::new());
 | 
			
		||||
 | 
			
		||||
        None
 | 
			
		||||
                // Insert the fragment and see what comes out
 | 
			
		||||
                peer_messages.insert_fragment(seq, off, len, chunk)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Add framing to chunk to send to the wire
 | 
			
		||||
    fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> {
 | 
			
		||||
        assert!(chunk.len() > 0);
 | 
			
		||||
        assert!(message_len <= MAX_MESSAGE_LEN);
 | 
			
		||||
        assert!(message_len <= MAX_LEN);
 | 
			
		||||
        assert!(offset + chunk.len() <= message_len);
 | 
			
		||||
 | 
			
		||||
        let off: LengthType = offset as LengthType;
 | 
			
		||||
@@ -175,7 +187,7 @@ impl AssemblyBuffer {
 | 
			
		||||
            out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message
 | 
			
		||||
 | 
			
		||||
            // Write out body
 | 
			
		||||
            out[HEADER_LEN..out.len()].copy_from_slice(chunk);
 | 
			
		||||
            out[HEADER_LEN..].copy_from_slice(chunk);
 | 
			
		||||
            out
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@@ -183,25 +195,30 @@ impl AssemblyBuffer {
 | 
			
		||||
    /// Split a message into packets and send them serially, ensuring
 | 
			
		||||
    /// that they are sent consecutively to a particular remote address,
 | 
			
		||||
    /// never interleaving packets from one message and other to minimize reassembly problems
 | 
			
		||||
    pub async fn split_message<F>(&self, message: Message, sender: F) -> std::io::Result<()>
 | 
			
		||||
    pub async fn split_message<S, F>(
 | 
			
		||||
        &self,
 | 
			
		||||
        data: Vec<u8>,
 | 
			
		||||
        remote_addr: SocketAddr,
 | 
			
		||||
        sender: S,
 | 
			
		||||
    ) -> std::io::Result<NetworkResult<()>>
 | 
			
		||||
    where
 | 
			
		||||
        F: Fn(Vec<u8>, SocketAddr) -> SendPinBoxFuture<std::io::Result<()>>,
 | 
			
		||||
        S: Fn(Vec<u8>, SocketAddr) -> F,
 | 
			
		||||
        F: Future<Output = std::io::Result<NetworkResult<()>>>,
 | 
			
		||||
    {
 | 
			
		||||
        if message.data.len() > MAX_MESSAGE_LEN {
 | 
			
		||||
        if data.len() > MAX_LEN {
 | 
			
		||||
            return Err(Error::from(ErrorKind::InvalidData));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Do not frame or split anything zero bytes long, just send it
 | 
			
		||||
        if message.data.len() == 0 {
 | 
			
		||||
            sender(message.data, message.remote_addr).await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        if data.len() == 0 {
 | 
			
		||||
            return sender(data, remote_addr).await;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Lock per remote addr
 | 
			
		||||
        let _tag_lock = self
 | 
			
		||||
            .unlocked_inner
 | 
			
		||||
            .outbound_lock_table
 | 
			
		||||
            .lock_tag(message.remote_addr)
 | 
			
		||||
            .lock_tag(remote_addr)
 | 
			
		||||
            .await;
 | 
			
		||||
 | 
			
		||||
        // Get a message seq
 | 
			
		||||
@@ -209,16 +226,16 @@ impl AssemblyBuffer {
 | 
			
		||||
 | 
			
		||||
        // Chunk it up
 | 
			
		||||
        let mut offset = 0usize;
 | 
			
		||||
        let message_len = message.data.len();
 | 
			
		||||
        for chunk in message.data.chunks(FRAGMENT_LEN) {
 | 
			
		||||
        let message_len = data.len();
 | 
			
		||||
        for chunk in data.chunks(FRAGMENT_LEN) {
 | 
			
		||||
            // Frame chunk
 | 
			
		||||
            let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq);
 | 
			
		||||
            // Send chunk
 | 
			
		||||
            sender(framed_chunk, message.remote_addr).await?;
 | 
			
		||||
            network_result_try!(sender(framed_chunk, remote_addr).await?);
 | 
			
		||||
            // Go to next chunk
 | 
			
		||||
            offset += chunk.len()
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
        Ok(NetworkResult::value(()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -56,11 +56,11 @@ impl RawTcpNetworkConnection {
 | 
			
		||||
        stream.flush().await.into_network_result()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
 | 
			
		||||
    //#[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
 | 
			
		||||
    pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
 | 
			
		||||
        let mut stream = self.stream.clone();
 | 
			
		||||
        let out = Self::send_internal(&mut stream, message).await?;
 | 
			
		||||
        tracing::Span::current().record("network_result", &tracing::field::display(&out));
 | 
			
		||||
        //tracing::Span::current().record("network_result", &tracing::field::display(&out));
 | 
			
		||||
        Ok(out)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -18,13 +18,25 @@ impl RawUdpProtocolHandler {
 | 
			
		||||
 | 
			
		||||
    // #[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))]
 | 
			
		||||
    pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> {
 | 
			
		||||
        let (size, descriptor) = loop {
 | 
			
		||||
        let (message_len, descriptor) = loop {
 | 
			
		||||
            // Get a packet
 | 
			
		||||
            let (size, remote_addr) = network_result_value_or_log!(self.socket.recv_from(data).await.into_network_result()? => continue);
 | 
			
		||||
            if size > MAX_MESSAGE_SIZE {
 | 
			
		||||
 | 
			
		||||
            // Insert into assembly buffer
 | 
			
		||||
            let Some(message) = self.assembly_buffer.insert_frame(&data[0..size], remote_addr) else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            // Check length of reassembled message (same for all protocols)
 | 
			
		||||
            if message.len() > MAX_MESSAGE_SIZE {
 | 
			
		||||
                log_net!(debug "{}({}) at {}@{}:{}", "Invalid message".green(), "received too large UDP message", file!(), line!(), column!());
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Copy assemble message out if we got one
 | 
			
		||||
            data[0..message.len()].copy_from_slice(&message);
 | 
			
		||||
 | 
			
		||||
            // Return a connection descriptor and the amount of data in the message
 | 
			
		||||
            let peer_addr = PeerAddress::new(
 | 
			
		||||
                SocketAddress::from_socket_addr(remote_addr),
 | 
			
		||||
                ProtocolType::UDP,
 | 
			
		||||
@@ -35,25 +47,46 @@ impl RawUdpProtocolHandler {
 | 
			
		||||
                SocketAddress::from_socket_addr(local_socket_addr),
 | 
			
		||||
            );
 | 
			
		||||
 | 
			
		||||
            break (size, descriptor);
 | 
			
		||||
            break (message.len(), descriptor);
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // tracing::Span::current().record("ret.len", &size);
 | 
			
		||||
        // tracing::Span::current().record("ret.len", &message_len);
 | 
			
		||||
        // tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
 | 
			
		||||
        Ok((size, descriptor))
 | 
			
		||||
        Ok((message_len, descriptor))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.descriptor))]
 | 
			
		||||
    //#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.descriptor))]
 | 
			
		||||
    pub async fn send_message(
 | 
			
		||||
        &self,
 | 
			
		||||
        data: Vec<u8>,
 | 
			
		||||
        socket_addr: SocketAddr,
 | 
			
		||||
        remote_addr: SocketAddr,
 | 
			
		||||
    ) -> io::Result<NetworkResult<ConnectionDescriptor>> {
 | 
			
		||||
        if data.len() > MAX_MESSAGE_SIZE {
 | 
			
		||||
            bail_io_error_other!("sending too large UDP message");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Fragment and send
 | 
			
		||||
        let sender = |framed_chunk: Vec<u8>, remote_addr: SocketAddr| async move {
 | 
			
		||||
            let len = network_result_try!(self
 | 
			
		||||
                .socket
 | 
			
		||||
                .send_to(&framed_chunk, remote_addr)
 | 
			
		||||
                .await
 | 
			
		||||
                .into_network_result()?);
 | 
			
		||||
            if len != framed_chunk.len() {
 | 
			
		||||
                bail_io_error_other!("UDP partial send")
 | 
			
		||||
            }
 | 
			
		||||
            Ok(NetworkResult::value(()))
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        network_result_try!(
 | 
			
		||||
            self.assembly_buffer
 | 
			
		||||
                .split_message(data, remote_addr, sender)
 | 
			
		||||
                .await?
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        // Return a connection descriptor for the sent message
 | 
			
		||||
        let peer_addr = PeerAddress::new(
 | 
			
		||||
            SocketAddress::from_socket_addr(socket_addr),
 | 
			
		||||
            SocketAddress::from_socket_addr(remote_addr),
 | 
			
		||||
            ProtocolType::UDP,
 | 
			
		||||
        );
 | 
			
		||||
        let local_socket_addr = self.socket.local_addr()?;
 | 
			
		||||
@@ -63,17 +96,7 @@ impl RawUdpProtocolHandler {
 | 
			
		||||
            SocketAddress::from_socket_addr(local_socket_addr),
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let len = network_result_try!(self
 | 
			
		||||
            .socket
 | 
			
		||||
            .send_to(&data, socket_addr)
 | 
			
		||||
            .await
 | 
			
		||||
            .into_network_result()?);
 | 
			
		||||
        if len != data.len() {
 | 
			
		||||
            bail_io_error_other!("UDP partial send")
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        tracing::Span::current().record("ret.len", &len);
 | 
			
		||||
        tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
 | 
			
		||||
        // tracing::Span::current().record("ret.descriptor", &format!("{:?}", descriptor).as_str());
 | 
			
		||||
        Ok(NetworkResult::value(descriptor))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -72,7 +72,7 @@ where
 | 
			
		||||
    //         .map_err(to_io_error_other)
 | 
			
		||||
    // }
 | 
			
		||||
 | 
			
		||||
    #[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
 | 
			
		||||
    //#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
 | 
			
		||||
    pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
 | 
			
		||||
        if message.len() > MAX_MESSAGE_SIZE {
 | 
			
		||||
            bail_io_error_other!("received too large WS message");
 | 
			
		||||
@@ -89,7 +89,7 @@ where
 | 
			
		||||
            Ok(v) => NetworkResult::value(v),
 | 
			
		||||
            Err(e) => err_to_network_result(e),
 | 
			
		||||
        };
 | 
			
		||||
        tracing::Span::current().record("network_result", &tracing::field::display(&out));
 | 
			
		||||
        //tracing::Span::current().record("network_result", &tracing::field::display(&out));
 | 
			
		||||
        Ok(out)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -64,7 +64,7 @@ impl WebsocketNetworkConnection {
 | 
			
		||||
    //     self.inner.ws_meta.close().await.map_err(to_io).map(drop)
 | 
			
		||||
    // }
 | 
			
		||||
 | 
			
		||||
    #[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
 | 
			
		||||
    //#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
 | 
			
		||||
    pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
 | 
			
		||||
        if message.len() > MAX_MESSAGE_SIZE {
 | 
			
		||||
            bail_io_error_other!("sending too large WS message");
 | 
			
		||||
@@ -79,7 +79,7 @@ impl WebsocketNetworkConnection {
 | 
			
		||||
        .map_err(to_io)
 | 
			
		||||
        .into_network_result()?;
 | 
			
		||||
 | 
			
		||||
        tracing::Span::current().record("network_result", &tracing::field::display(&out));
 | 
			
		||||
        //tracing::Span::current().record("network_result", &tracing::field::display(&out));
 | 
			
		||||
        Ok(out)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user