fix cooperative cancellation

This commit is contained in:
John Smith 2022-06-15 14:05:04 -04:00
parent 180628beef
commit c33f78ac8b
24 changed files with 520 additions and 299 deletions

34
.vscode/launch.json vendored
View File

@ -29,13 +29,17 @@
"type": "lldb",
"request": "launch",
"name": "Launch veilid-cli",
"args": ["--debug"],
"args": [
"--debug"
],
"program": "${workspaceFolder}/target/debug/veilid-cli",
"windows": {
"program": "${workspaceFolder}/target/debug/veilid-cli.exe"
},
"cwd": "${workspaceFolder}/target/debug/",
"sourceLanguages": ["rust"],
"sourceLanguages": [
"rust"
],
"terminal": "console"
},
// {
@ -48,20 +52,21 @@
// "args": ["--trace"],
// "cwd": "${workspaceFolder}/veilid-server"
// }
{
"type": "lldb",
"request": "launch",
"name": "Debug veilid-server",
"program": "${workspaceFolder}/target/debug/veilid-server",
"args": ["--trace", "--attach=true"],
"args": [
"--trace",
"--attach=true"
],
"cwd": "${workspaceFolder}/target/debug/",
"env": {
"RUST_BACKTRACE": "1"
},
"terminal": "console"
},
{
"type": "lldb",
"request": "launch",
@ -78,10 +83,11 @@
"name": "veilid-core"
}
},
"args": ["${selectedText}"],
"args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/target/debug/"
},
{
"type": "lldb",
"request": "launch",
@ -98,10 +104,11 @@
"name": "veilid-server"
}
},
"args": ["${selectedText}"],
"args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/veilid-server"
},
{
"type": "lldb",
"request": "launch",
@ -118,10 +125,11 @@
"name": "keyvaluedb-sqlite"
}
},
"args": ["${selectedText}"],
"args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/external/keyvaluedb/keyvaluedb-sqlite"
},
{
"type": "lldb",
"request": "launch",
@ -138,7 +146,9 @@
"name": "keyring"
}
},
"args": ["${selectedText}"],
"args": [
"${selectedText}"
],
"cwd": "${workspaceFolder}/external/keyring-rs"
}
]

View File

