fix cooperative cancellation

This commit is contained in:
John Smith 2022-06-15 14:05:04 -04:00
parent 180628beef
commit c33f78ac8b
24 changed files with 520 additions and 299 deletions

38
.vscode/launch.json vendored
View File

@ -29,13 +29,17 @@
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "Launch veilid-cli", "name": "Launch veilid-cli",
"args": ["--debug"], "args": [
"--debug"
],
"program": "${workspaceFolder}/target/debug/veilid-cli", "program": "${workspaceFolder}/target/debug/veilid-cli",
"windows": { "windows": {
"program": "${workspaceFolder}/target/debug/veilid-cli.exe" "program": "${workspaceFolder}/target/debug/veilid-cli.exe"
}, },
"cwd": "${workspaceFolder}/target/debug/", "cwd": "${workspaceFolder}/target/debug/",
"sourceLanguages": ["rust"], "sourceLanguages": [
"rust"
],
"terminal": "console" "terminal": "console"
}, },
// { // {
@ -48,20 +52,21 @@
// "args": ["--trace"], // "args": ["--trace"],
// "cwd": "${workspaceFolder}/veilid-server" // "cwd": "${workspaceFolder}/veilid-server"
// } // }
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "Debug veilid-server", "name": "Debug veilid-server",
"program": "${workspaceFolder}/target/debug/veilid-server", "program": "${workspaceFolder}/target/debug/veilid-server",
"args": ["--trace", "--attach=true"], "args": [
"--trace",
"--attach=true"
],
"cwd": "${workspaceFolder}/target/debug/", "cwd": "${workspaceFolder}/target/debug/",
"env": { "env": {
"RUST_BACKTRACE": "1" "RUST_BACKTRACE": "1"
}, },
"terminal": "console" "terminal": "console"
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
@ -78,10 +83,11 @@
"name": "veilid-core" "name": "veilid-core"
} }
}, },
"args": ["${selectedText}"], "args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/target/debug/" "cwd": "${workspaceFolder}/target/debug/"
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
@ -98,10 +104,11 @@
"name": "veilid-server" "name": "veilid-server"
} }
}, },
"args": ["${selectedText}"], "args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/veilid-server" "cwd": "${workspaceFolder}/veilid-server"
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
@ -118,10 +125,11 @@
"name": "keyvaluedb-sqlite" "name": "keyvaluedb-sqlite"
} }
}, },
"args": ["${selectedText}"], "args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/external/keyvaluedb/keyvaluedb-sqlite" "cwd": "${workspaceFolder}/external/keyvaluedb/keyvaluedb-sqlite"
}, },
{ {
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
@ -138,8 +146,10 @@
"name": "keyring" "name": "keyring"
} }
}, },
"args": ["${selectedText}"], "args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/external/keyring-rs" "cwd": "${workspaceFolder}/external/keyring-rs"
} }
] ]
} }

View File

