diff --git a/veilid-core/src/dht/key.rs b/veilid-core/src/dht/key.rs index 3b9cd313..7d3ff345 100644 --- a/veilid-core/src/dht/key.rs +++ b/veilid-core/src/dht/key.rs @@ -3,7 +3,6 @@ use core::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd}; use core::convert::{TryFrom, TryInto}; use core::fmt; use core::hash::{Hash, Hasher}; -use hex; use crate::veilid_rng::*; use ed25519_dalek::{Keypair, PublicKey, Signature}; diff --git a/veilid-core/src/network_manager/connection_handle.rs b/veilid-core/src/network_manager/connection_handle.rs new file mode 100644 index 00000000..ac5b6389 --- /dev/null +++ b/veilid-core/src/network_manager/connection_handle.rs @@ -0,0 +1,38 @@ +use super::*; + +#[derive(Clone, Debug)] +pub struct ConnectionHandle { + descriptor: ConnectionDescriptor, + channel: flume::Sender>, +} + +impl ConnectionHandle { + pub(super) fn new(descriptor: ConnectionDescriptor, channel: flume::Sender>) -> Self { + Self { + descriptor, + channel, + } + } + + pub fn connection_descriptor(&self) -> ConnectionDescriptor { + self.descriptor.clone() + } + + pub fn send(&self, message: Vec) -> Result<(), String> { + self.channel.send(message).map_err(map_to_string) + } + pub async fn send_async(&self, message: Vec) -> Result<(), String> { + self.channel + .send_async(message) + .await + .map_err(map_to_string) + } +} + +impl PartialEq for ConnectionHandle { + fn eq(&self, other: &Self) -> bool { + self.descriptor == other.descriptor + } +} + +impl Eq for ConnectionHandle {} diff --git a/veilid-core/src/network_manager/connection_limits.rs b/veilid-core/src/network_manager/connection_limits.rs index 5d9f9803..451ba173 100644 --- a/veilid-core/src/network_manager/connection_limits.rs +++ b/veilid-core/src/network_manager/connection_limits.rs @@ -1,7 +1,5 @@ -use crate::xx::*; -use crate::*; +use super::*; use alloc::collections::btree_map::Entry; -use core::fmt; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AddressFilterError { diff --git a/veilid-core/src/network_manager/connection_manager.rs b/veilid-core/src/network_manager/connection_manager.rs index 5af9f094..c36c4fa8 100644 --- a/veilid-core/src/network_manager/connection_manager.rs +++ b/veilid-core/src/network_manager/connection_manager.rs @@ -3,8 +3,6 @@ use crate::xx::*; use connection_table::*; use network_connection::*; -const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize; - /////////////////////////////////////////////////////////// // Connection manager @@ -59,8 +57,7 @@ impl ConnectionManager { } pub async fn shutdown(&self) { - // xxx close all connections in the connection table - + // Drops connection table, which drops all connections in it *self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config()); } @@ -68,47 +65,48 @@ impl ConnectionManager { pub async fn get_connection( &self, descriptor: ConnectionDescriptor, - ) -> Option { + ) -> Option { let mut inner = self.arc.inner.lock().await; inner.connection_table.get_connection(descriptor) } - // Internal routine to register new connection atomically - fn on_new_connection_internal( + // Internal routine to register new connection atomically. + // Registers connection in the connection table for later access + // and spawns a message processing loop for the connection + fn on_new_protocol_network_connection( &self, inner: &mut ConnectionManagerInner, - conn: NetworkConnection, - ) -> Result<(), String> { - log_net!("on_new_connection_internal: {:?}", conn); - let tx = inner - .connection_add_channel_tx - .as_ref() - .ok_or_else(fn_string!("connection channel isn't open yet"))? - .clone(); + conn: ProtocolNetworkConnection, + ) -> Result { + log_net!("on_new_protocol_network_connection: {:?}", conn); - let receiver_loop_future = Self::process_connection(self.clone(), conn.clone()); - tx.try_send(receiver_loop_future) - .map_err(map_to_string) - .map_err(logthru_net!(error "failed to start receiver loop"))?; - - // If the receiver loop started successfully, - // add the new connection to the table - inner.connection_table.add_connection(conn) + // Wrap with NetworkConnection object to start the connection processing loop + let conn = NetworkConnection::from_protocol(self.clone(), conn); + let handle = conn.get_handle(); + // Add to the connection table + inner.connection_table.add_connection(conn)?; + Ok(handle) } // Called by low-level network when any connection-oriented protocol connection appears - // either from incoming or outgoing connections. Registers connection in the connection table for later access - // and spawns a message processing loop for the connection - pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> { + // either from incoming connections. + pub(super) async fn on_accepted_protocol_network_connection( + &self, + conn: ProtocolNetworkConnection, + ) -> Result<(), String> { let mut inner = self.arc.inner.lock().await; - self.on_new_connection_internal(&mut *inner, conn) + self.on_new_protocol_network_connection(&mut *inner, conn) + .map(drop) } + // Called when we want to create a new connection or get the current one that already exists + // This will kill off any connections that are in conflict with the new connection to be made + // in order to make room for the new connection in the system's connection table pub async fn get_or_create_connection( &self, local_addr: Option, dial_info: DialInfo, - ) -> Result { + ) -> Result { log_net!( "== get_or_create_connection local_addr={:?} dial_info={:?}", local_addr.green(), @@ -146,8 +144,10 @@ impl ConnectionManager { if local_addr.port() != 0 { for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] { let pa = PeerAddress::new(descriptor.remote.socket_address, pt); - for conn in inner.connection_table.get_connections_by_remote(pa) { - let desc = conn.connection_descriptor(); + for desc in inner + .connection_table + .get_connection_descriptors_by_remote(pa) + { let mut kill = false; if let Some(conn_local) = desc.local { if (local_addr.ip().is_unspecified() @@ -163,7 +163,9 @@ impl ConnectionManager { local_addr.green(), pa.green() ); - conn.close().await?; + if let Err(e) = inner.connection_table.remove_connection(descriptor) { + log_net!(error e); + } } } } @@ -171,73 +173,17 @@ impl ConnectionManager { } // Attempt new connection - let conn = NetworkConnection::connect(local_addr, dial_info).await?; + let conn = ProtocolNetworkConnection::connect(local_addr, dial_info).await?; - self.on_new_connection_internal(&mut *inner, conn.clone())?; - - Ok(conn) + self.on_new_protocol_network_connection(&mut *inner, conn) } - // Connection receiver loop - fn process_connection( - this: ConnectionManager, - conn: NetworkConnection, - ) -> SystemPinBoxFuture<()> { - log_net!("Starting process_connection loop for {:?}", conn.green()); - let network_manager = this.network_manager(); - Box::pin(async move { - // - let descriptor = conn.connection_descriptor(); - let inactivity_timeout = this - .network_manager() - .config() - .get() - .network - .connection_inactivity_timeout_ms; - loop { - // process inactivity timeout on receives only - // if you want a keepalive, it has to be requested from the other side - let message = select! { - res = conn.recv().fuse() => { - match res { - Ok(v) => v, - Err(e) => { - log_net!(debug e); - break; - } - } - } - _ = intf::sleep(inactivity_timeout).fuse()=> { - // timeout - log_net!("connection timeout on {:?}", descriptor.green()); - break; - } - }; - if let Err(e) = network_manager - .on_recv_envelope(message.as_slice(), descriptor) - .await - { - log_net!(error e); - break; - } - } - - log_net!( - "== Connection loop finished local_addr={:?} remote={:?}", - descriptor.local.green(), - descriptor.remote.green() - ); - - if let Err(e) = this - .arc - .inner - .lock() - .await - .connection_table - .remove_connection(descriptor) - { - log_net!(error e); - } - }) + // Callback from network connection receive loop when it exits + // cleans up the entry in the connection table + pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) { + let mut inner = self.arc.inner.lock().await; + if let Err(e) = inner.connection_table.remove_connection(descriptor) { + log_net!(error e); + } } } diff --git a/veilid-core/src/network_manager/connection_table.rs b/veilid-core/src/network_manager/connection_table.rs index 12d3e8b3..21ce47ff 100644 --- a/veilid-core/src/network_manager/connection_table.rs +++ b/veilid-core/src/network_manager/connection_table.rs @@ -1,7 +1,4 @@ -use super::connection_limits::*; -use super::network_connection::*; -use crate::xx::*; -use crate::*; +use super::*; use alloc::collections::btree_map::Entry; use hashlink::LruCache; @@ -9,7 +6,7 @@ use hashlink::LruCache; pub struct ConnectionTable { max_connections: Vec, conn_by_descriptor: Vec>, - conns_by_remote: BTreeMap>, + descriptors_by_remote: BTreeMap>, address_filter: ConnectionLimits, } @@ -39,7 +36,7 @@ impl ConnectionTable { LruCache::new_unbounded(), LruCache::new_unbounded(), ], - conns_by_remote: BTreeMap::new(), + descriptors_by_remote: BTreeMap::new(), address_filter: ConnectionLimits::new(config), } } @@ -60,7 +57,7 @@ impl ConnectionTable { self.address_filter.add(ip_addr).map_err(map_to_string)?; // Add the connection to the table - let res = self.conn_by_descriptor[index].insert(descriptor, conn.clone()); + let res = self.conn_by_descriptor[index].insert(descriptor.clone(), conn); assert!(res.is_none()); // if we have reached the maximum number of connections per protocol type @@ -73,49 +70,54 @@ impl ConnectionTable { } // add connection records - let conns = self.conns_by_remote.entry(descriptor.remote).or_default(); + let descriptors = self + .descriptors_by_remote + .entry(descriptor.remote) + .or_default(); - warn!("add_connection: {:?}", conn); - conns.push(conn); + warn!("add_connection: {:?}", descriptor); + descriptors.push(descriptor); Ok(()) } - pub fn get_connection( - &mut self, - descriptor: ConnectionDescriptor, - ) -> Option { + pub fn get_connection(&mut self, descriptor: ConnectionDescriptor) -> Option { + warn!("get_connection: {:?}", descriptor); let index = protocol_to_index(descriptor.protocol_type()); - let out = self.conn_by_descriptor[index].get(&descriptor).cloned(); - warn!("get_connection: {:?} -> {:?}", descriptor, out); - out + let out = self.conn_by_descriptor[index].get(&descriptor); + out.map(|c| c.get_handle()) } pub fn get_last_connection_by_remote( &mut self, remote: PeerAddress, - ) -> Option { - let out = self - .conns_by_remote + ) -> Option { + warn!("get_last_connection_by_remote: {:?}", remote); + let descriptor = self + .descriptors_by_remote .get(&remote) .map(|v| v[(v.len() - 1)].clone()); - warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out); - if let Some(connection) = &out { + if let Some(descriptor) = descriptor { // lru bump - let index = protocol_to_index(connection.connection_descriptor().protocol_type()); - let _ = self.conn_by_descriptor[index].get(&connection.connection_descriptor()); + let index = protocol_to_index(descriptor.protocol_type()); + let handle = self.conn_by_descriptor[index] + .get(&descriptor) + .map(|c| c.get_handle()); + handle + } else { + None } - out } - pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec { - let out = self - .conns_by_remote + pub fn get_connection_descriptors_by_remote( + &mut self, + remote: PeerAddress, + ) -> Vec { + warn!("get_connection_descriptors_by_remote: {:?}", remote); + self.descriptors_by_remote .get(&remote) .cloned() - .unwrap_or_default(); - warn!("get_connections_by_remote: {:?} -> {:?}", remote, out); - out + .unwrap_or_default() } pub fn connection_count(&self) -> usize { @@ -126,7 +128,7 @@ impl ConnectionTable { let ip_addr = descriptor.remote.socket_address.to_ip_addr(); // conns_by_remote - match self.conns_by_remote.entry(descriptor.remote) { + match self.descriptors_by_remote.entry(descriptor.remote) { Entry::Vacant(_) => { panic!("inconsistency in connection table") } @@ -135,7 +137,7 @@ impl ConnectionTable { // Remove one matching connection from the list for (n, elem) in v.iter().enumerate() { - if elem.connection_descriptor() == descriptor { + if *elem == descriptor { v.remove(n); break; } @@ -151,18 +153,14 @@ impl ConnectionTable { .expect("Inconsistency in connection table"); } - pub fn remove_connection( - &mut self, - descriptor: ConnectionDescriptor, - ) -> Result { + pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> { warn!("remove_connection: {:?}", descriptor); let index = protocol_to_index(descriptor.protocol_type()); - let out = self.conn_by_descriptor[index] + let _ = self.conn_by_descriptor[index] .remove(&descriptor) .ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?; self.remove_connection_records(descriptor); - - Ok(out) + Ok(()) } } diff --git a/veilid-core/src/network_manager/mod.rs b/veilid-core/src/network_manager/mod.rs index 895a91e2..978c6d06 100644 --- a/veilid-core/src/network_manager/mod.rs +++ b/veilid-core/src/network_manager/mod.rs @@ -5,6 +5,7 @@ mod native; #[cfg(target_arch = "wasm32")] mod wasm; +mod connection_handle; mod connection_limits; mod connection_manager; mod connection_table; @@ -17,8 +18,9 @@ pub mod tests; pub use network_connection::*; //////////////////////////////////////////////////////////////////////////////////////// - +use connection_limits::*; use connection_manager::*; +use connection_handle::*; use dht::*; use hashlink::LruCache; use intf::*; @@ -1034,7 +1036,7 @@ impl NetworkManager { // Called when a packet potentially containing an RPC envelope is received by a low-level // network protocol handler. Processes the envelope, authenticates and decrypts the RPC message // and passes it to the RPC handler - pub async fn on_recv_envelope( + async fn on_recv_envelope( &self, data: &[u8], descriptor: ConnectionDescriptor, diff --git a/veilid-core/src/network_manager/native/mod.rs b/veilid-core/src/network_manager/native/mod.rs index f400cf4e..e7d859ff 100644 --- a/veilid-core/src/network_manager/native/mod.rs +++ b/veilid-core/src/network_manager/native/mod.rs @@ -341,7 +341,7 @@ impl Network { log_net!("send_data_to_existing_connection to {:?}", descriptor); // connection exists, send over it - conn.send(data).await.map_err(logthru_net!())?; + conn.send_async(data).await.map_err(logthru_net!())?; // Network accounting self.network_manager() @@ -389,7 +389,7 @@ impl Network { .get_or_create_connection(Some(local_addr), dial_info.clone()) .await?; - let res = conn.send(data).await.map_err(logthru_net!(error)); + let res = conn.send_async(data).await.map_err(logthru_net!(error)); if res.is_ok() { // Network accounting self.network_manager() diff --git a/veilid-core/src/network_manager/native/network_tcp.rs b/veilid-core/src/network_manager/native/network_tcp.rs index def3075b..d4bd36c5 100644 --- a/veilid-core/src/network_manager/native/network_tcp.rs +++ b/veilid-core/src/network_manager/native/network_tcp.rs @@ -7,7 +7,7 @@ use sockets::*; #[derive(Clone)] pub struct ListenerState { - pub protocol_handlers: Vec>, + pub protocol_accept_handlers: Vec>, pub tls_protocol_handlers: Vec>, pub tls_acceptor: Option, } @@ -15,7 +15,7 @@ pub struct ListenerState { impl ListenerState { pub fn new() -> Self { Self { - protocol_handlers: Vec::new(), + protocol_accept_handlers: Vec::new(), tls_protocol_handlers: Vec::new(), tls_acceptor: None, } @@ -46,7 +46,7 @@ impl Network { addr: SocketAddr, protocol_handlers: &[Box], tls_connection_initial_timeout: u64, - ) -> Result, String> { + ) -> Result, String> { let ts = tls_acceptor .accept(stream) .await @@ -76,9 +76,9 @@ impl Network { stream: AsyncPeekStream, tcp_stream: TcpStream, addr: SocketAddr, - protocol_handlers: &[Box], - ) -> Result, String> { - for ah in protocol_handlers.iter() { + protocol_accept_handlers: &[Box], + ) -> Result, String> { + for ah in protocol_accept_handlers.iter() { if let Some(nc) = ah .on_accept(stream.clone(), tcp_stream.clone(), addr) .await @@ -185,7 +185,7 @@ impl Network { ) .await } else { - this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers) + this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers) .await }; @@ -207,7 +207,10 @@ impl Network { }; // Register the new connection in the connection manager - if let Err(e) = connection_manager.on_new_connection(conn).await { + if let Err(e) = connection_manager + .on_accepted_protocol_network_connection(conn) + .await + { log_net!(error "failed to register new connection: {}", e); } }) @@ -270,7 +273,7 @@ impl Network { )); } else { ls.write() - .protocol_handlers + .protocol_accept_handlers .push(new_protocol_accept_handler( self.network_manager().config(), false, diff --git a/veilid-core/src/network_manager/native/protocol/mod.rs b/veilid-core/src/network_manager/native/protocol/mod.rs index 8a761afb..401d3ce4 100644 --- a/veilid-core/src/network_manager/native/protocol/mod.rs +++ b/veilid-core/src/network_manager/native/protocol/mod.rs @@ -21,7 +21,7 @@ impl ProtocolNetworkConnection { pub async fn connect( local_address: Option, dial_info: DialInfo, - ) -> Result { + ) -> Result { match dial_info.protocol_type() { ProtocolType::UDP => { panic!("Should not connect to UDP dialinfo"); @@ -55,6 +55,16 @@ impl ProtocolNetworkConnection { } } + pub fn descriptor(&self) -> ConnectionDescriptor { + match self { + Self::Dummy(d) => d.descriptor(), + Self::RawTcp(t) => t.descriptor(), + Self::WsAccepted(w) => w.descriptor(), + Self::Ws(w) => w.descriptor(), + Self::Wss(w) => w.descriptor(), + } + } + pub async fn close(&self) -> Result<(), String> { match self { Self::Dummy(d) => d.close(), diff --git a/veilid-core/src/network_manager/native/protocol/tcp.rs b/veilid-core/src/network_manager/native/protocol/tcp.rs index 950d136e..7bbf1014 100644 --- a/veilid-core/src/network_manager/native/protocol/tcp.rs +++ b/veilid-core/src/network_manager/native/protocol/tcp.rs @@ -3,6 +3,7 @@ use futures_util::{AsyncReadExt, AsyncWriteExt}; use sockets::*; pub struct RawTcpNetworkConnection { + descriptor: ConnectionDescriptor, stream: AsyncPeekStream, tcp_stream: TcpStream, } @@ -14,8 +15,20 @@ impl fmt::Debug for RawTcpNetworkConnection { } impl RawTcpNetworkConnection { - pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self { - Self { stream, tcp_stream } + pub fn new( + descriptor: ConnectionDescriptor, + stream: AsyncPeekStream, + tcp_stream: TcpStream, + ) -> Self { + Self { + descriptor, + stream, + tcp_stream, + } + } + + pub fn descriptor(&self) -> ConnectionDescriptor { + self.descriptor.clone() } pub async fn close(&self) -> Result<(), String> { @@ -33,7 +46,7 @@ impl RawTcpNetworkConnection { .map_err(logthru_net!()) } - pub async fn send(&self, message: Vec) -> Result<(), String> { + async fn send_internal(mut stream: AsyncPeekStream, message: Vec) -> Result<(), String> { log_net!("sending TCP message of size {}", message.len()); if message.len() > MAX_MESSAGE_SIZE { return Err("sending too large TCP message".to_owned()); @@ -41,7 +54,6 @@ impl RawTcpNetworkConnection { let len = message.len() as u16; let header = [b'V', b'L', len as u8, (len >> 8) as u8]; - let mut stream = self.stream.clone(); stream .write_all(&header) .await @@ -54,6 +66,11 @@ impl RawTcpNetworkConnection { .map_err(logthru_net!()) } + pub async fn send(&self, message: Vec) -> Result<(), String> { + let stream = self.stream.clone(); + Self::send_internal(stream, message).await + } + pub async fn recv(&self) -> Result, String> { let mut header = [0u8; 4]; @@ -108,7 +125,7 @@ impl RawTcpProtocolHandler { stream: AsyncPeekStream, tcp_stream: TcpStream, socket_addr: SocketAddr, - ) -> Result, String> { + ) -> Result, String> { log_net!("TCP: on_accept_async: enter"); let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN]; let peeklen = stream @@ -123,10 +140,11 @@ impl RawTcpProtocolHandler { ProtocolType::TCP, ); let local_address = self.inner.lock().local_address; - let conn = NetworkConnection::from_protocol( + let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new( ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)), - ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)), - ); + stream, + tcp_stream, + )); log_net!(debug "TCP: on_accept_async from: {}", socket_addr); @@ -136,7 +154,7 @@ impl RawTcpProtocolHandler { pub async fn connect( local_address: Option, dial_info: DialInfo, - ) -> Result { + ) -> Result { // Get remote socket address to connect to let remote_socket_addr = dial_info.to_socket_addr(); @@ -161,13 +179,15 @@ impl RawTcpProtocolHandler { let ps = AsyncPeekStream::new(ts.clone()); // Wrap the stream in a network connection and return it - let conn = NetworkConnection::from_protocol( + let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new( ConnectionDescriptor { local: Some(SocketAddress::from_socket_addr(actual_local_address)), remote: dial_info.to_peer_address(), }, - ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)), - ); + ps, + ts, + )); + Ok(conn) } @@ -194,24 +214,15 @@ impl RawTcpProtocolHandler { .map_err(logthru_net!(error "remote_addr={}", socket_addr))?; // See what local address we ended up with and turn this into a stream - let actual_local_address = ts - .local_addr() - .map_err(map_to_string) - .map_err(logthru_net!("could not get local address from TCP stream"))?; + // let actual_local_address = ts + // .local_addr() + // .map_err(map_to_string) + // .map_err(logthru_net!("could not get local address from TCP stream"))?; let ps = AsyncPeekStream::new(ts.clone()); - // Wrap the stream in a network connection and return it - let conn = NetworkConnection::from_protocol( - ConnectionDescriptor { - local: Some(SocketAddress::from_socket_addr(actual_local_address)), - remote: PeerAddress::new( - SocketAddress::from_socket_addr(socket_addr), - ProtocolType::TCP, - ), - }, - ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)), - ); - conn.send(data).await + // Send directly from the raw network connection + // this builds the connection and tears it down immediately after the send + RawTcpNetworkConnection::send_internal(ps, data).await } } @@ -221,7 +232,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler { stream: AsyncPeekStream, tcp_stream: TcpStream, peer_addr: SocketAddr, - ) -> SystemPinBoxFuture, String>> { + ) -> SystemPinBoxFuture, String>> { Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr)) } } diff --git a/veilid-core/src/network_manager/native/protocol/ws.rs b/veilid-core/src/network_manager/native/protocol/ws.rs index a3975ac9..4ad6d1b4 100644 --- a/veilid-core/src/network_manager/native/protocol/ws.rs +++ b/veilid-core/src/network_manager/native/protocol/ws.rs @@ -15,6 +15,7 @@ pub struct WebsocketNetworkConnection where T: io::Read + io::Write + Send + Unpin + 'static, { + descriptor: ConnectionDescriptor, stream: CloneStream>, tcp_stream: TcpStream, } @@ -32,13 +33,22 @@ impl WebsocketNetworkConnection where T: io::Read + io::Write + Send + Unpin + 'static, { - pub fn new(stream: WebSocketStream, tcp_stream: TcpStream) -> Self { + pub fn new( + descriptor: ConnectionDescriptor, + stream: WebSocketStream, + tcp_stream: TcpStream, + ) -> Self { Self { + descriptor, stream: CloneStream::new(stream), tcp_stream, } } + pub fn descriptor(&self) -> ConnectionDescriptor { + self.descriptor.clone() + } + pub async fn close(&self) -> Result<(), String> { // Make an attempt to flush the stream self.stream @@ -132,7 +142,7 @@ impl WebsocketProtocolHandler { ps: AsyncPeekStream, tcp_stream: TcpStream, socket_addr: SocketAddr, - ) -> Result, String> { + ) -> Result, String> { log_net!("WS: on_accept_async: enter"); let request_path_len = self.arc.request_path.len() + 2; @@ -179,25 +189,24 @@ impl WebsocketProtocolHandler { let peer_addr = PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type); - let conn = NetworkConnection::from_protocol( + let conn = ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new( ConnectionDescriptor::new( peer_addr, SocketAddress::from_socket_addr(self.arc.local_address), ), - ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new( - ws_stream, tcp_stream, - )), - ); + ws_stream, + tcp_stream, + )); log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr); Ok(Some(conn)) } - pub async fn connect( + async fn connect_internal( local_address: Option, dial_info: DialInfo, - ) -> Result { + ) -> Result { // Split dial info up let (tls, scheme) = match &dial_info { DialInfo::WS(_) => (false, "ws"), @@ -251,26 +260,27 @@ impl WebsocketProtocolHandler { .map_err(map_to_string) .map_err(logthru_net!(error))?; - Ok(NetworkConnection::from_protocol( - descriptor, - ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new( - ws_stream, tcp_stream, - )), + Ok(ProtocolNetworkConnection::Wss( + WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream), )) } else { let (ws_stream, _response) = client_async(request, tcp_stream.clone()) .await .map_err(map_to_string) .map_err(logthru_net!(error))?; - Ok(NetworkConnection::from_protocol( - descriptor, - ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new( - ws_stream, tcp_stream, - )), + Ok(ProtocolNetworkConnection::Ws( + WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream), )) } } + pub async fn connect( + local_address: Option, + dial_info: DialInfo, + ) -> Result { + Self::connect_internal(local_address, dial_info).await + } + pub async fn send_unbound_message(dial_info: DialInfo, data: Vec) -> Result<(), String> { if data.len() > MAX_MESSAGE_SIZE { return Err("sending too large unbound WS message".to_owned()); @@ -281,11 +291,11 @@ impl WebsocketProtocolHandler { dial_info, ); - let conn = Self::connect(None, dial_info.clone()) + let protconn = Self::connect_internal(None, dial_info.clone()) .await .map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?; - conn.send(data).await + protconn.send(data).await } } @@ -295,7 +305,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler { stream: AsyncPeekStream, tcp_stream: TcpStream, peer_addr: SocketAddr, - ) -> SystemPinBoxFuture, String>> { + ) -> SystemPinBoxFuture, String>> { Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr)) } } diff --git a/veilid-core/src/network_manager/network_connection.rs b/veilid-core/src/network_manager/network_connection.rs index e667939f..87607521 100644 --- a/veilid-core/src/network_manager/network_connection.rs +++ b/veilid-core/src/network_manager/network_connection.rs @@ -1,5 +1,5 @@ use super::*; -use crate::xx::*; +use futures_util::{FutureExt, StreamExt}; cfg_if::cfg_if! { if #[cfg(target_arch = "wasm32")] { @@ -16,7 +16,7 @@ cfg_if::cfg_if! { stream: AsyncPeekStream, tcp_stream: TcpStream, peer_addr: SocketAddr, - ) -> SystemPinBoxFuture, String>>; + ) -> SystemPinBoxFuture, String>>; } pub trait ProtocolAcceptHandlerClone { @@ -45,9 +45,14 @@ cfg_if::cfg_if! { // Dummy protocol network connection for testing #[derive(Debug)] -pub struct DummyNetworkConnection {} +pub struct DummyNetworkConnection { + descriptor: ConnectionDescriptor, +} impl DummyNetworkConnection { + pub fn descriptor(&self) -> ConnectionDescriptor { + self.descriptor.clone() + } pub fn close(&self) -> Result<(), String> { Ok(()) } @@ -62,6 +67,14 @@ impl DummyNetworkConnection { /////////////////////////////////////////////////////////// // Top-level protocol independent network connection object +#[derive(Clone, Copy, Debug)] +enum RecvLoopAction { + Send, + Recv, + Finish, + Timeout, +} + #[derive(Debug, Clone)] pub struct NetworkConnectionStats { last_message_sent_time: Option, @@ -69,107 +82,249 @@ pub struct NetworkConnectionStats { } #[derive(Debug)] -struct NetworkConnectionInner { - stats: NetworkConnectionStats, -} - -#[derive(Debug)] -struct NetworkConnectionArc { - descriptor: ConnectionDescriptor, - protocol_connection: ProtocolNetworkConnection, - established_time: u64, - inner: Mutex, -} - -#[derive(Clone, Debug)] pub struct NetworkConnection { - arc: Arc, + descriptor: ConnectionDescriptor, + _processor: Option>, + established_time: u64, + stats: Arc>, + sender: flume::Sender>, } -impl PartialEq for NetworkConnection { - fn eq(&self, other: &Self) -> bool { - Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc) - } -} - -impl Eq for NetworkConnection {} impl NetworkConnection { - fn new_inner() -> NetworkConnectionInner { - NetworkConnectionInner { - stats: NetworkConnectionStats { + pub(super) fn dummy(descriptor: ConnectionDescriptor) -> Self { + // Create handle for sending (dummy is immediately disconnected) + let (sender, _receiver) = flume::bounded(intf::get_concurrency() as usize); + + Self { + descriptor, + _processor: None, + established_time: intf::get_timestamp(), + stats: Arc::new(Mutex::new(NetworkConnectionStats { last_message_sent_time: None, last_message_recv_time: None, - }, - } - } - fn new_arc( - descriptor: ConnectionDescriptor, - protocol_connection: ProtocolNetworkConnection, - ) -> NetworkConnectionArc { - NetworkConnectionArc { - descriptor, - protocol_connection, - established_time: intf::get_timestamp(), - inner: Mutex::new(Self::new_inner()), + })), + sender, } } - pub fn dummy(descriptor: ConnectionDescriptor) -> Self { - NetworkConnection::from_protocol( - descriptor, - ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}), - ) - } - - pub fn from_protocol( - descriptor: ConnectionDescriptor, + pub(super) fn from_protocol( + connection_manager: ConnectionManager, protocol_connection: ProtocolNetworkConnection, ) -> Self { - Self { - arc: Arc::new(Self::new_arc(descriptor, protocol_connection)), - } - } + // Get timeout + let network_manager = connection_manager.network_manager(); + let inactivity_timeout = network_manager + .config() + .get() + .network + .connection_inactivity_timeout_ms; - pub async fn connect( - local_address: Option, - dial_info: DialInfo, - ) -> Result { - ProtocolNetworkConnection::connect(local_address, dial_info).await + // Get descriptor + let descriptor = protocol_connection.descriptor(); + + // Create handle for sending + let (sender, receiver) = flume::bounded(intf::get_concurrency() as usize); + + // Create stats + let stats = Arc::new(Mutex::new(NetworkConnectionStats { + last_message_sent_time: None, + last_message_recv_time: None, + })); + + // Spawn connection processor and pass in protocol connection + let processor = intf::spawn_local(Self::process_connection( + connection_manager, + descriptor.clone(), + receiver, + protocol_connection, + inactivity_timeout, + stats.clone(), + )); + + // Return the connection + Self { + descriptor, + _processor: Some(processor), + established_time: intf::get_timestamp(), + stats, + sender, + } } pub fn connection_descriptor(&self) -> ConnectionDescriptor { - self.arc.descriptor + self.descriptor.clone() } - pub async fn close(&self) -> Result<(), String> { - self.arc.protocol_connection.close().await + pub fn get_handle(&self) -> ConnectionHandle { + ConnectionHandle::new(self.descriptor.clone(), self.sender.clone()) } - pub async fn send(&self, message: Vec) -> Result<(), String> { + async fn send_internal( + protocol_connection: &ProtocolNetworkConnection, + stats: Arc>, + message: Vec, + ) -> Result<(), String> { let ts = intf::get_timestamp(); - let out = self.arc.protocol_connection.send(message).await; + let out = protocol_connection.send(message).await; if out.is_ok() { - let mut inner = self.arc.inner.lock(); - inner.stats.last_message_sent_time.max_assign(Some(ts)); + let mut stats = stats.lock(); + stats.last_message_sent_time.max_assign(Some(ts)); } out } - pub async fn recv(&self) -> Result, String> { + async fn recv_internal( + protocol_connection: &ProtocolNetworkConnection, + stats: Arc>, + ) -> Result, String> { let ts = intf::get_timestamp(); - let out = self.arc.protocol_connection.recv().await; + let out = protocol_connection.recv().await; if out.is_ok() { - let mut inner = self.arc.inner.lock(); - inner.stats.last_message_recv_time.max_assign(Some(ts)); + let mut stats = stats.lock(); + stats.last_message_recv_time.max_assign(Some(ts)); } out } pub fn stats(&self) -> NetworkConnectionStats { - let inner = self.arc.inner.lock(); - inner.stats.clone() + let stats = self.stats.lock(); + stats.clone() } pub fn established_time(&self) -> u64 { - self.arc.established_time + self.established_time + } + + // Connection receiver loop + fn process_connection( + connection_manager: ConnectionManager, + descriptor: ConnectionDescriptor, + receiver: flume::Receiver>, + protocol_connection: ProtocolNetworkConnection, + connection_inactivity_timeout_ms: u32, + stats: Arc>, + ) -> SystemPinBoxFuture<()> { + Box::pin(async move { + log_net!( + "Starting process_connection loop for {:?}", + descriptor.green() + ); + + let network_manager = connection_manager.network_manager(); + let mut unord = FuturesUnordered::new(); + let mut need_receiver = true; + let mut need_sender = true; + + // Push mutable timer so we can reset it + // Normally we would use an io::timeout here, but WASM won't support that, so we use a mutable sleep future + let new_timer = || { + intf::sleep(connection_inactivity_timeout_ms).then(|_| async { + // timeout + log_net!("connection timeout on {:?}", descriptor.green()); + RecvLoopAction::Timeout + }) + }; + let timer = MutableFuture::new(new_timer()); + unord.push(timer.clone().boxed()); + + loop { + // Add another message sender future if necessary + if need_sender { + need_sender = false; + unord.push( + receiver + .recv_async() + .then(|res| async { + match res { + Ok(message) => { + // send the packet + if let Err(e) = Self::send_internal( + &protocol_connection, + stats.clone(), + message, + ) + .await + { + // Sending the packet along can fail, if so, this connection is dead + log_net!(debug e); + RecvLoopAction::Finish + } else { + RecvLoopAction::Send + } + } + Err(e) => { + // All senders gone, shouldn't happen since we store one alongside the join handle + log_net!(warn e); + RecvLoopAction::Finish + } + } + }) + .boxed(), + ); + } + + // Add another message receiver future if necessary + if need_receiver { + need_sender = false; + unord.push( + Self::recv_internal(&protocol_connection, stats.clone()) + .then(|res| async { + match res { + Ok(message) => { + // Pass received messages up to the network manager for processing + if let Err(e) = network_manager + .on_recv_envelope(message.as_slice(), descriptor) + .await + { + log_net!(error e); + RecvLoopAction::Finish + } else { + RecvLoopAction::Recv + } + } + Err(e) => { + // Connection unable to receive, closed + log_net!(warn e); + RecvLoopAction::Finish + } + } + }) + .boxed(), + ); + } + + // Process futures + match unord.next().await { + Some(RecvLoopAction::Send) => { + // Don't reset inactivity timer if we're only sending + + need_sender = true; + } + Some(RecvLoopAction::Recv) => { + // Reset inactivity timer since we got something from this connection + timer.set(new_timer()); + + need_receiver = true; + } + Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => { + break; + } + + None => { + // Should not happen + unreachable!(); + } + } + } + + log_net!( + "== Connection loop finished local_addr={:?} remote={:?}", + descriptor.local.green(), + descriptor.remote.green() + ); + + connection_manager + .report_connection_finished(descriptor) + .await + }) } } diff --git a/veilid-core/src/network_manager/tests/test_connection_table.rs b/veilid-core/src/network_manager/tests/test_connection_table.rs index dbfc2d07..bc0d2695 100644 --- a/veilid-core/src/network_manager/tests/test_connection_table.rs +++ b/veilid-core/src/network_manager/tests/test_connection_table.rs @@ -52,10 +52,15 @@ pub async fn test_add_get_remove() { ); let c1 = NetworkConnection::dummy(a1); + let c1h = c1.get_handle(); let c2 = NetworkConnection::dummy(a2); + //let c2h = c2.get_handle(); let c3 = NetworkConnection::dummy(a3); + //let c3h = c3.get_handle(); let c4 = NetworkConnection::dummy(a4); + //let c4h = c4.get_handle(); let c5 = NetworkConnection::dummy(a5); + //let c5h = c5.get_handle(); assert_eq!(a1, c2.connection_descriptor()); assert_ne!(a3, c4.connection_descriptor()); @@ -63,36 +68,39 @@ pub async fn test_add_get_remove() { assert_eq!(table.connection_count(), 0); assert_eq!(table.get_connection(a1), None); - table.add_connection(c1.clone()).unwrap(); + table.add_connection(c1).unwrap(); assert_eq!(table.connection_count(), 1); assert_err!(table.remove_connection(a3)); assert_err!(table.remove_connection(a4)); assert_eq!(table.connection_count(), 1); - assert_eq!(table.get_connection(a1), Some(c1.clone())); - assert_eq!(table.get_connection(a1), Some(c1.clone())); + assert_eq!(table.get_connection(a1), Some(c1h.clone())); + assert_eq!(table.get_connection(a1), Some(c1h.clone())); assert_eq!(table.connection_count(), 1); - assert_err!(table.add_connection(c1.clone())); - assert_err!(table.add_connection(c2.clone())); + assert_err!(table.add_connection(c2)); assert_eq!(table.connection_count(), 1); - assert_eq!(table.get_connection(a1), Some(c1.clone())); - assert_eq!(table.get_connection(a1), Some(c1.clone())); + assert_eq!(table.get_connection(a1), Some(c1h.clone())); + assert_eq!(table.get_connection(a1), Some(c1h.clone())); assert_eq!(table.connection_count(), 1); - assert_eq!(table.remove_connection(a2), Ok(c1.clone())); + assert_eq!(table.remove_connection(a2), Ok(())); assert_eq!(table.connection_count(), 0); assert_err!(table.remove_connection(a2)); assert_eq!(table.connection_count(), 0); assert_eq!(table.get_connection(a2), None); assert_eq!(table.get_connection(a1), None); assert_eq!(table.connection_count(), 0); - table.add_connection(c1.clone()).unwrap(); + let c1 = NetworkConnection::dummy(a1); + //let c1h = c1.get_handle(); + table.add_connection(c1).unwrap(); + let c2 = NetworkConnection::dummy(a2); + //let c2h = c2.get_handle(); assert_err!(table.add_connection(c2)); - table.add_connection(c3.clone()).unwrap(); - table.add_connection(c4.clone()).unwrap(); + table.add_connection(c3).unwrap(); + table.add_connection(c4).unwrap(); assert_eq!(table.connection_count(), 3); - assert_eq!(table.remove_connection(a2), Ok(c1)); - assert_eq!(table.remove_connection(a3), Ok(c3)); - assert_eq!(table.remove_connection(a4), Ok(c4)); + assert_eq!(table.remove_connection(a2), Ok(())); + assert_eq!(table.remove_connection(a3), Ok(())); + assert_eq!(table.remove_connection(a4), Ok(())); assert_eq!(table.connection_count(), 0); } diff --git a/veilid-core/src/network_manager/wasm/protocol/ws.rs b/veilid-core/src/network_manager/wasm/protocol/ws.rs index 767ced1a..4c3d35b8 100644 --- a/veilid-core/src/network_manager/wasm/protocol/ws.rs +++ b/veilid-core/src/network_manager/wasm/protocol/ws.rs @@ -13,6 +13,7 @@ struct WebsocketNetworkConnectionInner { #[derive(Clone)] pub struct WebsocketNetworkConnection { + descriptor: ConnectionDescriptor, inner: Arc, } @@ -23,8 +24,11 @@ impl fmt::Debug for WebsocketNetworkConnection { } impl WebsocketNetworkConnection { - pub fn new(ws_meta: WsMeta, ws_stream: WsStream) -> Self { + pub fn new( + descriptor: ConnectionDescriptor, + ws_meta: WsMeta, ws_stream: WsStream) -> Self { Self { + descriptor, inner: Arc::new(WebsocketNetworkConnectionInner { ws_meta, ws_stream: CloneStream::new(ws_stream), @@ -32,6 +36,10 @@ impl WebsocketNetworkConnection { } } + pub fn descriptor(&self) -> ConnectionDescriptor { + self.descriptor.clone() + } + pub async fn close(&self) -> Result<(), String> { self.inner.ws_meta.close().await.map_err(map_to_string).map(drop) } @@ -73,7 +81,7 @@ impl WebsocketProtocolHandler { pub async fn connect( local_address: Option, dial_info: DialInfo, - ) -> Result { + ) -> Result { assert!(local_address.is_none()); @@ -96,10 +104,10 @@ impl WebsocketProtocolHandler { // Make our connection descriptor - Ok(NetworkConnection::from_protocol(ConnectionDescriptor { + Ok(ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(ConnectionDescriptor { local: None, remote: dial_info.to_peer_address(), - },ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(wsmeta, wsio)))) + }, wsmeta, wsio))) } pub async fn send_unbound_message(dial_info: DialInfo, data: Vec) -> Result<(), String> { diff --git a/veilid-core/src/xx/mod.rs b/veilid-core/src/xx/mod.rs index 93cb068a..95609b14 100644 --- a/veilid-core/src/xx/mod.rs +++ b/veilid-core/src/xx/mod.rs @@ -8,6 +8,7 @@ mod eventual_value_clone; mod ip_addr_port; mod ip_extra; mod log_thru; +mod mutable_future; mod single_future; mod single_shot_eventual; mod split_url; @@ -104,6 +105,7 @@ pub use eventual_value::*; pub use eventual_value_clone::*; pub use ip_addr_port::*; pub use ip_extra::*; +pub use mutable_future::*; pub use single_future::*; pub use single_shot_eventual::*; pub use tick_task::*; diff --git a/veilid-core/src/xx/mutable_future.rs b/veilid-core/src/xx/mutable_future.rs new file mode 100644 index 00000000..1e3ca478 --- /dev/null +++ b/veilid-core/src/xx/mutable_future.rs @@ -0,0 +1,33 @@ +use super::*; + +pub struct MutableFuture> { + inner: Arc>>>, +} + +impl> MutableFuture { + pub fn new(inner: T) -> Self { + Self { + inner: Arc::new(Mutex::new(Box::pin(inner))), + } + } + + pub fn set(&self, inner: T) { + *self.inner.lock() = Box::pin(inner); + } +} + +impl> Clone for MutableFuture { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl> Future for MutableFuture { + type Output = O; + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + let mut inner = self.inner.lock(); + T::poll(inner.as_mut(), cx) + } +}