@ -334,7 +334,7 @@ impl NetworkInterfaces {
self.valid = false;
let last_interfaces = core::mem::take(&mut self.interfaces);
let mut platform_support = PlatformSupport::new().map_err(logthru_net!())?;
let mut platform_support = PlatformSupport::new()?;
platform_support
.get_interfaces(&mut self.interfaces)
.await?;

View File

@ -2,19 +2,28 @@ use super::*;
use crate::xx::*;
use connection_table::*;
use network_connection::*;
use stop_token::future::FutureExt;
///////////////////////////////////////////////////////////
// Connection manager
#[derive(Debug)]
enum ConnectionManagerEvent {
Accepted(ProtocolNetworkConnection),
Finished(ConnectionDescriptor),
}
#[derive(Debug)]
struct ConnectionManagerInner {
connection_table: ConnectionTable,
sender: flume::Sender<ConnectionManagerEvent>,
async_processor_jh: Option<MustJoinHandle<()>>,
stop_source: Option<StopSource>,
}
struct ConnectionManagerArc {
network_manager: NetworkManager,
inner: AsyncMutex<Option<ConnectionManagerInner>>,
inner: Mutex<Option<ConnectionManagerInner>>,
}
impl core::fmt::Debug for ConnectionManagerArc {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
@ -30,16 +39,23 @@ pub struct ConnectionManager {
}
impl ConnectionManager {
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
fn new_inner(
config: VeilidConfig,
stop_source: StopSource,
sender: flume::Sender<ConnectionManagerEvent>,
async_processor_jh: JoinHandle<()>,
) -> ConnectionManagerInner {
ConnectionManagerInner {
stop_source: Some(StopSource::new()),
stop_source: Some(stop_source),
sender: sender,
async_processor_jh: Some(MustJoinHandle::new(async_processor_jh)),
connection_table: ConnectionTable::new(config),
}
}
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
ConnectionManagerArc {
network_manager,
inner: AsyncMutex::new(None),
inner: Mutex::new(None),
}
}
pub fn new(network_manager: NetworkManager) -> Self {
@ -54,18 +70,34 @@ impl ConnectionManager {
pub async fn startup(&self) {
trace!("startup connection manager");
let mut inner = self.arc.inner.lock().await;
let mut inner = self.arc.inner.lock();
if inner.is_some() {
panic!("shouldn't start connection manager twice without shutting it down first");
}
*inner = Some(Self::new_inner(self.network_manager().config()));
// Create channel for async_processor to receive notifications of networking events
let (sender, receiver) = flume::unbounded();
// Create the stop source we'll use to stop the processor and the connection table
let stop_source = StopSource::new();
// Spawn the async processor
let async_processor = spawn(self.clone().async_processor(stop_source.token(), receiver));
// Store in the inner object
*inner = Some(Self::new_inner(
self.network_manager().config(),
stop_source,
sender,
async_processor,
));
}
pub async fn shutdown(&self) {
debug!("starting connection manager shutdown");
// Remove the inner from the lock
let mut inner = {
let mut inner_lock = self.arc.inner.lock().await;
let mut inner_lock = self.arc.inner.lock();
let inner = match inner_lock.take() {
Some(v) => v,
None => {
@ -75,11 +107,17 @@ impl ConnectionManager {
inner
};
// Stop all the connections
// Stop all the connections and the async processor
debug!("stopping async processor task");
drop(inner.stop_source.take());
let async_processor_jh = inner.async_processor_jh.take().unwrap();
// wait for the async processor to stop
debug!("waiting for async processor to stop");
async_processor_jh.await;
// Wait for the connections to complete
debug!("waiting for connection handlers to complete");
inner.connection_table.join().await;
debug!("finished connection manager shutdown");
}
// Returns a network connection if one already is established
@ -87,7 +125,7 @@ impl ConnectionManager {
&self,
descriptor: ConnectionDescriptor,
) -> Option<ConnectionHandle> {
let mut inner = self.arc.inner.lock().await;
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
Some(v) => v,
None => {
@ -128,7 +166,8 @@ impl ConnectionManager {
local_addr: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ConnectionHandle, String> {
let mut inner = self.arc.inner.lock().await;
let killed = {
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
Some(v) => v,
None => {
@ -168,7 +207,7 @@ impl ConnectionManager {
// Drop any other protocols connections to this remote that have the same local addr
// otherwise this connection won't succeed due to binding
let mut killed = false;
let mut killed = Vec::<NetworkConnection>::new();
if let Some(local_addr) = local_addr {
if local_addr.port() != 0 {
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
@ -193,21 +232,29 @@ impl ConnectionManager {
">< Terminating connection prior_descriptor={:?}",
prior_descriptor
);
if let Err(e) =
inner.connection_table.remove_connection(prior_descriptor)
{
log_net!(error e);
}
killed = true;
let mut conn = inner
.connection_table
.remove_connection(prior_descriptor)
.expect("connection not in table");
conn.close();
killed.push(conn);
}
}
}
}
}
killed
};
// Wait for the killed connections to end their recv loops
let mut retry_count = if !killed.is_empty() { 2 } else { 0 };
for k in killed {
k.await;
}
// Attempt new connection
let mut retry_count = if killed { 2 } else { 0 };
let conn = loop {
match ProtocolNetworkConnection::connect(local_addr, dial_info.clone()).await {
Ok(v) => break Ok(v),
@ -222,19 +269,77 @@ impl ConnectionManager {
}
}?;
self.on_new_protocol_network_connection(&mut *inner, conn)
// Add to the connection table
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
Some(v) => v,
None => {
return Err("shutting down".to_owned());
}
};
self.on_new_protocol_network_connection(inner, conn)
}
///////////////////////////////////////////////////////////////////////////////////////////////////////
/// Callbacks
#[instrument(level = "trace", skip_all)]
async fn async_processor(
self,
stop_token: StopToken,
receiver: flume::Receiver<ConnectionManagerEvent>,
) {
// Process async commands
while let Ok(Ok(event)) = receiver.recv_async().timeout_at(stop_token.clone()).await {
match event {
ConnectionManagerEvent::Accepted(conn) => {
let mut inner = self.arc.inner.lock();
match &mut *inner {
Some(inner) => {
// Register the connection
// We don't care if this fails, since nobody here asked for the inbound connection.
// If it does, we just drop the connection
let _ = self.on_new_protocol_network_connection(inner, conn);
}
None => {
// If this somehow happens, we're shutting down
}
};
}
ConnectionManagerEvent::Finished(desc) => {
let conn = {
let mut inner_lock = self.arc.inner.lock();
match &mut *inner_lock {
Some(inner) => {
// Remove the connection and wait for the connection loop to terminate
if let Ok(conn) = inner.connection_table.remove_connection(desc) {
// Must close and wait to ensure things join
Some(conn)
} else {
None
}
}
None => None,
}
};
if let Some(mut conn) = conn {
conn.close();
conn.await;
}
}
}
}
}
// Called by low-level network when any connection-oriented protocol connection appears
// either from incoming connections.
pub(super) async fn on_accepted_protocol_network_connection(
&self,
conn: ProtocolNetworkConnection,
) -> Result<(), String> {
let mut inner = self.arc.inner.lock().await;
// Get channel sender
let sender = {
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
Some(v) => v,
None => {
@ -242,24 +347,35 @@ impl ConnectionManager {
return Ok(());
}
};
self.on_new_protocol_network_connection(inner, conn)
.map(drop)
inner.sender.clone()
};
// Inform the processor of the event
let _ = sender
.send_async(ConnectionManagerEvent::Accepted(conn))
.await;
Ok(())
}
// Callback from network connection receive loop when it exits
// cleans up the entry in the connection table
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
let mut inner = self.arc.inner.lock().await;
// Get channel sender
let sender = {
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
Some(v) => v,
None => {
// If we're shutting down, do nothing here
// If we are shutting down, just drop this and return
return;
}
};
inner.sender.clone()
};
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
log_net!(error e);
}
// Inform the processor of the event
let _ = sender
.send_async(ConnectionManagerEvent::Finished(descriptor))
.await;
}
}

View File

@ -160,13 +160,16 @@ impl ConnectionTable {
.expect("Inconsistency in connection table");
}
pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> {
pub fn remove_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> {
let index = protocol_to_index(descriptor.protocol_type());
let _ = self.conn_by_descriptor[index]
let conn = self.conn_by_descriptor[index]
.remove(&descriptor)
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
self.remove_connection_records(descriptor);
Ok(())
Ok(conn)
}
}

View File

@ -299,26 +299,31 @@ impl NetworkManager {
#[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) {
trace!("NetworkManager::shutdown begin");
debug!("starting network manager shutdown");
// Cancel all tasks
debug!("stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
warn!("rolling_transfers_task not stopped: {}", e);
}
debug!("stopping relay management task task");
if let Err(e) = self.unlocked_inner.relay_management_task.stop().await {
warn!("relay_management_task not stopped: {}", e);
}
// Shutdown network components if they started up
debug!("shutting down network components");
let components = self.inner.lock().components.clone();
if let Some(components) = components {
components.receipt_manager.shutdown().await;
components.rpc_processor.shutdown().await;
components.net.shutdown().await;
components.connection_manager.shutdown().await;
components.rpc_processor.shutdown().await;
components.receipt_manager.shutdown().await;
}
// reset the state
debug!("resetting network manager state");
{
let mut inner = self.inner.lock();
inner.components = None;
@ -326,9 +331,10 @@ impl NetworkManager {
}
// send update
debug!("sending network state update");
self.send_network_update();
trace!("NetworkManager::shutdown end");
debug!("finished network manager shutdown");
}
pub fn update_client_whitelist(&self, client: DHTKey) {
@ -883,8 +889,7 @@ impl NetworkManager {
match self
.net()
.send_data_to_existing_connection(descriptor, data)
.await
.map_err(logthru_net!())?
.await?
{
None => Ok(()),
Some(_) => Err("unable to send over reverse connection".to_owned()),
@ -973,8 +978,7 @@ impl NetworkManager {
match self
.net()
.send_data_to_existing_connection(descriptor, data)
.await
.map_err(logthru_net!())?
.await?
{
None => Ok(()),
Some(_) => Err("unable to send over hole punch".to_owned()),
@ -1005,8 +1009,7 @@ impl NetworkManager {
match this
.net()
.send_data_to_existing_connection(descriptor, data)
.await
.map_err(logthru_net!())?
.await?
{
None => {
return Ok(if descriptor.matches_peer_scope(PeerScope::Local) {
@ -1150,7 +1153,7 @@ impl NetworkManager {
"failed to resolve recipient node for relay, dropping outbound relayed packet...: {:?}",
e
)
}).map_err(logthru_net!())?
})?
} else {
// If this is not a node in the client whitelist, only allow inbound relay
// which only performs a lightweight lookup before passing the packet back out

View File

@ -279,20 +279,14 @@ impl Network {
let res = match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
.await
.map_err(logthru_net!())
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
}
};
if res.is_ok() {
@ -324,10 +318,7 @@ impl Network {
descriptor
);
ph.clone()
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!())?;
ph.clone().send_message(data, peer_socket_addr).await?;
// Network accounting
self.network_manager()
@ -345,7 +336,7 @@ impl Network {
log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it
conn.send_async(data).await.map_err(logthru_net!())?;
conn.send_async(data).await?;
// Network accounting
self.network_manager()
@ -372,10 +363,7 @@ impl Network {
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
let res = ph
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
let res = ph.send_message(data, peer_socket_addr).await;
if res.is_ok() {
// Network accounting
self.network_manager()
@ -383,8 +371,7 @@ impl Network {
}
return res;
}
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
.map_err(logthru_net!(error));
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
}
// Handle connection-oriented protocols
@ -394,7 +381,7 @@ impl Network {
.get_or_create_connection(Some(local_addr), dial_info.clone())
.await?;
let res = conn.send_async(data).await.map_err(logthru_net!(error));
let res = conn.send_async(data).await;
if res.is_ok() {
// Network accounting
self.network_manager()
@ -414,7 +401,13 @@ impl Network {
// initialize interfaces
let mut interfaces = NetworkInterfaces::new();
interfaces.refresh().await?;
self.inner.lock().interfaces = interfaces;
let protocol_config = {
let mut inner = self.inner.lock();
// Create stop source
inner.stop_source = Some(StopSource::new());
inner.interfaces = interfaces;
// get protocol config
let protocol_config = {
@ -450,7 +443,9 @@ impl Network {
ProtocolConfig { inbound, outbound }
};
self.inner.lock().protocol_config = Some(protocol_config);
inner.protocol_config = Some(protocol_config);
protocol_config
};
// start listeners
if protocol_config.inbound.contains(ProtocolType::UDP) {
@ -503,28 +498,32 @@ impl Network {
#[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) {
info!("stopping network");
debug!("starting low level network shutdown");
let network_manager = self.network_manager();
let routing_table = self.routing_table();
// Stop all tasks
debug!("stopping update network class task");
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
error!("update_network_class_task not cancelled: {}", e);
}
let mut unord = FuturesUnordered::new();
{
let mut inner = self.inner.lock();
// Drop the stop
drop(inner.stop_source.take());
// take the join handles out
for h in inner.join_handles.drain(..) {
unord.push(h);
}
// Drop the stop
drop(inner.stop_source.take());
}
debug!("stopping {} low level network tasks", unord.len());
// Wait for everything to stop
while unord.next().await.is_some() {}
debug!("clearing dial info");
// Drop all dial info
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
@ -532,7 +531,7 @@ impl Network {
// Reset state including network class
*self.inner.lock() = Self::new_inner(network_manager);
info!("network stopped");
debug!("finished low level network shutdown");
}
//////////////////////////////////////////

View File

@ -1,6 +1,7 @@
use super::*;
use futures_util::stream::FuturesUnordered;
use futures_util::FutureExt;
use stop_token::future::FutureExt as StopTokenFutureExt;
struct DetectedPublicDialInfo {
dial_info: DialInfo,
@ -584,7 +585,9 @@ impl Network {
// Wait for all discovery futures to complete and collect contexts
let mut contexts = Vec::<DiscoveryContext>::new();
let mut network_class = Option::<NetworkClass>::None;
while let Some(ctxvec) = unord.next().await {
loop {
match unord.next().timeout_at(stop_token.clone()).await {
Ok(Some(ctxvec)) => {
if let Some(ctxvec) = ctxvec {
for ctx in ctxvec {
if let Some(nc) = ctx.inner.lock().detected_network_class {
@ -601,6 +604,16 @@ impl Network {
}
}
}
Ok(None) => {
// Normal completion
break;
}
Err(_) => {
// Stop token, exit early without error propagation
return Ok(());
}
}
}
// Get best network class
if network_class.is_some() {

View File

@ -65,8 +65,7 @@ impl Network {
ps.peek_exact(&mut first_packet),
)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
.map_err(map_to_string)?;
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
.await
@ -82,8 +81,7 @@ impl Network {
for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr)
.await
.map_err(logthru_net!())?
.await?
{
return Ok(Some(nc));
}

View File

@ -39,15 +39,11 @@ impl ProtocolNetworkConnection {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await

View File

@ -31,19 +31,14 @@ impl RawTcpNetworkConnection {
self.descriptor.clone()
}
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
self.stream.clone().close().await.map_err(map_to_string)?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
}
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
@ -54,23 +49,17 @@ impl RawTcpNetworkConnection {
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
stream
.write_all(&header)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
stream
.write_all(&message)
.await
.map_err(map_to_string)
.map_err(logthru_net!())
stream.write_all(&header).await.map_err(map_to_string)?;
stream.write_all(&message).await.map_err(map_to_string)
}
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let stream = self.stream.clone();
Self::send_internal(stream, message).await
}
#[instrument(level="trace", err, skip(self), fields(message.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4];
@ -90,6 +79,8 @@ impl RawTcpNetworkConnection {
let mut out: Vec<u8> = vec![0u8; len];
stream.read_exact(&mut out).await.map_err(map_to_string)?;
tracing::Span::current().record("message.len", &out.len());
Ok(out)
}
}
@ -120,6 +111,7 @@ impl RawTcpProtocolHandler {
}
}
#[instrument(level = "trace", err, skip(self, stream, tcp_stream))]
async fn on_accept_async(
self,
stream: AsyncPeekStream,
@ -151,6 +143,7 @@ impl RawTcpProtocolHandler {
Ok(Some(conn))
}
#[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
@ -191,6 +184,7 @@ impl RawTcpProtocolHandler {
Ok(conn)
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(
socket_addr: SocketAddr,
data: Vec<u8>,

View File

@ -10,6 +10,7 @@ impl RawUdpProtocolHandler {
Self { socket }
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn recv_message(
&self,
data: &mut [u8],
@ -35,9 +36,13 @@ impl RawUdpProtocolHandler {
peer_addr,
SocketAddress::from_socket_addr(local_socket_addr),
);
tracing::Span::current().record("ret.len", &size);
tracing::Span::current().record("ret.from", &format!("{:?}", descriptor).as_str());
Ok((size, descriptor))
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));

View File

@ -49,21 +49,17 @@ where
self.descriptor.clone()
}
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
self.stream.clone().close().await.map_err(map_to_string)?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
}
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned());
@ -76,6 +72,7 @@ where
.map_err(logthru_net!(error "failed to send websocket message"))
}
#[instrument(level="trace", err, skip(self), fields(message.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let out = match self.stream.clone().next().await {
Some(Ok(Message::Binary(v))) => v,
@ -86,12 +83,13 @@ where
return Err(e.to_string()).map_err(logthru_net!(error));
}
None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
return Err("WS stream closed".to_owned());
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
tracing::Span::current().record("message.len", &out.len());
Ok(out)
}
}
@ -137,6 +135,7 @@ impl WebsocketProtocolHandler {
}
}
#[instrument(level = "trace", err, skip(self, ps, tcp_stream))]
pub async fn on_accept_async(
self,
ps: AsyncPeekStream,
@ -156,9 +155,9 @@ impl WebsocketProtocolHandler {
Ok(_) => (),
Err(e) => {
if e.kind() == io::ErrorKind::TimedOut {
return Err(e).map_err(map_to_string).map_err(logthru_net!());
return Err(e).map_err(map_to_string);
}
return Err(e).map_err(map_to_string).map_err(logthru_net!(error));
return Err(e).map_err(map_to_string);
}
}
@ -237,10 +236,7 @@ impl WebsocketProtocolHandler {
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
// See what local address we ended up with
let actual_local_addr = tcp_stream
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!())?;
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
// Make our connection descriptor
let descriptor = ConnectionDescriptor::new(
@ -274,6 +270,7 @@ impl WebsocketProtocolHandler {
}
}
#[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
@ -281,6 +278,7 @@ impl WebsocketProtocolHandler {
Self::connect_internal(local_address, dial_info).await
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());

View File

@ -89,6 +89,7 @@ pub struct NetworkConnection {
established_time: u64,
stats: Arc<Mutex<NetworkConnectionStats>>,
sender: flume::Sender<Vec<u8>>,
stop_source: Option<StopSource>,
}
impl NetworkConnection {
@ -105,12 +106,13 @@ impl NetworkConnection {
last_message_recv_time: None,
})),
sender,
stop_source: None,
}
}
pub(super) fn from_protocol(
connection_manager: ConnectionManager,
stop_token: StopToken,
manager_stop_token: StopToken,
protocol_connection: ProtocolNetworkConnection,
) -> Self {
// Get timeout
@ -133,10 +135,14 @@ impl NetworkConnection {
last_message_recv_time: None,
}));
let stop_source = StopSource::new();
let local_stop_token = stop_source.token();
// Spawn connection processor and pass in protocol connection
let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection(
connection_manager,
stop_token,
local_stop_token,
manager_stop_token,
descriptor.clone(),
receiver,
protocol_connection,
@ -151,6 +157,7 @@ impl NetworkConnection {
established_time: intf::get_timestamp(),
stats,
sender,
stop_source: Some(stop_source),
}
}
@ -162,6 +169,13 @@ impl NetworkConnection {
ConnectionHandle::new(self.descriptor.clone(), self.sender.clone())
}
pub fn close(&mut self) {
if let Some(stop_source) = self.stop_source.take() {
// drop the stopper
drop(stop_source);
}
}
async fn send_internal(
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
@ -200,7 +214,8 @@ impl NetworkConnection {
// Connection receiver loop
fn process_connection(
connection_manager: ConnectionManager,
stop_token: StopToken,
local_stop_token: StopToken,
manager_stop_token: StopToken,
descriptor: ConnectionDescriptor,
receiver: flume::Receiver<Vec<u8>>,
protocol_connection: ProtocolNetworkConnection,
@ -293,7 +308,13 @@ impl NetworkConnection {
}
// Process futures
match unord.next().timeout_at(stop_token.clone()).await {
match unord
.next()
.timeout_at(local_stop_token.clone())
.timeout_at(manager_stop_token.clone())
.await
.and_then(std::convert::identity) // flatten
{
Ok(Some(RecvLoopAction::Send)) => {
// Don't reset inactivity timer if we're only sending
need_sender = true;
@ -312,7 +333,7 @@ impl NetworkConnection {
unreachable!();
}
Err(_) => {
// Stop token
// Either one of the stop tokens
break;
}
}

View File

@ -82,7 +82,12 @@ pub async fn test_add_get_remove() {
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.remove_connection(a2), Ok(()));
assert_eq!(
table
.remove_connection(a2)
.map(|c| c.connection_descriptor()),
Ok(a1)
);
assert_eq!(table.connection_count(), 0);
assert_err!(table.remove_connection(a2));
assert_eq!(table.connection_count(), 0);
@ -98,9 +103,24 @@ pub async fn test_add_get_remove() {
table.add_connection(c3).unwrap();
table.add_connection(c4).unwrap();
assert_eq!(table.connection_count(), 3);
assert_eq!(table.remove_connection(a2), Ok(()));
assert_eq!(table.remove_connection(a3), Ok(()));
assert_eq!(table.remove_connection(a4), Ok(()));
assert_eq!(
table
.remove_connection(a2)
.map(|c| c.connection_descriptor()),
Ok(a2)
);
assert_eq!(
table
.remove_connection(a3)
.map(|c| c.connection_descriptor()),
Ok(a3)
);
assert_eq!(
table
.remove_connection(a4)
.map(|c| c.connection_descriptor()),
Ok(a4)
);
assert_eq!(table.connection_count(), 0);
}

View File

@ -69,7 +69,6 @@ impl Network {
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
.await
.map_err(logthru_net!())
}
};
if res.is_ok() {
@ -102,7 +101,7 @@ impl Network {
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
// connection exists, send over it
conn.send_async(data).await.map_err(logthru_net!())?;
conn.send_async(data).await?;
// Network accounting
self.network_manager()

View File

@ -314,6 +314,7 @@ impl ReceiptManager {
}
pub async fn shutdown(&self) {
debug!("starting receipt manager shutdown");
let network_manager = self.network_manager();
// Stop all tasks
@ -325,11 +326,13 @@ impl ReceiptManager {
};
// Wait for everything to stop
debug!("waiting for timeout task to stop");
if !timeout_task.join().await.is_ok() {
panic!("joining timeout task failed");
}
*self.inner.lock() = Self::new_inner(network_manager);
debug!("finished receipt manager shutdown");
}
pub fn record_receipt(

View File

@ -374,19 +374,26 @@ impl RoutingTable {
}
pub async fn terminate(&self) {
debug!("starting routing table terminate");
// Cancel all tasks being ticked
debug!("stopping rolling transfers task");
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
error!("rolling_transfers_task not stopped: {}", e);
}
debug!("stopping bootstrap task");
if let Err(e) = self.unlocked_inner.bootstrap_task.stop().await {
error!("bootstrap_task not stopped: {}", e);
}
debug!("stopping peer minimum refresh task");
if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.stop().await {
error!("peer_minimum_refresh_task not stopped: {}", e);
}
debug!("stopping ping_validator task");
if let Err(e) = self.unlocked_inner.ping_validator_task.stop().await {
error!("ping_validator_task not stopped: {}", e);
}
debug!("stopping node info update singlefuture");
if self
.unlocked_inner
.node_info_update_single_future
@ -398,6 +405,8 @@ impl RoutingTable {
}
*self.inner.lock() = Self::new_inner(self.network_manager());
debug!("finished routing table terminate");
}
// Inform routing table entries that our dial info has changed

View File

@ -1428,23 +1428,30 @@ impl RPCProcessor {
}
pub async fn shutdown(&self) {
debug!("starting rpc processor shutdown");
// Stop the rpc workers
let mut unord = FuturesUnordered::new();
{
let mut inner = self.inner.lock();
// drop the stop
drop(inner.stop_source.take());
// take the join handles out
for h in inner.worker_join_handles.drain(..) {
unord.push(h);
}
// drop the stop
drop(inner.stop_source.take());
}
debug!("stopping {} rpc worker tasks", unord.len());
// Wait for them to complete
while unord.next().await.is_some() {}
debug!("resetting rpc processor state");
// Release the rpc processor
*self.inner.lock() = Self::new_inner(self.network_manager());
debug!("finished rpc processor shutdown");
}
pub fn enqueue_message(

View File

@ -518,10 +518,10 @@ pub async fn test_single_future() {
69
})
.await,
Ok(None)
Ok((None, true))
);
assert_eq!(sf.check().await, Ok(None));
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None));
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok((None, false)));
assert_eq!(sf.join().await, Ok(Some(69)));
assert_eq!(
sf.single_spawn(async {
@ -529,7 +529,7 @@ pub async fn test_single_future() {
37
})
.await,
Ok(None)
Ok((None, true))
);
intf::sleep(2000).await;
assert_eq!(
@ -538,7 +538,7 @@ pub async fn test_single_future() {
27
})
.await,
Ok(Some(37))
Ok((Some(37), true))
);
intf::sleep(2000).await;
assert_eq!(sf.join().await, Ok(Some(27)));
@ -555,10 +555,10 @@ pub async fn test_must_join_single_future() {
69
})
.await,
Ok(None)
Ok((None, true))
);
assert_eq!(sf.check().await, Ok(None));
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None));
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok((None, false)));
assert_eq!(sf.join().await, Ok(Some(69)));
assert_eq!(
sf.single_spawn(async {
@ -566,7 +566,7 @@ pub async fn test_must_join_single_future() {
37
})
.await,
Ok(None)
Ok((None, true))
);
intf::sleep(2000).await;
assert_eq!(
@ -575,7 +575,7 @@ pub async fn test_must_join_single_future() {
27
})
.await,
Ok(Some(37))
Ok((Some(37), true))
);
intf::sleep(2000).await;
assert_eq!(sf.join().await, Ok(Some(27)));

View File

@ -1,20 +1,19 @@
use async_executors::JoinHandle;
use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll};
#[derive(Debug)]
pub struct MustJoinHandle<T> {
join_handle: JoinHandle<T>,
completed: AtomicBool,
completed: bool,
}
impl<T> MustJoinHandle<T> {
pub fn new(join_handle: JoinHandle<T>) -> Self {
Self {
join_handle,
completed: AtomicBool::new(false),
completed: false,
}
}
}
@ -22,7 +21,7 @@ impl<T> MustJoinHandle<T> {
impl<T> Drop for MustJoinHandle<T> {
fn drop(&mut self) {
// panic if we haven't completed
if !self.completed.load(Ordering::Relaxed) {
if !self.completed {
panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.")
}
}
@ -34,7 +33,7 @@ impl<T: 'static> Future for MustJoinHandle<T> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.join_handle).poll(cx) {
Poll::Ready(t) => {
self.completed.store(true, Ordering::Relaxed);
self.completed = true;
Poll::Ready(t)
}
Poll::Pending => Poll::Pending,

View File

@ -131,7 +131,7 @@ where
pub async fn single_spawn(
&self,
future: impl Future<Output = T> + 'static,
) -> Result<Option<T>, ()> {
) -> Result<(Option<T>,bool), ()> {
let mut out: Option<T> = None;
// See if we have a result we can return
@ -164,7 +164,7 @@ where
}
// Return the prior result if we have one
Ok(out)
Ok((out, run))
}
}
}
@ -178,7 +178,7 @@ cfg_if! {
pub async fn single_spawn(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> Result<Option<T>, ()> {
) -> Result<(Option<T>, bool), ()> {
let mut out: Option<T> = None;
// See if we have a result we can return
let maybe_jh = match self.try_lock() {
@ -206,7 +206,7 @@ cfg_if! {
self.unlock(Some(MustJoinHandle::new(spawn(future))));
}
// Return the prior result if we have one
Ok(out)
Ok((out, run))
}
}
}

View File

@ -160,7 +160,7 @@ where
pub async fn single_spawn(
&self,
future: impl Future<Output = T> + 'static,
) -> Result<Option<T>, ()> {
) -> Result<(Option<T>, bool), ()> {
let mut out: Option<T> = None;
// See if we have a result we can return
@ -193,7 +193,7 @@ where
}
// Return the prior result if we have one
Ok(out)
Ok((out, run))
}
}
}
@ -207,7 +207,7 @@ cfg_if! {
pub async fn single_spawn(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> Result<Option<T>, ()> {
) -> Result<(Option<T>, bool), ()> {
let mut out: Option<T> = None;
// See if we have a result we can return
let maybe_jh = match self.try_lock() {
@ -235,7 +235,7 @@ cfg_if! {
self.unlock(Some(spawn(future)));
}
// Return the prior result if we have one
Ok(out)
Ok((out, run))
}
}
}

View File

@ -76,11 +76,13 @@ impl TickTask {
let opt_stop_source = &mut *self.stop_source.lock().await;
if opt_stop_source.is_none() {
// already stopped, just return
trace!("tick task already stopped");
return Ok(());
}
*opt_stop_source = None;
drop(opt_stop_source.take());
// wait for completion of the tick task
trace!("stopping single future");
match self.single_future.join().await {
Ok(Some(Err(err))) => Err(err),
_ => Ok(()),
@ -91,9 +93,37 @@ impl TickTask {
let now = get_timestamp();
let last_timestamp_us = self.last_timestamp_us.load(Ordering::Acquire);
if last_timestamp_us == 0u64 || (now - last_timestamp_us) >= self.tick_period_us {
// Run the singlefuture
if last_timestamp_us != 0u64 && (now - last_timestamp_us) < self.tick_period_us {
// It's not time yet
return Ok(());
}
// Lock the stop source, tells us if we have ever started this future
let opt_stop_source = &mut *self.stop_source.lock().await;
if opt_stop_source.is_some() {
// See if the previous execution finished with an error
match self.single_future.check().await {
Ok(Some(Err(e))) => {
// We have an error result, which means the singlefuture ran but we need to propagate the error
return Err(e);
}
Ok(Some(Ok(()))) => {
// We have an ok result, which means the singlefuture ran, and we should run it again this tick
}
Ok(None) => {
// No prior result to return which means things are still running
// We can just return now, since the singlefuture will not run a second time
return Ok(());
}
Err(()) => {
// If we get this, it's because we are joining the singlefuture already
// Don't bother running but this is not an error in this case
return Ok(());
}
};
}
// Run the singlefuture
let stop_source = StopSource::new();
match self
.single_future
@ -104,24 +134,20 @@ impl TickTask {
))
.await
{
// Single future ran this tick
Ok(Some(ret)) => {
// We should have already consumed the result of the last run, or there was none
// and we should definitely have run, because the prior 'check()' operation
// should have ensured the singlefuture was ready to run
Ok((None, true)) => {
// Set new timer
self.last_timestamp_us.store(now, Ordering::Release);
// Save new stopper
*opt_stop_source = Some(stop_source);
ret
}
// Single future did not run this tick
Ok(None) | Err(()) => {
// If the execution didn't happen this time because it was already running
// then we should try again the next tick and not reset the timestamp so we try as soon as possible
Ok(())
}
}
} else {
// It's not time yet
Ok(())
// All other conditions should not be reachable
_ => {
unreachable!();
}
}
}
}

View File

@ -32,11 +32,13 @@ pub async fn run_veilid_server(settings: Settings, server_mode: ServerMode) -> R
run_veilid_server_internal(settings, server_mode).await
}
#[instrument(err)]
#[instrument(err, skip_all)]
pub async fn run_veilid_server_internal(
settings: Settings,
server_mode: ServerMode,
) -> Result<(), String> {
trace!(?settings, ?server_mode);
let settingsr = settings.read();
// Create client api state change pipe