@ -334,7 +334,7 @@ impl NetworkInterfaces {
self.valid = false; self.valid = false;
let last_interfaces = core::mem::take(&mut self.interfaces); let last_interfaces = core::mem::take(&mut self.interfaces);
let mut platform_support = PlatformSupport::new().map_err(logthru_net!())?; let mut platform_support = PlatformSupport::new()?;
platform_support platform_support
.get_interfaces(&mut self.interfaces) .get_interfaces(&mut self.interfaces)
.await?; .await?;

View File

@ -2,19 +2,28 @@ use super::*;
use crate::xx::*; use crate::xx::*;
use connection_table::*; use connection_table::*;
use network_connection::*; use network_connection::*;
use stop_token::future::FutureExt;
/////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////
// Connection manager // Connection manager
#[derive(Debug)]
enum ConnectionManagerEvent {
Accepted(ProtocolNetworkConnection),
Finished(ConnectionDescriptor),
}
#[derive(Debug)] #[derive(Debug)]
struct ConnectionManagerInner { struct ConnectionManagerInner {
connection_table: ConnectionTable, connection_table: ConnectionTable,
sender: flume::Sender<ConnectionManagerEvent>,
async_processor_jh: Option<MustJoinHandle<()>>,
stop_source: Option<StopSource>, stop_source: Option<StopSource>,
} }
struct ConnectionManagerArc { struct ConnectionManagerArc {
network_manager: NetworkManager, network_manager: NetworkManager,
inner: AsyncMutex<Option<ConnectionManagerInner>>, inner: Mutex<Option<ConnectionManagerInner>>,
} }
impl core::fmt::Debug for ConnectionManagerArc { impl core::fmt::Debug for ConnectionManagerArc {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
@ -30,16 +39,23 @@ pub struct ConnectionManager {
} }
impl ConnectionManager { impl ConnectionManager {
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner { fn new_inner(
config: VeilidConfig,
stop_source: StopSource,
sender: flume::Sender<ConnectionManagerEvent>,
async_processor_jh: JoinHandle<()>,
) -> ConnectionManagerInner {
ConnectionManagerInner { ConnectionManagerInner {
stop_source: Some(StopSource::new()), stop_source: Some(stop_source),
sender: sender,
async_processor_jh: Some(MustJoinHandle::new(async_processor_jh)),
connection_table: ConnectionTable::new(config), connection_table: ConnectionTable::new(config),
} }
} }
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc { fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
ConnectionManagerArc { ConnectionManagerArc {
network_manager, network_manager,
inner: AsyncMutex::new(None), inner: Mutex::new(None),
} }
} }
pub fn new(network_manager: NetworkManager) -> Self { pub fn new(network_manager: NetworkManager) -> Self {
@ -54,18 +70,34 @@ impl ConnectionManager {
pub async fn startup(&self) { pub async fn startup(&self) {
trace!("startup connection manager"); trace!("startup connection manager");
let mut inner = self.arc.inner.lock().await; let mut inner = self.arc.inner.lock();
if inner.is_some() { if inner.is_some() {
panic!("shouldn't start connection manager twice without shutting it down first"); panic!("shouldn't start connection manager twice without shutting it down first");
} }
*inner = Some(Self::new_inner(self.network_manager().config())); // Create channel for async_processor to receive notifications of networking events
let (sender, receiver) = flume::unbounded();
// Create the stop source we'll use to stop the processor and the connection table
let stop_source = StopSource::new();
// Spawn the async processor
let async_processor = spawn(self.clone().async_processor(stop_source.token(), receiver));
// Store in the inner object
*inner = Some(Self::new_inner(
self.network_manager().config(),
stop_source,
sender,
async_processor,
));
} }
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
debug!("starting connection manager shutdown");
// Remove the inner from the lock // Remove the inner from the lock
let mut inner = { let mut inner = {
let mut inner_lock = self.arc.inner.lock().await; let mut inner_lock = self.arc.inner.lock();
let inner = match inner_lock.take() { let inner = match inner_lock.take() {
Some(v) => v, Some(v) => v,
None => { None => {
@ -75,11 +107,17 @@ impl ConnectionManager {
inner inner
}; };
// Stop all the connections // Stop all the connections and the async processor
debug!("stopping async processor task");
drop(inner.stop_source.take()); drop(inner.stop_source.take());
let async_processor_jh = inner.async_processor_jh.take().unwrap();
// wait for the async processor to stop
debug!("waiting for async processor to stop");
async_processor_jh.await;
// Wait for the connections to complete // Wait for the connections to complete
debug!("waiting for connection handlers to complete");
inner.connection_table.join().await; inner.connection_table.join().await;
debug!("finished connection manager shutdown");
} }
// Returns a network connection if one already is established // Returns a network connection if one already is established
@ -87,7 +125,7 @@ impl ConnectionManager {
&self, &self,
descriptor: ConnectionDescriptor, descriptor: ConnectionDescriptor,
) -> Option<ConnectionHandle> { ) -> Option<ConnectionHandle> {
let mut inner = self.arc.inner.lock().await; let mut inner = self.arc.inner.lock();
let inner = match &mut *inner { let inner = match &mut *inner {
Some(v) => v, Some(v) => v,
None => { None => {
@ -128,86 +166,95 @@ impl ConnectionManager {
local_addr: Option<SocketAddr>, local_addr: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<ConnectionHandle, String> { ) -> Result<ConnectionHandle, String> {
let mut inner = self.arc.inner.lock().await; let killed = {
let inner = match &mut *inner { let mut inner = self.arc.inner.lock();
Some(v) => v, let inner = match &mut *inner {
None => { Some(v) => v,
panic!("not started"); None => {
} panic!("not started");
}; }
};
log_net!(
"== get_or_create_connection local_addr={:?} dial_info={:?}",
local_addr.green(),
dial_info.green()
);
let peer_address = dial_info.to_peer_address();
let descriptor = match local_addr {
Some(la) => {
ConnectionDescriptor::new(peer_address, SocketAddress::from_socket_addr(la))
}
None => ConnectionDescriptor::new_no_local(peer_address),
};
// If any connection to this remote exists that has the same protocol, return it
// Any connection will do, we don't have to match the local address
if let Some(conn) = inner
.connection_table
.get_last_connection_by_remote(descriptor.remote())
{
log_net!( log_net!(
"== Returning existing connection local_addr={:?} peer_address={:?}", "== get_or_create_connection local_addr={:?} dial_info={:?}",
local_addr.green(), local_addr.green(),
peer_address.green() dial_info.green()
); );
return Ok(conn); let peer_address = dial_info.to_peer_address();
} let descriptor = match local_addr {
Some(la) => {
ConnectionDescriptor::new(peer_address, SocketAddress::from_socket_addr(la))
}
None => ConnectionDescriptor::new_no_local(peer_address),
};
// Drop any other protocols connections to this remote that have the same local addr // If any connection to this remote exists that has the same protocol, return it
// otherwise this connection won't succeed due to binding // Any connection will do, we don't have to match the local address
let mut killed = false;
if let Some(local_addr) = local_addr { if let Some(conn) = inner
if local_addr.port() != 0 { .connection_table
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] { .get_last_connection_by_remote(descriptor.remote())
let pa = PeerAddress::new(descriptor.remote_address().clone(), pt); {
for prior_descriptor in inner log_net!(
.connection_table "== Returning existing connection local_addr={:?} peer_address={:?}",
.get_connection_descriptors_by_remote(pa) local_addr.green(),
{ peer_address.green()
let mut kill = false; );
// See if the local address would collide
if let Some(prior_local) = prior_descriptor.local() { return Ok(conn);
if (local_addr.ip().is_unspecified() }
|| prior_local.to_ip_addr().is_unspecified()
|| (local_addr.ip() == prior_local.to_ip_addr())) // Drop any other protocols connections to this remote that have the same local addr
&& prior_local.port() == local_addr.port() // otherwise this connection won't succeed due to binding
{ let mut killed = Vec::<NetworkConnection>::new();
kill = true; if let Some(local_addr) = local_addr {
if local_addr.port() != 0 {
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
let pa = PeerAddress::new(descriptor.remote_address().clone(), pt);
for prior_descriptor in inner
.connection_table
.get_connection_descriptors_by_remote(pa)
{
let mut kill = false;
// See if the local address would collide
if let Some(prior_local) = prior_descriptor.local() {
if (local_addr.ip().is_unspecified()
|| prior_local.to_ip_addr().is_unspecified()
|| (local_addr.ip() == prior_local.to_ip_addr()))
&& prior_local.port() == local_addr.port()
{
kill = true;
}
} }
} if kill {
if kill { log_net!(debug
log_net!(debug ">< Terminating connection prior_descriptor={:?}",
">< Terminating connection prior_descriptor={:?}", prior_descriptor
prior_descriptor );
); let mut conn = inner
if let Err(e) = .connection_table
inner.connection_table.remove_connection(prior_descriptor) .remove_connection(prior_descriptor)
{ .expect("connection not in table");
log_net!(error e);
conn.close();
killed.push(conn);
} }
killed = true;
} }
} }
} }
} }
killed
};
// Wait for the killed connections to end their recv loops
let mut retry_count = if !killed.is_empty() { 2 } else { 0 };
for k in killed {
k.await;
} }
// Attempt new connection // Attempt new connection
let mut retry_count = if killed { 2 } else { 0 };
let conn = loop { let conn = loop {
match ProtocolNetworkConnection::connect(local_addr, dial_info.clone()).await { match ProtocolNetworkConnection::connect(local_addr, dial_info.clone()).await {
Ok(v) => break Ok(v), Ok(v) => break Ok(v),
@ -222,44 +269,113 @@ impl ConnectionManager {
} }
}?; }?;
self.on_new_protocol_network_connection(&mut *inner, conn) // Add to the connection table
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
Some(v) => v,
None => {
return Err("shutting down".to_owned());
}
};
self.on_new_protocol_network_connection(inner, conn)
} }
/////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////
/// Callbacks /// Callbacks
#[instrument(level = "trace", skip_all)]
async fn async_processor(
self,
stop_token: StopToken,
receiver: flume::Receiver<ConnectionManagerEvent>,
) {
// Process async commands
while let Ok(Ok(event)) = receiver.recv_async().timeout_at(stop_token.clone()).await {
match event {
ConnectionManagerEvent::Accepted(conn) => {
let mut inner = self.arc.inner.lock();
match &mut *inner {
Some(inner) => {
// Register the connection
// We don't care if this fails, since nobody here asked for the inbound connection.
// If it does, we just drop the connection
let _ = self.on_new_protocol_network_connection(inner, conn);
}
None => {
// If this somehow happens, we're shutting down
}
};
}
ConnectionManagerEvent::Finished(desc) => {
let conn = {
let mut inner_lock = self.arc.inner.lock();
match &mut *inner_lock {
Some(inner) => {
// Remove the connection and wait for the connection loop to terminate
if let Ok(conn) = inner.connection_table.remove_connection(desc) {
// Must close and wait to ensure things join
Some(conn)
} else {
None
}
}
None => None,
}
};
if let Some(mut conn) = conn {
conn.close();
conn.await;
}
}
}
}
}
// Called by low-level network when any connection-oriented protocol connection appears // Called by low-level network when any connection-oriented protocol connection appears
// either from incoming connections. // either from incoming connections.
pub(super) async fn on_accepted_protocol_network_connection( pub(super) async fn on_accepted_protocol_network_connection(
&self, &self,
conn: ProtocolNetworkConnection, conn: ProtocolNetworkConnection,
) -> Result<(), String> { ) -> Result<(), String> {
let mut inner = self.arc.inner.lock().await; // Get channel sender
let inner = match &mut *inner { let sender = {
Some(v) => v, let mut inner = self.arc.inner.lock();
None => { let inner = match &mut *inner {
// If we are shutting down, just drop this and return Some(v) => v,
return Ok(()); None => {
} // If we are shutting down, just drop this and return
return Ok(());
}
};
inner.sender.clone()
}; };
self.on_new_protocol_network_connection(inner, conn)
.map(drop) // Inform the processor of the event
let _ = sender
.send_async(ConnectionManagerEvent::Accepted(conn))
.await;
Ok(())
} }
// Callback from network connection receive loop when it exits // Callback from network connection receive loop when it exits
// cleans up the entry in the connection table // cleans up the entry in the connection table
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) { pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
let mut inner = self.arc.inner.lock().await; // Get channel sender
let inner = match &mut *inner { let sender = {
Some(v) => v, let mut inner = self.arc.inner.lock();
None => { let inner = match &mut *inner {
// If we're shutting down, do nothing here Some(v) => v,
return; None => {
} // If we are shutting down, just drop this and return
return;
}
};
inner.sender.clone()
}; };
if let Err(e) = inner.connection_table.remove_connection(descriptor) { // Inform the processor of the event
log_net!(error e); let _ = sender
} .send_async(ConnectionManagerEvent::Finished(descriptor))
.await;
} }
} }

View File

@ -160,13 +160,16 @@ impl ConnectionTable {
.expect("Inconsistency in connection table"); .expect("Inconsistency in connection table");
} }
pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> { pub fn remove_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> {
let index = protocol_to_index(descriptor.protocol_type()); let index = protocol_to_index(descriptor.protocol_type());
let _ = self.conn_by_descriptor[index] let conn = self.conn_by_descriptor[index]
.remove(&descriptor) .remove(&descriptor)
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?; .ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
self.remove_connection_records(descriptor); self.remove_connection_records(descriptor);
Ok(()) Ok(conn)
} }
} }

View File

@ -299,26 +299,31 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
trace!("NetworkManager::shutdown begin"); debug!("starting network manager shutdown");
// Cancel all tasks // Cancel all tasks
debug!("stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await { if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
warn!("rolling_transfers_task not stopped: {}", e); warn!("rolling_transfers_task not stopped: {}", e);
} }
debug!("stopping relay management task task");
if let Err(e) = self.unlocked_inner.relay_management_task.stop().await { if let Err(e) = self.unlocked_inner.relay_management_task.stop().await {
warn!("relay_management_task not stopped: {}", e); warn!("relay_management_task not stopped: {}", e);
} }
// Shutdown network components if they started up // Shutdown network components if they started up
debug!("shutting down network components");
let components = self.inner.lock().components.clone(); let components = self.inner.lock().components.clone();
if let Some(components) = components { if let Some(components) = components {
components.receipt_manager.shutdown().await;
components.rpc_processor.shutdown().await;
components.net.shutdown().await; components.net.shutdown().await;
components.connection_manager.shutdown().await; components.connection_manager.shutdown().await;
components.rpc_processor.shutdown().await;
components.receipt_manager.shutdown().await;
} }
// reset the state // reset the state
debug!("resetting network manager state");
{ {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.components = None; inner.components = None;
@ -326,9 +331,10 @@ impl NetworkManager {
} }
// send update // send update
debug!("sending network state update");
self.send_network_update(); self.send_network_update();
trace!("NetworkManager::shutdown end"); debug!("finished network manager shutdown");
} }
pub fn update_client_whitelist(&self, client: DHTKey) { pub fn update_client_whitelist(&self, client: DHTKey) {
@ -883,8 +889,7 @@ impl NetworkManager {
match self match self
.net() .net()
.send_data_to_existing_connection(descriptor, data) .send_data_to_existing_connection(descriptor, data)
.await .await?
.map_err(logthru_net!())?
{ {
None => Ok(()), None => Ok(()),
Some(_) => Err("unable to send over reverse connection".to_owned()), Some(_) => Err("unable to send over reverse connection".to_owned()),
@ -973,8 +978,7 @@ impl NetworkManager {
match self match self
.net() .net()
.send_data_to_existing_connection(descriptor, data) .send_data_to_existing_connection(descriptor, data)
.await .await?
.map_err(logthru_net!())?
{ {
None => Ok(()), None => Ok(()),
Some(_) => Err("unable to send over hole punch".to_owned()), Some(_) => Err("unable to send over hole punch".to_owned()),
@ -1005,8 +1009,7 @@ impl NetworkManager {
match this match this
.net() .net()
.send_data_to_existing_connection(descriptor, data) .send_data_to_existing_connection(descriptor, data)
.await .await?
.map_err(logthru_net!())?
{ {
None => { None => {
return Ok(if descriptor.matches_peer_scope(PeerScope::Local) { return Ok(if descriptor.matches_peer_scope(PeerScope::Local) {
@ -1150,7 +1153,7 @@ impl NetworkManager {
"failed to resolve recipient node for relay, dropping outbound relayed packet...: {:?}", "failed to resolve recipient node for relay, dropping outbound relayed packet...: {:?}",
e e
) )
}).map_err(logthru_net!())? })?
} else { } else {
// If this is not a node in the client whitelist, only allow inbound relay // If this is not a node in the client whitelist, only allow inbound relay
// which only performs a lightweight lookup before passing the packet back out // which only performs a lightweight lookup before passing the packet back out

View File

@ -279,20 +279,14 @@ impl Network {
let res = match dial_info.protocol_type() { let res = match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data) RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
.await
.map_err(logthru_net!())
} }
ProtocolType::TCP => { ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data) RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
.await
.map_err(logthru_net!())
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data) WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
.await
.map_err(logthru_net!())
} }
}; };
if res.is_ok() { if res.is_ok() {
@ -324,10 +318,7 @@ impl Network {
descriptor descriptor
); );
ph.clone() ph.clone().send_message(data, peer_socket_addr).await?;
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!())?;
// Network accounting // Network accounting
self.network_manager() self.network_manager()
@ -345,7 +336,7 @@ impl Network {
log_net!("send_data_to_existing_connection to {:?}", descriptor); log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it // connection exists, send over it
conn.send_async(data).await.map_err(logthru_net!())?; conn.send_async(data).await?;
// Network accounting // Network accounting
self.network_manager() self.network_manager()
@ -372,10 +363,7 @@ impl Network {
if dial_info.protocol_type() == ProtocolType::UDP { if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) { if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
let res = ph let res = ph.send_message(data, peer_socket_addr).await;
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
if res.is_ok() { if res.is_ok() {
// Network accounting // Network accounting
self.network_manager() self.network_manager()
@ -383,8 +371,7 @@ impl Network {
} }
return res; return res;
} }
return Err("no appropriate UDP protocol handler for dial_info".to_owned()) return Err("no appropriate UDP protocol handler for dial_info".to_owned());
.map_err(logthru_net!(error));
} }
// Handle connection-oriented protocols // Handle connection-oriented protocols
@ -394,7 +381,7 @@ impl Network {
.get_or_create_connection(Some(local_addr), dial_info.clone()) .get_or_create_connection(Some(local_addr), dial_info.clone())
.await?; .await?;
let res = conn.send_async(data).await.map_err(logthru_net!(error)); let res = conn.send_async(data).await;
if res.is_ok() { if res.is_ok() {
// Network accounting // Network accounting
self.network_manager() self.network_manager()
@ -414,43 +401,51 @@ impl Network {
// initialize interfaces // initialize interfaces
let mut interfaces = NetworkInterfaces::new(); let mut interfaces = NetworkInterfaces::new();
interfaces.refresh().await?; interfaces.refresh().await?;
self.inner.lock().interfaces = interfaces;
// get protocol config
let protocol_config = { let protocol_config = {
let c = self.config.get(); let mut inner = self.inner.lock();
let mut inbound = ProtocolSet::new();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp { // Create stop source
inbound.insert(ProtocolType::UDP); inner.stop_source = Some(StopSource::new());
} inner.interfaces = interfaces;
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
inbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
inbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
inbound.insert(ProtocolType::WSS);
}
let mut outbound = ProtocolSet::new(); // get protocol config
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp { let protocol_config = {
outbound.insert(ProtocolType::UDP); let c = self.config.get();
} let mut inbound = ProtocolSet::new();
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
outbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
outbound.insert(ProtocolType::WSS);
}
ProtocolConfig { inbound, outbound } if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
inbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
inbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
inbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
inbound.insert(ProtocolType::WSS);
}
let mut outbound = ProtocolSet::new();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
outbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
outbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
outbound.insert(ProtocolType::WSS);
}
ProtocolConfig { inbound, outbound }
};
inner.protocol_config = Some(protocol_config);
protocol_config
}; };
self.inner.lock().protocol_config = Some(protocol_config);
// start listeners // start listeners
if protocol_config.inbound.contains(ProtocolType::UDP) { if protocol_config.inbound.contains(ProtocolType::UDP) {
@ -503,28 +498,32 @@ impl Network {
#[instrument(level = "debug", skip_all)] #[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
info!("stopping network"); debug!("starting low level network shutdown");
let network_manager = self.network_manager(); let network_manager = self.network_manager();
let routing_table = self.routing_table(); let routing_table = self.routing_table();
// Stop all tasks // Stop all tasks
debug!("stopping update network class task");
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await { if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
error!("update_network_class_task not cancelled: {}", e); error!("update_network_class_task not cancelled: {}", e);
} }
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
{ {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
// Drop the stop
drop(inner.stop_source.take());
// take the join handles out // take the join handles out
for h in inner.join_handles.drain(..) { for h in inner.join_handles.drain(..) {
unord.push(h); unord.push(h);
} }
// Drop the stop
drop(inner.stop_source.take());
} }
debug!("stopping {} low level network tasks", unord.len());
// Wait for everything to stop // Wait for everything to stop
while unord.next().await.is_some() {} while unord.next().await.is_some() {}
debug!("clearing dial info");
// Drop all dial info // Drop all dial info
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet); routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork); routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
@ -532,7 +531,7 @@ impl Network {
// Reset state including network class // Reset state including network class
*self.inner.lock() = Self::new_inner(network_manager); *self.inner.lock() = Self::new_inner(network_manager);
info!("network stopped"); debug!("finished low level network shutdown");
} }
////////////////////////////////////////// //////////////////////////////////////////

View File

@ -1,6 +1,7 @@
use super::*; use super::*;
use futures_util::stream::FuturesUnordered; use futures_util::stream::FuturesUnordered;
use futures_util::FutureExt; use futures_util::FutureExt;
use stop_token::future::FutureExt as StopTokenFutureExt;
struct DetectedPublicDialInfo { struct DetectedPublicDialInfo {
dial_info: DialInfo, dial_info: DialInfo,
@ -584,20 +585,32 @@ impl Network {
// Wait for all discovery futures to complete and collect contexts // Wait for all discovery futures to complete and collect contexts
let mut contexts = Vec::<DiscoveryContext>::new(); let mut contexts = Vec::<DiscoveryContext>::new();
let mut network_class = Option::<NetworkClass>::None; let mut network_class = Option::<NetworkClass>::None;
while let Some(ctxvec) = unord.next().await { loop {
if let Some(ctxvec) = ctxvec { match unord.next().timeout_at(stop_token.clone()).await {
for ctx in ctxvec { Ok(Some(ctxvec)) => {
if let Some(nc) = ctx.inner.lock().detected_network_class { if let Some(ctxvec) = ctxvec {
if let Some(last_nc) = network_class { for ctx in ctxvec {
if nc < last_nc { if let Some(nc) = ctx.inner.lock().detected_network_class {
network_class = Some(nc); if let Some(last_nc) = network_class {
if nc < last_nc {
network_class = Some(nc);
}
} else {
network_class = Some(nc);
}
} }
} else {
network_class = Some(nc); contexts.push(ctx);
} }
} }
}
contexts.push(ctx); Ok(None) => {
// Normal completion
break;
}
Err(_) => {
// Stop token, exit early without error propagation
return Ok(());
} }
} }
} }

View File

@ -65,8 +65,7 @@ impl Network {
ps.peek_exact(&mut first_packet), ps.peek_exact(&mut first_packet),
) )
.await .await
.map_err(map_to_string) .map_err(map_to_string)?;
.map_err(logthru_net!())?;
self.try_handlers(ps, tcp_stream, addr, protocol_handlers) self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
.await .await
@ -82,8 +81,7 @@ impl Network {
for ah in protocol_accept_handlers.iter() { for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr) .on_accept(stream.clone(), tcp_stream.clone(), addr)
.await .await?
.map_err(logthru_net!())?
{ {
return Ok(Some(nc)); return Ok(Some(nc));
} }

View File

@ -39,15 +39,11 @@ impl ProtocolNetworkConnection {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data) udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
.await
.map_err(logthru_net!())
} }
ProtocolType::TCP => { ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr(); let peer_socket_addr = dial_info.to_socket_addr();
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data) tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
.await
.map_err(logthru_net!())
} }
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await

View File

@ -31,19 +31,14 @@ impl RawTcpNetworkConnection {
self.descriptor.clone() self.descriptor.clone()
} }
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream // Make an attempt to flush the stream
self.stream self.stream.clone().close().await.map_err(map_to_string)?;
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Then forcibly close the socket // Then forcibly close the socket
self.tcp_stream self.tcp_stream
.shutdown(Shutdown::Both) .shutdown(Shutdown::Both)
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!())
} }
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> { async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
@ -54,23 +49,17 @@ impl RawTcpNetworkConnection {
let len = message.len() as u16; let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8]; let header = [b'V', b'L', len as u8, (len >> 8) as u8];
stream stream.write_all(&header).await.map_err(map_to_string)?;
.write_all(&header) stream.write_all(&message).await.map_err(map_to_string)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
stream
.write_all(&message)
.await
.map_err(map_to_string)
.map_err(logthru_net!())
} }
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> { pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let stream = self.stream.clone(); let stream = self.stream.clone();
Self::send_internal(stream, message).await Self::send_internal(stream, message).await
} }
#[instrument(level="trace", err, skip(self), fields(message.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> { pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4]; let mut header = [0u8; 4];
@ -90,6 +79,8 @@ impl RawTcpNetworkConnection {
let mut out: Vec<u8> = vec![0u8; len]; let mut out: Vec<u8> = vec![0u8; len];
stream.read_exact(&mut out).await.map_err(map_to_string)?; stream.read_exact(&mut out).await.map_err(map_to_string)?;
tracing::Span::current().record("message.len", &out.len());
Ok(out) Ok(out)
} }
} }
@ -120,6 +111,7 @@ impl RawTcpProtocolHandler {
} }
} }
#[instrument(level = "trace", err, skip(self, stream, tcp_stream))]
async fn on_accept_async( async fn on_accept_async(
self, self,
stream: AsyncPeekStream, stream: AsyncPeekStream,
@ -151,6 +143,7 @@ impl RawTcpProtocolHandler {
Ok(Some(conn)) Ok(Some(conn))
} }
#[instrument(level = "trace", err)]
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
@ -191,6 +184,7 @@ impl RawTcpProtocolHandler {
Ok(conn) Ok(conn)
} }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message( pub async fn send_unbound_message(
socket_addr: SocketAddr, socket_addr: SocketAddr,
data: Vec<u8>, data: Vec<u8>,

View File

@ -10,6 +10,7 @@ impl RawUdpProtocolHandler {
Self { socket } Self { socket }
} }
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn recv_message( pub async fn recv_message(
&self, &self,
data: &mut [u8], data: &mut [u8],
@ -35,9 +36,13 @@ impl RawUdpProtocolHandler {
peer_addr, peer_addr,
SocketAddress::from_socket_addr(local_socket_addr), SocketAddress::from_socket_addr(local_socket_addr),
); );
tracing::Span::current().record("ret.len", &size);
tracing::Span::current().record("ret.from", &format!("{:?}", descriptor).as_str());
Ok((size, descriptor)) Ok((size, descriptor))
} }
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> { pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE { if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error)); return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));

