From 180628beefdd7894e82096a4506dd27a26b152ba Mon Sep 17 00:00:00 2001 From: John Smith Date: Sun, 12 Jun 2022 20:58:02 -0400 Subject: [PATCH] refactor for cooperative cancellation --- Cargo.lock | 13 ++ veilid-core/Cargo.toml | 1 + veilid-core/src/attachment_manager.rs | 7 +- .../src/network_manager/connection_manager.rs | 95 ++++++-- .../src/network_manager/connection_table.rs | 11 + veilid-core/src/network_manager/mod.rs | 38 ++-- veilid-core/src/network_manager/native/mod.rs | 29 ++- .../native/network_class_discovery.rs | 7 +- .../src/network_manager/native/network_tcp.rs | 200 +++++++++------- .../src/network_manager/native/network_udp.rs | 69 ++++-- .../src/network_manager/network_connection.rs | 46 ++-- veilid-core/src/receipt_manager.rs | 71 ++++-- veilid-core/src/routing_table/mod.rs | 60 +++-- veilid-core/src/rpc_processor/mod.rs | 31 ++- .../src/tests/common/test_host_interface.rs | 38 ++++ veilid-core/src/xx/mod.rs | 5 + veilid-core/src/xx/must_join_handle.rs | 43 ++++ veilid-core/src/xx/must_join_single_future.rs | 213 ++++++++++++++++++ veilid-core/src/xx/tick_task.rs | 61 +++-- 19 files changed, 810 insertions(+), 228 deletions(-) create mode 100644 veilid-core/src/xx/must_join_handle.rs create mode 100644 veilid-core/src/xx/must_join_single_future.rs diff --git a/Cargo.lock b/Cargo.lock index 2713536d..c48dd07c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4410,6 +4410,18 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stop-token" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af91f480ee899ab2d9f8435bfdfc14d08a5754bd9d3fef1f1a1c23336aad6c8b" +dependencies = [ + "async-channel", + "cfg-if 1.0.0", + "futures-core", + "pin-project-lite", +] + [[package]] name = "strsim" version = "0.10.0" @@ -5209,6 +5221,7 @@ dependencies = [ "simplelog", "socket2", "static_assertions", + "stop-token", "thiserror", "tracing", "tracing-error", diff --git a/veilid-core/Cargo.toml b/veilid-core/Cargo.toml index 9a5c4396..386d544d 100644 --- a/veilid-core/Cargo.toml +++ b/veilid-core/Cargo.toml @@ -41,6 +41,7 @@ flume = { version = "^0", features = ["async"] } enumset = { version= "^1", features = ["serde"] } backtrace = { version = "^0", optional = true } owo-colors = "^3" +stop-token = "^0" ed25519-dalek = { version = "^1", default_features = false, features = ["alloc", "u64_backend"] } x25519-dalek = { package = "x25519-dalek-ng", version = "^1", default_features = false, features = ["u64_backend"] } diff --git a/veilid-core/src/attachment_manager.rs b/veilid-core/src/attachment_manager.rs index 15af8144..2bd24d7c 100644 --- a/veilid-core/src/attachment_manager.rs +++ b/veilid-core/src/attachment_manager.rs @@ -109,7 +109,7 @@ pub struct AttachmentManagerInner { maintain_peers: bool, attach_timestamp: Option, update_callback: Option, - attachment_maintainer_jh: Option>, + attachment_maintainer_jh: Option>, } #[derive(Clone)] @@ -306,8 +306,9 @@ impl AttachmentManager { // Create long-running connection maintenance routine let this = self.clone(); self.inner.lock().maintain_peers = true; - self.inner.lock().attachment_maintainer_jh = - Some(intf::spawn(this.attachment_maintainer())); + self.inner.lock().attachment_maintainer_jh = Some(MustJoinHandle::new(intf::spawn( + this.attachment_maintainer(), + ))); } #[instrument(level = "trace", skip(self))] diff --git a/veilid-core/src/network_manager/connection_manager.rs b/veilid-core/src/network_manager/connection_manager.rs index 43087c27..ebfdfaa1 100644 --- a/veilid-core/src/network_manager/connection_manager.rs +++ b/veilid-core/src/network_manager/connection_manager.rs @@ -9,11 +9,12 @@ use network_connection::*; #[derive(Debug)] struct ConnectionManagerInner { connection_table: ConnectionTable, + stop_source: Option, } struct ConnectionManagerArc { network_manager: NetworkManager, - inner: AsyncMutex, + inner: AsyncMutex>, } impl core::fmt::Debug for ConnectionManagerArc { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -31,14 +32,14 @@ pub struct ConnectionManager { impl ConnectionManager { fn new_inner(config: VeilidConfig) -> ConnectionManagerInner { ConnectionManagerInner { + stop_source: Some(StopSource::new()), connection_table: ConnectionTable::new(config), } } fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc { - let config = network_manager.config(); ConnectionManagerArc { network_manager, - inner: AsyncMutex::new(Self::new_inner(config)), + inner: AsyncMutex::new(None), } } pub fn new(network_manager: NetworkManager) -> Self { @@ -53,12 +54,32 @@ impl ConnectionManager { pub async fn startup(&self) { trace!("startup connection manager"); - //let mut inner = self.arc.inner.lock().await; + let mut inner = self.arc.inner.lock().await; + if inner.is_some() { + panic!("shouldn't start connection manager twice without shutting it down first"); + } + + *inner = Some(Self::new_inner(self.network_manager().config())); } pub async fn shutdown(&self) { - // Drops connection table, which drops all connections in it - *self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config()); + // Remove the inner from the lock + let mut inner = { + let mut inner_lock = self.arc.inner.lock().await; + let inner = match inner_lock.take() { + Some(v) => v, + None => { + panic!("not started"); + } + }; + inner + }; + + // Stop all the connections + drop(inner.stop_source.take()); + + // Wait for the connections to complete + inner.connection_table.join().await; } // Returns a network connection if one already is established @@ -67,6 +88,12 @@ impl ConnectionManager { descriptor: ConnectionDescriptor, ) -> Option { let mut inner = self.arc.inner.lock().await; + let inner = match &mut *inner { + Some(v) => v, + None => { + panic!("not started"); + } + }; inner.connection_table.get_connection(descriptor) } @@ -81,24 +108,18 @@ impl ConnectionManager { log_net!("on_new_protocol_network_connection: {:?}", conn); // Wrap with NetworkConnection object to start the connection processing loop - let conn = NetworkConnection::from_protocol(self.clone(), conn); + let stop_token = match &inner.stop_source { + Some(ss) => ss.token(), + None => return Err("not creating connection because we are stopping".to_owned()), + }; + + let conn = NetworkConnection::from_protocol(self.clone(), stop_token, conn); let handle = conn.get_handle(); // Add to the connection table inner.connection_table.add_connection(conn)?; Ok(handle) } - // Called by low-level network when any connection-oriented protocol connection appears - // either from incoming connections. - pub(super) async fn on_accepted_protocol_network_connection( - &self, - conn: ProtocolNetworkConnection, - ) -> Result<(), String> { - let mut inner = self.arc.inner.lock().await; - self.on_new_protocol_network_connection(&mut *inner, conn) - .map(drop) - } - // Called when we want to create a new connection or get the current one that already exists // This will kill off any connections that are in conflict with the new connection to be made // in order to make room for the new connection in the system's connection table @@ -107,6 +128,14 @@ impl ConnectionManager { local_addr: Option, dial_info: DialInfo, ) -> Result { + let mut inner = self.arc.inner.lock().await; + let inner = match &mut *inner { + Some(v) => v, + None => { + panic!("not started"); + } + }; + log_net!( "== get_or_create_connection local_addr={:?} dial_info={:?}", local_addr.green(), @@ -123,7 +152,6 @@ impl ConnectionManager { // If any connection to this remote exists that has the same protocol, return it // Any connection will do, we don't have to match the local address - let mut inner = self.arc.inner.lock().await; if let Some(conn) = inner .connection_table @@ -197,10 +225,39 @@ impl ConnectionManager { self.on_new_protocol_network_connection(&mut *inner, conn) } + /////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Callbacks + + // Called by low-level network when any connection-oriented protocol connection appears + // either from incoming connections. + pub(super) async fn on_accepted_protocol_network_connection( + &self, + conn: ProtocolNetworkConnection, + ) -> Result<(), String> { + let mut inner = self.arc.inner.lock().await; + let inner = match &mut *inner { + Some(v) => v, + None => { + // If we are shutting down, just drop this and return + return Ok(()); + } + }; + self.on_new_protocol_network_connection(inner, conn) + .map(drop) + } + // Callback from network connection receive loop when it exits // cleans up the entry in the connection table pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) { let mut inner = self.arc.inner.lock().await; + let inner = match &mut *inner { + Some(v) => v, + None => { + // If we're shutting down, do nothing here + return; + } + }; + if let Err(e) = inner.connection_table.remove_connection(descriptor) { log_net!(error e); } diff --git a/veilid-core/src/network_manager/connection_table.rs b/veilid-core/src/network_manager/connection_table.rs index 8873d171..8feb7434 100644 --- a/veilid-core/src/network_manager/connection_table.rs +++ b/veilid-core/src/network_manager/connection_table.rs @@ -1,5 +1,6 @@ use super::*; use alloc::collections::btree_map::Entry; +use futures_util::StreamExt; use hashlink::LruCache; #[derive(Debug)] @@ -41,6 +42,16 @@ impl ConnectionTable { } } + pub async fn join(&mut self) { + let mut unord = FuturesUnordered::new(); + for table in &mut self.conn_by_descriptor { + for (_, v) in table.drain() { + unord.push(v); + } + } + while unord.next().await.is_some() {} + } + pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> { let descriptor = conn.connection_descriptor(); let ip_addr = descriptor.remote_address().to_ip_addr(); diff --git a/veilid-core/src/network_manager/mod.rs b/veilid-core/src/network_manager/mod.rs index df797204..fdd8ae56 100644 --- a/veilid-core/src/network_manager/mod.rs +++ b/veilid-core/src/network_manager/mod.rs @@ -171,8 +171,8 @@ impl NetworkManager { let this2 = this.clone(); this.unlocked_inner .rolling_transfers_task - .set_routine(move |l, t| { - Box::pin(this2.clone().rolling_transfers_task_routine(l, t)) + .set_routine(move |s, l, t| { + Box::pin(this2.clone().rolling_transfers_task_routine(s, l, t)) }); } // Set relay management tick task @@ -180,8 +180,8 @@ impl NetworkManager { let this2 = this.clone(); this.unlocked_inner .relay_management_task - .set_routine(move |l, t| { - Box::pin(this2.clone().relay_management_task_routine(l, t)) + .set_routine(move |s, l, t| { + Box::pin(this2.clone().relay_management_task_routine(s, l, t)) }); } this @@ -275,10 +275,10 @@ impl NetworkManager { }); // Start network components + connection_manager.startup().await; + net.startup().await?; rpc_processor.startup().await?; receipt_manager.startup().await?; - net.startup().await?; - connection_manager.startup().await; trace!("NetworkManager::internal_startup end"); @@ -302,20 +302,20 @@ impl NetworkManager { trace!("NetworkManager::shutdown begin"); // Cancel all tasks - if let Err(e) = self.unlocked_inner.rolling_transfers_task.cancel().await { - warn!("rolling_transfers_task not cancelled: {}", e); + if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await { + warn!("rolling_transfers_task not stopped: {}", e); } - if let Err(e) = self.unlocked_inner.relay_management_task.cancel().await { - warn!("relay_management_task not cancelled: {}", e); + if let Err(e) = self.unlocked_inner.relay_management_task.stop().await { + warn!("relay_management_task not stopped: {}", e); } // Shutdown network components if they started up let components = self.inner.lock().components.clone(); if let Some(components) = components { - components.connection_manager.shutdown().await; - components.net.shutdown().await; components.receipt_manager.shutdown().await; components.rpc_processor.shutdown().await; + components.net.shutdown().await; + components.connection_manager.shutdown().await; } // reset the state @@ -1202,7 +1202,12 @@ impl NetworkManager { // Keep relays assigned and accessible #[instrument(level = "trace", skip(self), err)] - async fn relay_management_task_routine(self, _last_ts: u64, cur_ts: u64) -> Result<(), String> { + async fn relay_management_task_routine( + self, + stop_token: StopToken, + _last_ts: u64, + cur_ts: u64, + ) -> Result<(), String> { // log_net!("--- network manager relay_management task"); // Get our node's current node info and network class and do the right thing @@ -1255,7 +1260,12 @@ impl NetworkManager { // Compute transfer statistics for the low level network #[instrument(level = "trace", skip(self), err)] - async fn rolling_transfers_task_routine(self, last_ts: u64, cur_ts: u64) -> Result<(), String> { + async fn rolling_transfers_task_routine( + self, + stop_token: StopToken, + last_ts: u64, + cur_ts: u64, + ) -> Result<(), String> { // log_net!("--- network manager rolling_transfers task"); { let inner = &mut *self.inner.lock(); diff --git a/veilid-core/src/network_manager/native/mod.rs b/veilid-core/src/network_manager/native/mod.rs index 8fa47461..4112a966 100644 --- a/veilid-core/src/network_manager/native/mod.rs +++ b/veilid-core/src/network_manager/native/mod.rs @@ -42,7 +42,8 @@ struct NetworkInner { protocol_config: Option, static_public_dialinfo: ProtocolSet, network_class: Option, - join_handles: Vec>, + join_handles: Vec>, + stop_source: Option, udp_port: u16, tcp_port: u16, ws_port: u16, @@ -82,6 +83,7 @@ impl Network { static_public_dialinfo: ProtocolSet::empty(), network_class: None, join_handles: Vec::new(), + stop_source: None, udp_port: 0u16, tcp_port: 0u16, ws_port: 0u16, @@ -115,8 +117,8 @@ impl Network { let this2 = this.clone(); this.unlocked_inner .update_network_class_task - .set_routine(move |l, t| { - Box::pin(this2.clone().update_network_class_task_routine(l, t)) + .set_routine(move |s, l, t| { + Box::pin(this2.clone().update_network_class_task_routine(s, l, t)) }); } @@ -200,7 +202,7 @@ impl Network { Ok(config) } - fn add_to_join_handles(&self, jh: JoinHandle<()>) { + fn add_to_join_handles(&self, jh: MustJoinHandle<()>) { let mut inner = self.inner.lock(); inner.join_handles.push(jh); } @@ -506,17 +508,28 @@ impl Network { let network_manager = self.network_manager(); let routing_table = self.routing_table(); - // Cancel all tasks - if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await { - warn!("update_network_class_task not cancelled: {}", e); + // Stop all tasks + if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await { + error!("update_network_class_task not cancelled: {}", e); } + let mut unord = FuturesUnordered::new(); + { + let mut inner = self.inner.lock(); + // Drop the stop + drop(inner.stop_source.take()); + // take the join handles out + for h in inner.join_handles.drain(..) { + unord.push(h); + } + } + // Wait for everything to stop + while unord.next().await.is_some() {} // Drop all dial info routing_table.clear_dial_info_details(RoutingDomain::PublicInternet); routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork); // Reset state including network class - // Cancels all async background tasks by dropping join handles *self.inner.lock() = Self::new_inner(network_manager); info!("network stopped"); diff --git a/veilid-core/src/network_manager/native/network_class_discovery.rs b/veilid-core/src/network_manager/native/network_class_discovery.rs index 6ce879cd..c9b27855 100644 --- a/veilid-core/src/network_manager/native/network_class_discovery.rs +++ b/veilid-core/src/network_manager/native/network_class_discovery.rs @@ -465,7 +465,12 @@ impl Network { } #[instrument(level = "trace", skip(self), err)] - pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> { + pub async fn update_network_class_task_routine( + self, + stop_token: StopToken, + _l: u64, + _t: u64, + ) -> Result<(), String> { // Ensure we aren't trying to update this without clearing it first let old_network_class = self.inner.lock().network_class; assert_eq!(old_network_class, None); diff --git a/veilid-core/src/network_manager/native/network_tcp.rs b/veilid-core/src/network_manager/native/network_tcp.rs index d4bd36c5..43b83157 100644 --- a/veilid-core/src/network_manager/native/network_tcp.rs +++ b/veilid-core/src/network_manager/native/network_tcp.rs @@ -2,6 +2,7 @@ use super::*; use crate::intf::*; use async_tls::TlsAcceptor; use sockets::*; +use stop_token::future::FutureExt; ///////////////////////////////////////////////////////////////// @@ -91,6 +92,106 @@ impl Network { Ok(None) } + async fn tcp_acceptor( + self, + tcp_stream: async_std::io::Result, + listener_state: Arc>, + connection_manager: ConnectionManager, + connection_initial_timeout: u64, + tls_connection_initial_timeout: u64, + ) { + let tcp_stream = match tcp_stream { + Ok(v) => v, + Err(_) => { + // If this happened our low-level listener socket probably died + // so it's time to restart the network + self.inner.lock().network_needs_restart = true; + return; + } + }; + + let listener_state = listener_state.clone(); + let connection_manager = connection_manager.clone(); + + // Limit the number of connections from the same IP address + // and the number of total connections + let addr = match tcp_stream.peer_addr() { + Ok(addr) => addr, + Err(e) => { + log_net!(error "failed to get peer address: {}", e); + return; + } + }; + // XXX limiting + + log_net!("TCP connection from: {}", addr); + + // Create a stream we can peek on + let ps = AsyncPeekStream::new(tcp_stream.clone()); + + ///////////////////////////////////////////////////////////// + let mut first_packet = [0u8; PEEK_DETECT_LEN]; + + // read a chunk of the stream + if io::timeout( + Duration::from_micros(connection_initial_timeout), + ps.peek_exact(&mut first_packet), + ) + .await + .is_err() + { + // If we fail to get a packet within the connection initial timeout + // then we punt this connection + log_net!(warn "connection initial timeout from: {:?}", addr); + return; + } + + // Run accept handlers on accepted stream + + // Check is this could be TLS + let ls = listener_state.read().clone(); + + let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 { + self.try_tls_handlers( + ls.tls_acceptor.as_ref().unwrap(), + ps, + tcp_stream, + addr, + &ls.tls_protocol_handlers, + tls_connection_initial_timeout, + ) + .await + } else { + self.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers) + .await + }; + + let conn = match conn { + Ok(Some(c)) => { + log_net!("protocol handler found for {:?}: {:?}", addr, c); + c + } + Ok(None) => { + // No protocol handlers matched? drop it. + log_net!(warn "no protocol handler for connection from {:?}", addr); + return; + } + Err(e) => { + // Failed to negotiate connection? drop it. + log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e); + return; + } + }; + + // Register the new connection in the connection manager + if let Err(e) = connection_manager + .on_accepted_protocol_network_connection(conn) + .await + { + log_net!(error "failed to register new connection: {}", e); + } + } + async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> { // Get config let (connection_initial_timeout, tls_connection_initial_timeout) = { @@ -123,111 +224,40 @@ impl Network { // Spawn the socket task let this = self.clone(); + let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token(); let connection_manager = self.connection_manager(); //////////////////////////////////////////////////////////// let jh = spawn(async move { // moves listener object in and get incoming iterator // when this task exists, the listener will close the socket - listener + let _ = listener .incoming() - .for_each_concurrent(None, |tcp_stream| async { - let tcp_stream = tcp_stream.unwrap(); + .for_each_concurrent(None, |tcp_stream| { + let this = this.clone(); let listener_state = listener_state.clone(); let connection_manager = connection_manager.clone(); - - // Limit the number of connections from the same IP address - // and the number of total connections - let addr = match tcp_stream.peer_addr() { - Ok(addr) => addr, - Err(e) => { - log_net!(error "failed to get peer address: {}", e); - return; - } - }; - // XXX limiting - - log_net!("TCP connection from: {}", addr); - - // Create a stream we can peek on - let ps = AsyncPeekStream::new(tcp_stream.clone()); - - ///////////////////////////////////////////////////////////// - let mut first_packet = [0u8; PEEK_DETECT_LEN]; - - // read a chunk of the stream - if io::timeout( - Duration::from_micros(connection_initial_timeout), - ps.peek_exact(&mut first_packet), + Self::tcp_acceptor( + this, + tcp_stream, + listener_state, + connection_manager, + connection_initial_timeout, + tls_connection_initial_timeout, ) - .await - .is_err() - { - // If we fail to get a packet within the connection initial timeout - // then we punt this connection - log_net!(warn "connection initial timeout from: {:?}", addr); - return; - } - - // Run accept handlers on accepted stream - - // Check is this could be TLS - let ls = listener_state.read().clone(); - - let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 { - this.try_tls_handlers( - ls.tls_acceptor.as_ref().unwrap(), - ps, - tcp_stream, - addr, - &ls.tls_protocol_handlers, - tls_connection_initial_timeout, - ) - .await - } else { - this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers) - .await - }; - - let conn = match conn { - Ok(Some(c)) => { - log_net!("protocol handler found for {:?}: {:?}", addr, c); - c - } - Ok(None) => { - // No protocol handlers matched? drop it. - log_net!(warn "no protocol handler for connection from {:?}", addr); - return; - } - Err(e) => { - // Failed to negotiate connection? drop it. - log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e); - return; - } - }; - - // Register the new connection in the connection manager - if let Err(e) = connection_manager - .on_accepted_protocol_network_connection(conn) - .await - { - log_net!(error "failed to register new connection: {}", e); - } }) + .timeout_at(stop_token) .await; + log_net!(debug "exited incoming loop for {}", addr); // Remove our listener state from this address if we're stopping this.inner.lock().listener_states.remove(&addr); log_net!(debug "listener state removed for {}", addr); - - // If this happened our low-level listener socket probably died - // so it's time to restart the network - this.inner.lock().network_needs_restart = true; }); //////////////////////////////////////////////////////////// // Add to join handles - self.add_to_join_handles(jh); + self.add_to_join_handles(MustJoinHandle::new(jh)); Ok(()) } diff --git a/veilid-core/src/network_manager/native/network_udp.rs b/veilid-core/src/network_manager/native/network_udp.rs index ec189e04..940a0131 100644 --- a/veilid-core/src/network_manager/native/network_udp.rs +++ b/veilid-core/src/network_manager/native/network_udp.rs @@ -1,5 +1,6 @@ use super::*; use sockets::*; +use stop_token::future::FutureExt; impl Network { pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> { @@ -43,47 +44,75 @@ impl Network { // Spawn a local async task for each socket let mut protocol_handlers_unordered = FuturesUnordered::new(); let network_manager = this.network_manager(); + let stop_token = this.inner.lock().stop_source.as_ref().unwrap().token(); for ph in protocol_handlers { let network_manager = network_manager.clone(); + let stop_token = stop_token.clone(); let jh = spawn_local(async move { let mut data = vec![0u8; 65536]; - while let Ok((size, descriptor)) = ph.recv_message(&mut data).await { - // XXX: Limit the number of packets from the same IP address? - log_net!("UDP packet: {:?}", descriptor); - - // Network accounting - network_manager.stats_packet_rcvd( - descriptor.remote_address().to_ip_addr(), - size as u64, - ); - - // Pass it up for processing - if let Err(e) = network_manager - .on_recv_envelope(&data[..size], descriptor) + loop { + match ph + .recv_message(&mut data) + .timeout_at(stop_token.clone()) .await { - log_net!(error "failed to process received udp envelope: {}", e); + Ok(Ok((size, descriptor))) => { + // XXX: Limit the number of packets from the same IP address? + log_net!("UDP packet: {:?}", descriptor); + + // Network accounting + network_manager.stats_packet_rcvd( + descriptor.remote_address().to_ip_addr(), + size as u64, + ); + + // Pass it up for processing + if let Err(e) = network_manager + .on_recv_envelope(&data[..size], descriptor) + .await + { + log_net!(error "failed to process received udp envelope: {}", e); + } + } + Ok(Err(_)) => { + return false; + } + Err(_) => { + return true; + } } } }); protocol_handlers_unordered.push(jh); } - // Now we wait for any join handle to exit, - // which would indicate an error needing + // Now we wait for join handles to exit, + // if any error out it indicates an error needing // us to completely restart the network - let _ = protocol_handlers_unordered.next().await; + loop { + match protocol_handlers_unordered.next().await { + Some(v) => { + // true = stopped, false = errored + if !v { + // If any protocol handler fails, our socket died and we need to restart the network + this.inner.lock().network_needs_restart = true; + } + } + None => { + // All protocol handlers exited + break; + } + } + } trace!("UDP listener task stopped"); - // If this loop fails, our socket died and we need to restart the network - this.inner.lock().network_needs_restart = true; }); //////////////////////////////////////////////////////////// // Add to join handle - self.add_to_join_handles(jh); + self.add_to_join_handles(MustJoinHandle::new(jh)); } Ok(()) diff --git a/veilid-core/src/network_manager/network_connection.rs b/veilid-core/src/network_manager/network_connection.rs index ea02ea40..44a387ef 100644 --- a/veilid-core/src/network_manager/network_connection.rs +++ b/veilid-core/src/network_manager/network_connection.rs @@ -1,5 +1,6 @@ use super::*; use futures_util::{FutureExt, StreamExt}; +use stop_token::prelude::*; cfg_if::cfg_if! { if #[cfg(target_arch = "wasm32")] { @@ -84,7 +85,7 @@ pub struct NetworkConnectionStats { #[derive(Debug)] pub struct NetworkConnection { descriptor: ConnectionDescriptor, - _processor: Option>, + processor: Option>, established_time: u64, stats: Arc>, sender: flume::Sender>, @@ -97,7 +98,7 @@ impl NetworkConnection { Self { descriptor, - _processor: None, + processor: None, established_time: intf::get_timestamp(), stats: Arc::new(Mutex::new(NetworkConnectionStats { last_message_sent_time: None, @@ -109,6 +110,7 @@ impl NetworkConnection { pub(super) fn from_protocol( connection_manager: ConnectionManager, + stop_token: StopToken, protocol_connection: ProtocolNetworkConnection, ) -> Self { // Get timeout @@ -132,19 +134,20 @@ impl NetworkConnection { })); // Spawn connection processor and pass in protocol connection - let processor = intf::spawn_local(Self::process_connection( + let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection( connection_manager, + stop_token, descriptor.clone(), receiver, protocol_connection, inactivity_timeout, stats.clone(), - )); + ))); // Return the connection Self { descriptor, - _processor: Some(processor), + processor: Some(processor), established_time: intf::get_timestamp(), stats, sender, @@ -197,6 +200,7 @@ impl NetworkConnection { // Connection receiver loop fn process_connection( connection_manager: ConnectionManager, + stop_token: StopToken, descriptor: ConnectionDescriptor, receiver: flume::Receiver>, protocol_connection: ProtocolNetworkConnection, @@ -289,26 +293,28 @@ impl NetworkConnection { } // Process futures - match unord.next().await { - Some(RecvLoopAction::Send) => { + match unord.next().timeout_at(stop_token.clone()).await { + Ok(Some(RecvLoopAction::Send)) => { // Don't reset inactivity timer if we're only sending - need_sender = true; } - Some(RecvLoopAction::Recv) => { + Ok(Some(RecvLoopAction::Recv)) => { // Reset inactivity timer since we got something from this connection timer.set(new_timer()); need_receiver = true; } - Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => { + Ok(Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout)) => { break; } - - None => { + Ok(None) => { // Should not happen unreachable!(); } + Err(_) => { + // Stop token + break; + } } } @@ -317,9 +323,23 @@ impl NetworkConnection { descriptor.green() ); + // Let the connection manager know the receive loop exited connection_manager .report_connection_finished(descriptor) - .await + .await; }) } } + +// Resolves ready when the connection loop has terminated +impl Future for NetworkConnection { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + if let Some(mut processor) = self.processor.as_mut() { + Pin::new(&mut processor).poll(cx) + } else { + task::Poll::Ready(()) + } + } +} diff --git a/veilid-core/src/receipt_manager.rs b/veilid-core/src/receipt_manager.rs index bb9a1b80..a8fc681a 100644 --- a/veilid-core/src/receipt_manager.rs +++ b/veilid-core/src/receipt_manager.rs @@ -4,6 +4,7 @@ use dht::*; use futures_util::stream::{FuturesUnordered, StreamExt}; use network_manager::*; use routing_table::*; +use stop_token::future::FutureExt; use xx::*; #[derive(Clone, Debug, PartialEq, Eq)] @@ -170,7 +171,8 @@ pub struct ReceiptManagerInner { network_manager: NetworkManager, records_by_nonce: BTreeMap>>, next_oldest_ts: Option, - timeout_task: SingleFuture<()>, + stop_source: Option, + timeout_task: MustJoinSingleFuture<()>, } #[derive(Clone)] @@ -184,7 +186,8 @@ impl ReceiptManager { network_manager, records_by_nonce: BTreeMap::new(), next_oldest_ts: None, - timeout_task: SingleFuture::new(), + stop_source: None, + timeout_task: MustJoinSingleFuture::new(), } } @@ -201,13 +204,14 @@ impl ReceiptManager { pub async fn startup(&self) -> Result<(), String> { trace!("startup receipt manager"); // Retrieve config - /* - { - let config = self.core().config(); - let c = config.get(); - let mut inner = self.inner.lock(); - } - */ + + { + // let config = self.core().config(); + // let c = config.get(); + let mut inner = self.inner.lock(); + inner.stop_source = Some(StopSource::new()); + } + Ok(()) } @@ -235,7 +239,7 @@ impl ReceiptManager { } #[instrument(level = "trace", skip(self))] - pub async fn timeout_task_routine(self, now: u64) { + pub async fn timeout_task_routine(self, now: u64, stop_token: StopToken) { // Go through all receipts and build a list of expired nonces let mut new_next_oldest_ts: Option = None; let mut expired_records = Vec::new(); @@ -276,13 +280,25 @@ impl ReceiptManager { } // Wait on all the multi-call callbacks - while callbacks.next().await.is_some() {} + loop { + match callbacks.next().timeout_at(stop_token.clone()).await { + Ok(Some(_)) => {} + Ok(None) | Err(_) => break, + } + } } pub async fn tick(&self) -> Result<(), String> { - let (next_oldest_ts, timeout_task) = { + let (next_oldest_ts, timeout_task, stop_token) = { let inner = self.inner.lock(); - (inner.next_oldest_ts, inner.timeout_task.clone()) + let stop_token = match inner.stop_source.as_ref() { + Some(ss) => ss.token(), + None => { + // Do nothing if we're shutting down + return Ok(()); + } + }; + (inner.next_oldest_ts, inner.timeout_task.clone(), stop_token) }; let now = intf::get_timestamp(); // If we have at least one timestamp to expire, lets do it @@ -290,7 +306,7 @@ impl ReceiptManager { if now >= next_oldest_ts { // Single-spawn the timeout task routine let _ = timeout_task - .single_spawn(self.clone().timeout_task_routine(now)) + .single_spawn(self.clone().timeout_task_routine(now, stop_token)) .await; } } @@ -299,6 +315,20 @@ impl ReceiptManager { pub async fn shutdown(&self) { let network_manager = self.network_manager(); + + // Stop all tasks + let timeout_task = { + let mut inner = self.inner.lock(); + // Drop the stop + drop(inner.stop_source.take()); + inner.timeout_task.clone() + }; + + // Wait for everything to stop + if !timeout_task.join().await.is_ok() { + panic!("joining timeout task failed"); + } + *self.inner.lock() = Self::new_inner(network_manager); } @@ -410,9 +440,16 @@ impl ReceiptManager { ); // Increment return count - let callback_future = { + let (callback_future, stop_token) = { // Look up the receipt record from the nonce let mut inner = self.inner.lock(); + let stop_token = match inner.stop_source.as_ref() { + Some(ss) => ss.token(), + None => { + // If we're stopping do nothing here + return Ok(()); + } + }; let record = match inner.records_by_nonce.get(&receipt_nonce) { Some(r) => r.clone(), None => { @@ -438,12 +475,12 @@ impl ReceiptManager { Self::update_next_oldest_timestamp(&mut *inner); } - callback_future + (callback_future, stop_token) }; // Issue the callback if let Some(callback_future) = callback_future { - callback_future.await; + let _ = callback_future.timeout_at(stop_token).await; } Ok(()) diff --git a/veilid-core/src/routing_table/mod.rs b/veilid-core/src/routing_table/mod.rs index 13c89825..6b769c9e 100644 --- a/veilid-core/src/routing_table/mod.rs +++ b/veilid-core/src/routing_table/mod.rs @@ -71,7 +71,7 @@ struct RoutingTableUnlockedInner { bootstrap_task: TickTask, peer_minimum_refresh_task: TickTask, ping_validator_task: TickTask, - node_info_update_single_future: SingleFuture<()>, + node_info_update_single_future: MustJoinSingleFuture<()>, } #[derive(Clone)] @@ -103,7 +103,7 @@ impl RoutingTable { bootstrap_task: TickTask::new(1), peer_minimum_refresh_task: TickTask::new_ms(c.network.dht.min_peer_refresh_time_ms), ping_validator_task: TickTask::new(1), - node_info_update_single_future: SingleFuture::new(), + node_info_update_single_future: MustJoinSingleFuture::new(), } } pub fn new(network_manager: NetworkManager) -> Self { @@ -118,8 +118,8 @@ impl RoutingTable { let this2 = this.clone(); this.unlocked_inner .rolling_transfers_task - .set_routine(move |l, t| { - Box::pin(this2.clone().rolling_transfers_task_routine(l, t)) + .set_routine(move |s, l, t| { + Box::pin(this2.clone().rolling_transfers_task_routine(s, l, t)) }); } // Set bootstrap tick task @@ -127,15 +127,15 @@ impl RoutingTable { let this2 = this.clone(); this.unlocked_inner .bootstrap_task - .set_routine(move |_l, _t| Box::pin(this2.clone().bootstrap_task_routine())); + .set_routine(move |s, _l, _t| Box::pin(this2.clone().bootstrap_task_routine(s))); } // Set peer minimum refresh tick task { let this2 = this.clone(); this.unlocked_inner .peer_minimum_refresh_task - .set_routine(move |_l, _t| { - Box::pin(this2.clone().peer_minimum_refresh_task_routine()) + .set_routine(move |s, _l, _t| { + Box::pin(this2.clone().peer_minimum_refresh_task_routine(s)) }); } // Set ping validator tick task @@ -143,7 +143,9 @@ impl RoutingTable { let this2 = this.clone(); this.unlocked_inner .ping_validator_task - .set_routine(move |l, t| Box::pin(this2.clone().ping_validator_task_routine(l, t))); + .set_routine(move |s, l, t| { + Box::pin(this2.clone().ping_validator_task_routine(s, l, t)) + }); } this } @@ -373,26 +375,26 @@ impl RoutingTable { pub async fn terminate(&self) { // Cancel all tasks being ticked - if let Err(e) = self.unlocked_inner.rolling_transfers_task.cancel().await { - warn!("rolling_transfers_task not cancelled: {}", e); + if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await { + error!("rolling_transfers_task not stopped: {}", e); } - if let Err(e) = self.unlocked_inner.bootstrap_task.cancel().await { - warn!("bootstrap_task not cancelled: {}", e); + if let Err(e) = self.unlocked_inner.bootstrap_task.stop().await { + error!("bootstrap_task not stopped: {}", e); } - if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.cancel().await { - warn!("peer_minimum_refresh_task not cancelled: {}", e); + if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.stop().await { + error!("peer_minimum_refresh_task not stopped: {}", e); } - if let Err(e) = self.unlocked_inner.ping_validator_task.cancel().await { - warn!("ping_validator_task not cancelled: {}", e); + if let Err(e) = self.unlocked_inner.ping_validator_task.stop().await { + error!("ping_validator_task not stopped: {}", e); } if self .unlocked_inner .node_info_update_single_future - .cancel() + .join() .await .is_err() { - warn!("node_info_update_single_future not cancelled"); + error!("node_info_update_single_future not stopped"); } *self.inner.lock() = Self::new_inner(self.network_manager()); @@ -990,7 +992,7 @@ impl RoutingTable { } #[instrument(level = "trace", skip(self), err)] - async fn bootstrap_task_routine(self) -> Result<(), String> { + async fn bootstrap_task_routine(self, stop_token: StopToken) -> Result<(), String> { let (bootstrap, bootstrap_nodes) = { let c = self.config.get(); ( @@ -1093,7 +1095,7 @@ impl RoutingTable { // Ask our remaining peers to give us more peers before we go // back to the bootstrap servers to keep us from bothering them too much #[instrument(level = "trace", skip(self), err)] - async fn peer_minimum_refresh_task_routine(self) -> Result<(), String> { + async fn peer_minimum_refresh_task_routine(self, stop_token: StopToken) -> Result<(), String> { // get list of all peers we know about, even the unreliable ones, and ask them to find nodes close to our node too let noderefs = { let mut inner = self.inner.lock(); @@ -1125,7 +1127,12 @@ impl RoutingTable { // Ping each node in the routing table if they need to be pinged // to determine their reliability #[instrument(level = "trace", skip(self), err)] - async fn ping_validator_task_routine(self, _last_ts: u64, cur_ts: u64) -> Result<(), String> { + async fn ping_validator_task_routine( + self, + stop_token: StopToken, + _last_ts: u64, + cur_ts: u64, + ) -> Result<(), String> { // log_rtab!("--- ping_validator task"); let rpc = self.rpc_processor(); @@ -1144,7 +1151,9 @@ impl RoutingTable { nr, e.state_debug_info(cur_ts) ); - unord.push(intf::spawn_local(rpc.clone().rpc_call_status(nr))); + unord.push(MustJoinHandle::new(intf::spawn_local( + rpc.clone().rpc_call_status(nr), + ))); } Option::<()>::None }); @@ -1158,7 +1167,12 @@ impl RoutingTable { // Compute transfer statistics to determine how 'fast' a node is #[instrument(level = "trace", skip(self), err)] - async fn rolling_transfers_task_routine(self, last_ts: u64, cur_ts: u64) -> Result<(), String> { + async fn rolling_transfers_task_routine( + self, + stop_token: StopToken, + last_ts: u64, + cur_ts: u64, + ) -> Result<(), String> { // log_rtab!("--- rolling_transfers task"); let inner = &mut *self.inner.lock(); diff --git a/veilid-core/src/rpc_processor/mod.rs b/veilid-core/src/rpc_processor/mod.rs index 10154efd..1b3642b8 100644 --- a/veilid-core/src/rpc_processor/mod.rs +++ b/veilid-core/src/rpc_processor/mod.rs @@ -10,9 +10,11 @@ use crate::intf::*; use crate::xx::*; use capnp::message::ReaderSegments; use coders::*; +use futures_util::StreamExt; use network_manager::*; use receipt_manager::*; use routing_table::*; +use stop_token::future::FutureExt; use super::*; ///////////////////////////////////////////////////////////////////// @@ -167,7 +169,8 @@ pub struct RPCProcessorInner { timeout: u64, max_route_hop_count: usize, waiting_rpc_table: BTreeMap>, - worker_join_handles: Vec>, + stop_source: Option, + worker_join_handles: Vec>, } #[derive(Clone)] @@ -189,6 +192,7 @@ impl RPCProcessor { timeout: 10000000, max_route_hop_count: 7, waiting_rpc_table: BTreeMap::new(), + stop_source: None, worker_join_handles: Vec::new(), } } @@ -1368,8 +1372,8 @@ impl RPCProcessor { } } - async fn rpc_worker(self, receiver: flume::Receiver) { - while let Ok(msg) = receiver.recv_async().await { + async fn rpc_worker(self, stop_token: StopToken, receiver: flume::Receiver) { + while let Ok(Ok(msg)) = receiver.recv_async().timeout_at(stop_token.clone()).await { let _ = self .process_rpc_message(msg) .await @@ -1409,20 +1413,37 @@ impl RPCProcessor { inner.max_route_hop_count = max_route_hop_count; let channel = flume::bounded(queue_size as usize); inner.send_channel = Some(channel.0.clone()); + inner.stop_source = Some(StopSource::new()); // spin up N workers trace!("Spinning up {} RPC workers", concurrency); for _ in 0..concurrency { let this = self.clone(); let receiver = channel.1.clone(); - let jh = spawn(Self::rpc_worker(this, receiver)); - inner.worker_join_handles.push(jh); + let jh = spawn(Self::rpc_worker(this, inner.stop_source.as_ref().unwrap().token(), receiver)); + inner.worker_join_handles.push(MustJoinHandle::new(jh)); } Ok(()) } pub async fn shutdown(&self) { + // Stop the rpc workers + let mut unord = FuturesUnordered::new(); + { + let mut inner = self.inner.lock(); + // drop the stop + drop(inner.stop_source.take()); + // take the join handles out + for h in inner.worker_join_handles.drain(..) { + unord.push(h); + } + } + + // Wait for them to complete + while unord.next().await.is_some() {} + + // Release the rpc processor *self.inner.lock() = Self::new_inner(self.network_manager()); } diff --git a/veilid-core/src/tests/common/test_host_interface.rs b/veilid-core/src/tests/common/test_host_interface.rs index 1e9741cf..e8874d0c 100644 --- a/veilid-core/src/tests/common/test_host_interface.rs +++ b/veilid-core/src/tests/common/test_host_interface.rs @@ -545,6 +545,43 @@ pub async fn test_single_future() { assert_eq!(sf.check().await, Ok(None)); } +pub async fn test_must_join_single_future() { + info!("testing must join single future"); + let sf = MustJoinSingleFuture::::new(); + assert_eq!(sf.check().await, Ok(None)); + assert_eq!( + sf.single_spawn(async { + intf::sleep(2000).await; + 69 + }) + .await, + Ok(None) + ); + assert_eq!(sf.check().await, Ok(None)); + assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None)); + assert_eq!(sf.join().await, Ok(Some(69))); + assert_eq!( + sf.single_spawn(async { + intf::sleep(1000).await; + 37 + }) + .await, + Ok(None) + ); + intf::sleep(2000).await; + assert_eq!( + sf.single_spawn(async { + intf::sleep(1000).await; + 27 + }) + .await, + Ok(Some(37)) + ); + intf::sleep(2000).await; + assert_eq!(sf.join().await, Ok(Some(27))); + assert_eq!(sf.check().await, Ok(None)); +} + pub async fn test_tools() { info!("testing retry_falloff_log"); let mut last_us = 0u64; @@ -568,6 +605,7 @@ pub async fn test_all() { #[cfg(not(target_arch = "wasm32"))] test_network_interfaces().await; test_single_future().await; + test_must_join_single_future().await; test_eventual().await; test_eventual_value().await; test_eventual_value_clone().await; diff --git a/veilid-core/src/xx/mod.rs b/veilid-core/src/xx/mod.rs index fcfc8c02..36894216 100644 --- a/veilid-core/src/xx/mod.rs +++ b/veilid-core/src/xx/mod.rs @@ -8,6 +8,8 @@ mod eventual_value_clone; mod ip_addr_port; mod ip_extra; mod log_thru; +mod must_join_handle; +mod must_join_single_future; mod mutable_future; mod single_future; mod single_shot_eventual; @@ -25,6 +27,7 @@ pub use owo_colors::OwoColorize; pub use parking_lot::*; pub use split_url::*; pub use static_assertions::*; +pub use stop_token::*; pub use tracing::*; pub type PinBox = Pin>; @@ -105,6 +108,8 @@ pub use eventual_value::*; pub use eventual_value_clone::*; pub use ip_addr_port::*; pub use ip_extra::*; +pub use must_join_handle::*; +pub use must_join_single_future::*; pub use mutable_future::*; pub use single_future::*; pub use single_shot_eventual::*; diff --git a/veilid-core/src/xx/must_join_handle.rs b/veilid-core/src/xx/must_join_handle.rs new file mode 100644 index 00000000..284c9938 --- /dev/null +++ b/veilid-core/src/xx/must_join_handle.rs @@ -0,0 +1,43 @@ +use async_executors::JoinHandle; +use core::future::Future; +use core::pin::Pin; +use core::sync::atomic::{AtomicBool, Ordering}; +use core::task::{Context, Poll}; + +#[derive(Debug)] +pub struct MustJoinHandle { + join_handle: JoinHandle, + completed: AtomicBool, +} + +impl MustJoinHandle { + pub fn new(join_handle: JoinHandle) -> Self { + Self { + join_handle, + completed: AtomicBool::new(false), + } + } +} + +impl Drop for MustJoinHandle { + fn drop(&mut self) { + // panic if we haven't completed + if !self.completed.load(Ordering::Relaxed) { + panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.") + } + } +} + +impl Future for MustJoinHandle { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.join_handle).poll(cx) { + Poll::Ready(t) => { + self.completed.store(true, Ordering::Relaxed); + Poll::Ready(t) + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/veilid-core/src/xx/must_join_single_future.rs b/veilid-core/src/xx/must_join_single_future.rs new file mode 100644 index 00000000..773a4db1 --- /dev/null +++ b/veilid-core/src/xx/must_join_single_future.rs @@ -0,0 +1,213 @@ +use super::*; +use crate::intf::*; +use cfg_if::*; +use core::task::Poll; +use futures_util::poll; + +#[derive(Debug)] +struct MustJoinSingleFutureInner +where + T: 'static, +{ + locked: bool, + join_handle: Option>, +} + +/// Spawns a single background processing task idempotently, possibly returning the return value of the previously executed background task +/// This does not queue, just ensures that no more than a single copy of the task is running at a time, but allowing tasks to be retriggered +#[derive(Debug, Clone)] +pub struct MustJoinSingleFuture +where + T: 'static, +{ + inner: Arc>>, +} + +impl Default for MustJoinSingleFuture +where + T: 'static, +{ + fn default() -> Self { + Self::new() + } +} + +impl MustJoinSingleFuture +where + T: 'static, +{ + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(MustJoinSingleFutureInner { + locked: false, + join_handle: None, + })), + } + } + + fn try_lock(&self) -> Result>, ()> { + let mut inner = self.inner.lock(); + if inner.locked { + // If already locked error out + return Err(()); + } + inner.locked = true; + // If we got the lock, return what we have for a join handle if anything + Ok(inner.join_handle.take()) + } + + fn unlock(&self, jh: Option>) { + let mut inner = self.inner.lock(); + assert!(inner.locked); + assert!(inner.join_handle.is_none()); + inner.locked = false; + inner.join_handle = jh; + } + + // Check the result + pub async fn check(&self) -> Result, ()> { + let mut out: Option = None; + + // See if we have a result we can return + let maybe_jh = match self.try_lock() { + Ok(v) => v, + Err(_) => { + // If we are already polling somewhere else, don't hand back a result + return Err(()); + } + }; + if maybe_jh.is_some() { + let mut jh = maybe_jh.unwrap(); + + // See if we finished, if so, return the value of the last execution + if let Poll::Ready(r) = poll!(&mut jh) { + out = Some(r); + // Task finished, unlock with nothing + self.unlock(None); + } else { + // Still running put the join handle back so we can check on it later + self.unlock(Some(jh)); + } + } else { + // No task, unlock with nothing + self.unlock(None); + } + + // Return the prior result if we have one + Ok(out) + } + + // Wait for the result + pub async fn join(&self) -> Result, ()> { + let mut out: Option = None; + + // See if we have a result we can return + let maybe_jh = match self.try_lock() { + Ok(v) => v, + Err(_) => { + // If we are already polling somewhere else, + // that's an error because you can only join + // these things once + return Err(()); + } + }; + if maybe_jh.is_some() { + let jh = maybe_jh.unwrap(); + // Wait for return value of the last execution + out = Some(jh.await); + // Task finished, unlock with nothing + } else { + // No task, unlock with nothing + } + self.unlock(None); + + // Return the prior result if we have one + Ok(out) + } + + // Possibly spawn the future possibly returning the value of the last execution + cfg_if! { + if #[cfg(target_arch = "wasm32")] { + pub async fn single_spawn( + &self, + future: impl Future + 'static, + ) -> Result, ()> { + let mut out: Option = None; + + // See if we have a result we can return + let maybe_jh = match self.try_lock() { + Ok(v) => v, + Err(_) => { + // If we are already polling somewhere else, don't hand back a result + return Err(()); + } + }; + let mut run = true; + + if maybe_jh.is_some() { + let mut jh = maybe_jh.unwrap(); + + // See if we finished, if so, return the value of the last execution + if let Poll::Ready(r) = poll!(&mut jh) { + out = Some(r); + // Task finished, unlock with a new task + } else { + // Still running, don't run again, unlock with the current join handle + run = false; + self.unlock(Some(jh)); + } + } + + // Run if we should do that + if run { + self.unlock(Some(MustJoinHandle::new(spawn_local(future)))); + } + + // Return the prior result if we have one + Ok(out) + } + } + } +} +cfg_if! { + if #[cfg(not(target_arch = "wasm32"))] { + impl MustJoinSingleFuture + where + T: 'static + Send, + { + pub async fn single_spawn( + &self, + future: impl Future + Send + 'static, + ) -> Result, ()> { + let mut out: Option = None; + // See if we have a result we can return + let maybe_jh = match self.try_lock() { + Ok(v) => v, + Err(_) => { + // If we are already polling somewhere else, don't hand back a result + return Err(()); + } + }; + let mut run = true; + if maybe_jh.is_some() { + let mut jh = maybe_jh.unwrap(); + // See if we finished, if so, return the value of the last execution + if let Poll::Ready(r) = poll!(&mut jh) { + out = Some(r); + // Task finished, unlock with a new task + } else { + // Still running, don't run again, unlock with the current join handle + run = false; + self.unlock(Some(jh)); + } + } + // Run if we should do that + if run { + self.unlock(Some(MustJoinHandle::new(spawn(future)))); + } + // Return the prior result if we have one + Ok(out) + } + } + } +} diff --git a/veilid-core/src/xx/tick_task.rs b/veilid-core/src/xx/tick_task.rs index 9e6b5335..1dc8c7b9 100644 --- a/veilid-core/src/xx/tick_task.rs +++ b/veilid-core/src/xx/tick_task.rs @@ -6,10 +6,10 @@ use once_cell::sync::OnceCell; cfg_if! { if #[cfg(target_arch = "wasm32")] { type TickTaskRoutine = - dyn Fn(u64, u64) -> PinBoxFuture> + 'static; + dyn Fn(StopToken, u64, u64) -> PinBoxFuture> + 'static; } else { type TickTaskRoutine = - dyn Fn(u64, u64) -> SendPinBoxFuture> + Send + Sync + 'static; + dyn Fn(StopToken, u64, u64) -> SendPinBoxFuture> + Send + Sync + 'static; } } @@ -20,7 +20,8 @@ pub struct TickTask { last_timestamp_us: AtomicU64, tick_period_us: u64, routine: OnceCell>, - single_future: SingleFuture>, + stop_source: AsyncMutex>, + single_future: MustJoinSingleFuture>, } impl TickTask { @@ -29,7 +30,8 @@ impl TickTask { last_timestamp_us: AtomicU64::new(0), tick_period_us, routine: OnceCell::new(), - single_future: SingleFuture::new(), + stop_source: AsyncMutex::new(None), + single_future: MustJoinSingleFuture::new(), } } pub fn new_ms(tick_period_ms: u32) -> Self { @@ -37,7 +39,8 @@ impl TickTask { last_timestamp_us: AtomicU64::new(0), tick_period_us: (tick_period_ms as u64) * 1000u64, routine: OnceCell::new(), - single_future: SingleFuture::new(), + stop_source: AsyncMutex::new(None), + single_future: MustJoinSingleFuture::new(), } } pub fn new(tick_period_sec: u32) -> Self { @@ -45,7 +48,8 @@ impl TickTask { last_timestamp_us: AtomicU64::new(0), tick_period_us: (tick_period_sec as u64) * 1000000u64, routine: OnceCell::new(), - single_future: SingleFuture::new(), + stop_source: AsyncMutex::new(None), + single_future: MustJoinSingleFuture::new(), } } @@ -53,22 +57,31 @@ impl TickTask { if #[cfg(target_arch = "wasm32")] { pub fn set_routine( &self, - routine: impl Fn(u64, u64) -> PinBoxFuture> + 'static, + routine: impl Fn(StopToken, u64, u64) -> PinBoxFuture> + 'static, ) { self.routine.set(Box::new(routine)).map_err(drop).unwrap(); } } else { pub fn set_routine( &self, - routine: impl Fn(u64, u64) -> SendPinBoxFuture> + Send + Sync + 'static, + routine: impl Fn(StopToken, u64, u64) -> SendPinBoxFuture> + Send + Sync + 'static, ) { self.routine.set(Box::new(routine)).map_err(drop).unwrap(); } } } - pub async fn cancel(&self) -> Result<(), String> { - match self.single_future.cancel().await { + pub async fn stop(&self) -> Result<(), String> { + // drop the stop source if we have one + let opt_stop_source = &mut *self.stop_source.lock().await; + if opt_stop_source.is_none() { + // already stopped, just return + return Ok(()); + } + *opt_stop_source = None; + + // wait for completion of the tick task + match self.single_future.join().await { Ok(Some(Err(err))) => Err(err), _ => Ok(()), } @@ -80,27 +93,35 @@ impl TickTask { if last_timestamp_us == 0u64 || (now - last_timestamp_us) >= self.tick_period_us { // Run the singlefuture + let opt_stop_source = &mut *self.stop_source.lock().await; + let stop_source = StopSource::new(); match self .single_future - .single_spawn(self.routine.get().unwrap()(last_timestamp_us, now)) + .single_spawn(self.routine.get().unwrap()( + stop_source.token(), + last_timestamp_us, + now, + )) .await { - Ok(Some(Err(err))) => { - // If the last execution errored out then we should pass that error up + // Single future ran this tick + Ok(Some(ret)) => { + // Set new timer self.last_timestamp_us.store(now, Ordering::Release); - return Err(err); + // Save new stopper + *opt_stop_source = Some(stop_source); + ret } + // Single future did not run this tick Ok(None) | Err(()) => { // If the execution didn't happen this time because it was already running // then we should try again the next tick and not reset the timestamp so we try as soon as possible - } - _ => { - // Execution happened, next execution attempt should happen only after tick period - self.last_timestamp_us.store(now, Ordering::Release); + Ok(()) } } + } else { + // It's not time yet + Ok(()) } - - Ok(()) } }