View File

@ -49,21 +49,17 @@ where
self.descriptor.clone() self.descriptor.clone()
} }
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream // Make an attempt to flush the stream
self.stream self.stream.clone().close().await.map_err(map_to_string)?;
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Then forcibly close the socket // Then forcibly close the socket
self.tcp_stream self.tcp_stream
.shutdown(Shutdown::Both) .shutdown(Shutdown::Both)
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!())
} }
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> { pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
if message.len() > MAX_MESSAGE_SIZE { if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned()); return Err("received too large WS message".to_owned());
@ -76,6 +72,7 @@ where
.map_err(logthru_net!(error "failed to send websocket message")) .map_err(logthru_net!(error "failed to send websocket message"))
} }
#[instrument(level="trace", err, skip(self), fields(message.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> { pub async fn recv(&self) -> Result<Vec<u8>, String> {
let out = match self.stream.clone().next().await { let out = match self.stream.clone().next().await {
Some(Ok(Message::Binary(v))) => v, Some(Ok(Message::Binary(v))) => v,
@ -86,12 +83,13 @@ where
return Err(e.to_string()).map_err(logthru_net!(error)); return Err(e.to_string()).map_err(logthru_net!(error));
} }
None => { None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!()); return Err("WS stream closed".to_owned());
} }
}; };
if out.len() > MAX_MESSAGE_SIZE { if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error)) Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else { } else {
tracing::Span::current().record("message.len", &out.len());
Ok(out) Ok(out)
} }
} }
@ -137,6 +135,7 @@ impl WebsocketProtocolHandler {
} }
} }
#[instrument(level = "trace", err, skip(self, ps, tcp_stream))]
pub async fn on_accept_async( pub async fn on_accept_async(
self, self,
ps: AsyncPeekStream, ps: AsyncPeekStream,
@ -156,9 +155,9 @@ impl WebsocketProtocolHandler {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
if e.kind() == io::ErrorKind::TimedOut { if e.kind() == io::ErrorKind::TimedOut {
return Err(e).map_err(map_to_string).map_err(logthru_net!()); return Err(e).map_err(map_to_string);
} }
return Err(e).map_err(map_to_string).map_err(logthru_net!(error)); return Err(e).map_err(map_to_string);
} }
} }
@ -237,10 +236,7 @@ impl WebsocketProtocolHandler {
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?; .map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
// See what local address we ended up with // See what local address we ended up with
let actual_local_addr = tcp_stream let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Make our connection descriptor // Make our connection descriptor
let descriptor = ConnectionDescriptor::new( let descriptor = ConnectionDescriptor::new(
@ -274,6 +270,7 @@ impl WebsocketProtocolHandler {
} }
} }
#[instrument(level = "trace", err)]
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
@ -281,6 +278,7 @@ impl WebsocketProtocolHandler {
Self::connect_internal(local_address, dial_info).await Self::connect_internal(local_address, dial_info).await
} }
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> { pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE { if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned()); return Err("sending too large unbound WS message".to_owned());

View File

@ -89,6 +89,7 @@ pub struct NetworkConnection {
established_time: u64, established_time: u64,
stats: Arc<Mutex<NetworkConnectionStats>>, stats: Arc<Mutex<NetworkConnectionStats>>,
sender: flume::Sender<Vec<u8>>, sender: flume::Sender<Vec<u8>>,
stop_source: Option<StopSource>,
} }
impl NetworkConnection { impl NetworkConnection {
@ -105,12 +106,13 @@ impl NetworkConnection {
last_message_recv_time: None, last_message_recv_time: None,
})), })),
sender, sender,
stop_source: None,
} }
} }
pub(super) fn from_protocol( pub(super) fn from_protocol(
connection_manager: ConnectionManager, connection_manager: ConnectionManager,
stop_token: StopToken, manager_stop_token: StopToken,
protocol_connection: ProtocolNetworkConnection, protocol_connection: ProtocolNetworkConnection,
) -> Self { ) -> Self {
// Get timeout // Get timeout
@ -133,10 +135,14 @@ impl NetworkConnection {
last_message_recv_time: None, last_message_recv_time: None,
})); }));
let stop_source = StopSource::new();
let local_stop_token = stop_source.token();
// Spawn connection processor and pass in protocol connection // Spawn connection processor and pass in protocol connection
let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection( let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection(
connection_manager, connection_manager,
stop_token, local_stop_token,
manager_stop_token,
descriptor.clone(), descriptor.clone(),
receiver, receiver,
protocol_connection, protocol_connection,
@ -151,6 +157,7 @@ impl NetworkConnection {
established_time: intf::get_timestamp(), established_time: intf::get_timestamp(),
stats, stats,
sender, sender,
stop_source: Some(stop_source),
} }
} }
@ -162,6 +169,13 @@ impl NetworkConnection {
ConnectionHandle::new(self.descriptor.clone(), self.sender.clone()) ConnectionHandle::new(self.descriptor.clone(), self.sender.clone())
} }
pub fn close(&mut self) {
if let Some(stop_source) = self.stop_source.take() {
// drop the stopper
drop(stop_source);
}
}
async fn send_internal( async fn send_internal(
protocol_connection: &ProtocolNetworkConnection, protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>, stats: Arc<Mutex<NetworkConnectionStats>>,
@ -200,7 +214,8 @@ impl NetworkConnection {
// Connection receiver loop // Connection receiver loop
fn process_connection( fn process_connection(
connection_manager: ConnectionManager, connection_manager: ConnectionManager,
stop_token: StopToken, local_stop_token: StopToken,
manager_stop_token: StopToken,
descriptor: ConnectionDescriptor, descriptor: ConnectionDescriptor,
receiver: flume::Receiver<Vec<u8>>, receiver: flume::Receiver<Vec<u8>>,
protocol_connection: ProtocolNetworkConnection, protocol_connection: ProtocolNetworkConnection,
@ -293,7 +308,13 @@ impl NetworkConnection {
} }
// Process futures // Process futures
match unord.next().timeout_at(stop_token.clone()).await { match unord
.next()
.timeout_at(local_stop_token.clone())
.timeout_at(manager_stop_token.clone())
.await
.and_then(std::convert::identity) // flatten
{
Ok(Some(RecvLoopAction::Send)) => { Ok(Some(RecvLoopAction::Send)) => {
// Don't reset inactivity timer if we're only sending // Don't reset inactivity timer if we're only sending
need_sender = true; need_sender = true;
@ -312,7 +333,7 @@ impl NetworkConnection {
unreachable!(); unreachable!();
} }
Err(_) => { Err(_) => {
// Stop token // Either one of the stop tokens
break; break;
} }
} }

View File

@ -82,7 +82,12 @@ pub async fn test_add_get_remove() {
assert_eq!(table.get_connection(a1), Some(c1h.clone())); assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.get_connection(a1), Some(c1h.clone())); assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.connection_count(), 1); assert_eq!(table.connection_count(), 1);
assert_eq!(table.remove_connection(a2), Ok(())); assert_eq!(
table
.remove_connection(a2)
.map(|c| c.connection_descriptor()),
Ok(a1)
);
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
assert_err!(table.remove_connection(a2)); assert_err!(table.remove_connection(a2));
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
@ -98,9 +103,24 @@ pub async fn test_add_get_remove() {
table.add_connection(c3).unwrap(); table.add_connection(c3).unwrap();
table.add_connection(c4).unwrap(); table.add_connection(c4).unwrap();
assert_eq!(table.connection_count(), 3); assert_eq!(table.connection_count(), 3);
assert_eq!(table.remove_connection(a2), Ok(())); assert_eq!(
assert_eq!(table.remove_connection(a3), Ok(())); table
assert_eq!(table.remove_connection(a4), Ok(())); .remove_connection(a2)
.map(|c| c.connection_descriptor()),
Ok(a2)
);
assert_eq!(
table
.remove_connection(a3)
.map(|c| c.connection_descriptor()),
Ok(a3)
);
assert_eq!(
table
.remove_connection(a4)
.map(|c| c.connection_descriptor()),
Ok(a4)
);
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
} }

View File

@ -69,7 +69,6 @@ impl Network {
ProtocolType::WS | ProtocolType::WSS => { ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data) WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
.await .await
.map_err(logthru_net!())
} }
}; };
if res.is_ok() { if res.is_ok() {
@ -102,7 +101,7 @@ impl Network {
// Try to send to the exact existing connection if one exists // Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor).await { if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
// connection exists, send over it // connection exists, send over it
conn.send_async(data).await.map_err(logthru_net!())?; conn.send_async(data).await?;
// Network accounting // Network accounting
self.network_manager() self.network_manager()

View File

@ -314,6 +314,7 @@ impl ReceiptManager {
} }
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
debug!("starting receipt manager shutdown");
let network_manager = self.network_manager(); let network_manager = self.network_manager();
// Stop all tasks // Stop all tasks
@ -325,11 +326,13 @@ impl ReceiptManager {
}; };
// Wait for everything to stop // Wait for everything to stop
debug!("waiting for timeout task to stop");
if !timeout_task.join().await.is_ok() { if !timeout_task.join().await.is_ok() {
panic!("joining timeout task failed"); panic!("joining timeout task failed");
} }
*self.inner.lock() = Self::new_inner(network_manager); *self.inner.lock() = Self::new_inner(network_manager);
debug!("finished receipt manager shutdown");
} }
pub fn record_receipt( pub fn record_receipt(

View File

@ -374,19 +374,26 @@ impl RoutingTable {
} }
pub async fn terminate(&self) { pub async fn terminate(&self) {
debug!("starting routing table terminate");
// Cancel all tasks being ticked // Cancel all tasks being ticked
debug!("stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await { if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
error!("rolling_transfers_task not stopped: {}", e); error!("rolling_transfers_task not stopped: {}", e);
} }
debug!("stopping bootstrap task");
if let Err(e) = self.unlocked_inner.bootstrap_task.stop().await { if let Err(e) = self.unlocked_inner.bootstrap_task.stop().await {
error!("bootstrap_task not stopped: {}", e); error!("bootstrap_task not stopped: {}", e);
} }
debug!("stopping peer minimum refresh task");
if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.stop().await { if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.stop().await {
error!("peer_minimum_refresh_task not stopped: {}", e); error!("peer_minimum_refresh_task not stopped: {}", e);
} }
debug!("stopping ping_validator task");
if let Err(e) = self.unlocked_inner.ping_validator_task.stop().await { if let Err(e) = self.unlocked_inner.ping_validator_task.stop().await {
error!("ping_validator_task not stopped: {}", e); error!("ping_validator_task not stopped: {}", e);
} }
debug!("stopping node info update singlefuture");
if self if self
.unlocked_inner .unlocked_inner
.node_info_update_single_future .node_info_update_single_future
@ -398,6 +405,8 @@ impl RoutingTable {
} }
*self.inner.lock() = Self::new_inner(self.network_manager()); *self.inner.lock() = Self::new_inner(self.network_manager());
debug!("finished routing table terminate");
} }
// Inform routing table entries that our dial info has changed // Inform routing table entries that our dial info has changed

View File

@ -1428,23 +1428,30 @@ impl RPCProcessor {
} }
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
debug!("starting rpc processor shutdown");
// Stop the rpc workers // Stop the rpc workers
let mut unord = FuturesUnordered::new(); let mut unord = FuturesUnordered::new();
{ {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
// drop the stop
drop(inner.stop_source.take());
// take the join handles out // take the join handles out
for h in inner.worker_join_handles.drain(..) { for h in inner.worker_join_handles.drain(..) {
unord.push(h); unord.push(h);
} }
// drop the stop
drop(inner.stop_source.take());
} }
debug!("stopping {} rpc worker tasks", unord.len());
// Wait for them to complete // Wait for them to complete
while unord.next().await.is_some() {} while unord.next().await.is_some() {}
debug!("resetting rpc processor state");
// Release the rpc processor // Release the rpc processor
*self.inner.lock() = Self::new_inner(self.network_manager()); *self.inner.lock() = Self::new_inner(self.network_manager());
debug!("finished rpc processor shutdown");
} }
pub fn enqueue_message( pub fn enqueue_message(

View File

@ -518,10 +518,10 @@ pub async fn test_single_future() {
69 69
}) })
.await, .await,
Ok(None) Ok((None, true))
); );
assert_eq!(sf.check().await, Ok(None)); assert_eq!(sf.check().await, Ok(None));
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None)); assert_eq!(sf.single_spawn(async { panic!() }).await, Ok((None, false)));
assert_eq!(sf.join().await, Ok(Some(69))); assert_eq!(sf.join().await, Ok(Some(69)));
assert_eq!( assert_eq!(
sf.single_spawn(async { sf.single_spawn(async {
@ -529,7 +529,7 @@ pub async fn test_single_future() {
37 37
}) })
.await, .await,
Ok(None) Ok((None, true))
); );
intf::sleep(2000).await; intf::sleep(2000).await;
assert_eq!( assert_eq!(
@ -538,7 +538,7 @@ pub async fn test_single_future() {
27 27
}) })
.await, .await,
Ok(Some(37)) Ok((Some(37), true))
); );
intf::sleep(2000).await; intf::sleep(2000).await;
assert_eq!(sf.join().await, Ok(Some(27))); assert_eq!(sf.join().await, Ok(Some(27)));
@ -555,10 +555,10 @@ pub async fn test_must_join_single_future() {
69 69
}) })
.await, .await,
Ok(None) Ok((None, true))
); );
assert_eq!(sf.check().await, Ok(None)); assert_eq!(sf.check().await, Ok(None));
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None)); assert_eq!(sf.single_spawn(async { panic!() }).await, Ok((None, false)));
assert_eq!(sf.join().await, Ok(Some(69))); assert_eq!(sf.join().await, Ok(Some(69)));
assert_eq!( assert_eq!(
sf.single_spawn(async { sf.single_spawn(async {
@ -566,7 +566,7 @@ pub async fn test_must_join_single_future() {
37 37
}) })
.await, .await,
Ok(None) Ok((None, true))
); );
intf::sleep(2000).await; intf::sleep(2000).await;
assert_eq!( assert_eq!(
@ -575,7 +575,7 @@ pub async fn test_must_join_single_future() {
27 27
}) })
.await, .await,
Ok(Some(37)) Ok((Some(37), true))
); );
intf::sleep(2000).await; intf::sleep(2000).await;
assert_eq!(sf.join().await, Ok(Some(27))); assert_eq!(sf.join().await, Ok(Some(27)));

View File

@ -1,20 +1,19 @@
use async_executors::JoinHandle; use async_executors::JoinHandle;
use core::future::Future; use core::future::Future;
use core::pin::Pin; use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll}; use core::task::{Context, Poll};
#[derive(Debug)] #[derive(Debug)]
pub struct MustJoinHandle<T> { pub struct MustJoinHandle<T> {
join_handle: JoinHandle<T>, join_handle: JoinHandle<T>,
completed: AtomicBool, completed: bool,
} }
impl<T> MustJoinHandle<T> { impl<T> MustJoinHandle<T> {
pub fn new(join_handle: JoinHandle<T>) -> Self { pub fn new(join_handle: JoinHandle<T>) -> Self {
Self { Self {
join_handle, join_handle,
completed: AtomicBool::new(false), completed: false,
} }
} }
} }
@ -22,7 +21,7 @@ impl<T> MustJoinHandle<T> {
impl<T> Drop for MustJoinHandle<T> { impl<T> Drop for MustJoinHandle<T> {
fn drop(&mut self) { fn drop(&mut self) {
// panic if we haven't completed // panic if we haven't completed
if !self.completed.load(Ordering::Relaxed) { if !self.completed {
panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.") panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.")
} }
} }
@ -34,7 +33,7 @@ impl<T: 'static> Future for MustJoinHandle<T> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.join_handle).poll(cx) { match Pin::new(&mut self.join_handle).poll(cx) {
Poll::Ready(t) => { Poll::Ready(t) => {
self.completed.store(true, Ordering::Relaxed); self.completed = true;
Poll::Ready(t) Poll::Ready(t)
} }
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,

View File

@ -131,7 +131,7 @@ where
pub async fn single_spawn( pub async fn single_spawn(
&self, &self,
future: impl Future<Output = T> + 'static, future: impl Future<Output = T> + 'static,
) -> Result<Option<T>, ()> { ) -> Result<(Option<T>,bool), ()> {
let mut out: Option<T> = None; let mut out: Option<T> = None;
// See if we have a result we can return // See if we have a result we can return
@ -164,7 +164,7 @@ where
} }
// Return the prior result if we have one // Return the prior result if we have one
Ok(out) Ok((out, run))
} }
} }
} }
@ -178,7 +178,7 @@ cfg_if! {
pub async fn single_spawn( pub async fn single_spawn(
&self, &self,
future: impl Future<Output = T> + Send + 'static, future: impl Future<Output = T> + Send + 'static,
) -> Result<Option<T>, ()> { ) -> Result<(Option<T>, bool), ()> {
let mut out: Option<T> = None; let mut out: Option<T> = None;
// See if we have a result we can return // See if we have a result we can return
let maybe_jh = match self.try_lock() { let maybe_jh = match self.try_lock() {
@ -206,7 +206,7 @@ cfg_if! {
self.unlock(Some(MustJoinHandle::new(spawn(future)))); self.unlock(Some(MustJoinHandle::new(spawn(future))));
} }
// Return the prior result if we have one // Return the prior result if we have one
Ok(out) Ok((out, run))
} }
} }
} }

View File

@ -160,7 +160,7 @@ where
pub async fn single_spawn( pub async fn single_spawn(
&self, &self,
future: impl Future<Output = T> + 'static, future: impl Future<Output = T> + 'static,
) -> Result<Option<T>, ()> { ) -> Result<(Option<T>, bool), ()> {
let mut out: Option<T> = None; let mut out: Option<T> = None;
// See if we have a result we can return // See if we have a result we can return
@ -193,7 +193,7 @@ where
} }
// Return the prior result if we have one // Return the prior result if we have one
Ok(out) Ok((out, run))
} }
} }
} }
@ -207,7 +207,7 @@ cfg_if! {
pub async fn single_spawn( pub async fn single_spawn(
&self, &self,
future: impl Future<Output = T> + Send + 'static, future: impl Future<Output = T> + Send + 'static,
) -> Result<Option<T>, ()> { ) -> Result<(Option<T>, bool), ()> {
let mut out: Option<T> = None; let mut out: Option<T> = None;
// See if we have a result we can return // See if we have a result we can return
let maybe_jh = match self.try_lock() { let maybe_jh = match self.try_lock() {
@ -235,7 +235,7 @@ cfg_if! {
self.unlock(Some(spawn(future))); self.unlock(Some(spawn(future)));
} }
// Return the prior result if we have one // Return the prior result if we have one
Ok(out) Ok((out, run))
} }
} }
} }

View File

@ -76,11 +76,13 @@ impl TickTask {
let opt_stop_source = &mut *self.stop_source.lock().await; let opt_stop_source = &mut *self.stop_source.lock().await;
if opt_stop_source.is_none() { if opt_stop_source.is_none() {
// already stopped, just return // already stopped, just return
trace!("tick task already stopped");
return Ok(()); return Ok(());
} }
*opt_stop_source = None; drop(opt_stop_source.take());
// wait for completion of the tick task // wait for completion of the tick task
trace!("stopping single future");
match self.single_future.join().await { match self.single_future.join().await {
Ok(Some(Err(err))) => Err(err), Ok(Some(Err(err))) => Err(err),
_ => Ok(()), _ => Ok(()),
@ -91,37 +93,61 @@ impl TickTask {
let now = get_timestamp(); let now = get_timestamp();
let last_timestamp_us = self.last_timestamp_us.load(Ordering::Acquire); let last_timestamp_us = self.last_timestamp_us.load(Ordering::Acquire);
if last_timestamp_us == 0u64 || (now - last_timestamp_us) >= self.tick_period_us { if last_timestamp_us != 0u64 && (now - last_timestamp_us) < self.tick_period_us {
// Run the singlefuture
let opt_stop_source = &mut *self.stop_source.lock().await;
let stop_source = StopSource::new();
match self
.single_future
.single_spawn(self.routine.get().unwrap()(
stop_source.token(),
last_timestamp_us,
now,
))
.await
{
// Single future ran this tick
Ok(Some(ret)) => {
// Set new timer
self.last_timestamp_us.store(now, Ordering::Release);
// Save new stopper
*opt_stop_source = Some(stop_source);
ret
}
// Single future did not run this tick
Ok(None) | Err(()) => {
// If the execution didn't happen this time because it was already running
// then we should try again the next tick and not reset the timestamp so we try as soon as possible
Ok(())
}
}
} else {
// It's not time yet // It's not time yet
Ok(()) return Ok(());
}
// Lock the stop source, tells us if we have ever started this future
let opt_stop_source = &mut *self.stop_source.lock().await;
if opt_stop_source.is_some() {
// See if the previous execution finished with an error
match self.single_future.check().await {
Ok(Some(Err(e))) => {
// We have an error result, which means the singlefuture ran but we need to propagate the error
return Err(e);
}
Ok(Some(Ok(()))) => {
// We have an ok result, which means the singlefuture ran, and we should run it again this tick
}
Ok(None) => {
// No prior result to return which means things are still running
// We can just return now, since the singlefuture will not run a second time
return Ok(());
}
Err(()) => {
// If we get this, it's because we are joining the singlefuture already
// Don't bother running but this is not an error in this case
return Ok(());
}
};
}
// Run the singlefuture
let stop_source = StopSource::new();
match self
.single_future
.single_spawn(self.routine.get().unwrap()(
stop_source.token(),
last_timestamp_us,
now,
))
.await
{
// We should have already consumed the result of the last run, or there was none
// and we should definitely have run, because the prior 'check()' operation
// should have ensured the singlefuture was ready to run
Ok((None, true)) => {
// Set new timer
self.last_timestamp_us.store(now, Ordering::Release);
// Save new stopper
*opt_stop_source = Some(stop_source);
Ok(())
}
// All other conditions should not be reachable
_ => {
unreachable!();
}
} }
} }
} }

View File

@ -32,11 +32,13 @@ pub async fn run_veilid_server(settings: Settings, server_mode: ServerMode) -> R
run_veilid_server_internal(settings, server_mode).await run_veilid_server_internal(settings, server_mode).await
} }
#[instrument(err)] #[instrument(err, skip_all)]
pub async fn run_veilid_server_internal( pub async fn run_veilid_server_internal(
settings: Settings, settings: Settings,
server_mode: ServerMode, server_mode: ServerMode,
) -> Result<(), String> { ) -> Result<(), String> {
trace!(?settings, ?server_mode);
let settingsr = settings.read(); let settingsr = settings.read();
// Create client api state change pipe // Create client api state change pipe