fix cooperative cancellation
This commit is contained in:
parent
180628beef
commit
c33f78ac8b
38
.vscode/launch.json
vendored
38
.vscode/launch.json
vendored
@ -29,13 +29,17 @@
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Launch veilid-cli",
|
||||
"args": ["--debug"],
|
||||
"args": [
|
||||
"--debug"
|
||||
],
|
||||
"program": "${workspaceFolder}/target/debug/veilid-cli",
|
||||
"windows": {
|
||||
"program": "${workspaceFolder}/target/debug/veilid-cli.exe"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/target/debug/",
|
||||
"sourceLanguages": ["rust"],
|
||||
"sourceLanguages": [
|
||||
"rust"
|
||||
],
|
||||
"terminal": "console"
|
||||
},
|
||||
// {
|
||||
@ -48,20 +52,21 @@
|
||||
// "args": ["--trace"],
|
||||
// "cwd": "${workspaceFolder}/veilid-server"
|
||||
// }
|
||||
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug veilid-server",
|
||||
"program": "${workspaceFolder}/target/debug/veilid-server",
|
||||
"args": ["--trace", "--attach=true"],
|
||||
"args": [
|
||||
"--trace",
|
||||
"--attach=true"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/target/debug/",
|
||||
"env": {
|
||||
"RUST_BACKTRACE": "1"
|
||||
},
|
||||
"terminal": "console"
|
||||
},
|
||||
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
@ -78,10 +83,11 @@
|
||||
"name": "veilid-core"
|
||||
}
|
||||
},
|
||||
"args": ["${selectedText}"],
|
||||
"args": [
|
||||
"${selectedText}"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/target/debug/"
|
||||
},
|
||||
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
@ -98,10 +104,11 @@
|
||||
"name": "veilid-server"
|
||||
}
|
||||
},
|
||||
"args": ["${selectedText}"],
|
||||
"args": [
|
||||
"${selectedText}"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/veilid-server"
|
||||
},
|
||||
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
@ -118,10 +125,11 @@
|
||||
"name": "keyvaluedb-sqlite"
|
||||
}
|
||||
},
|
||||
"args": ["${selectedText}"],
|
||||
"args": [
|
||||
"${selectedText}"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/external/keyvaluedb/keyvaluedb-sqlite"
|
||||
},
|
||||
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
@ -138,8 +146,10 @@
|
||||
"name": "keyring"
|
||||
}
|
||||
},
|
||||
"args": ["${selectedText}"],
|
||||
"args": [
|
||||
"${selectedText}"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/external/keyring-rs"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -334,7 +334,7 @@ impl NetworkInterfaces {
|
||||
self.valid = false;
|
||||
let last_interfaces = core::mem::take(&mut self.interfaces);
|
||||
|
||||
let mut platform_support = PlatformSupport::new().map_err(logthru_net!())?;
|
||||
let mut platform_support = PlatformSupport::new()?;
|
||||
platform_support
|
||||
.get_interfaces(&mut self.interfaces)
|
||||
.await?;
|
||||
|
@ -2,19 +2,28 @@ use super::*;
|
||||
use crate::xx::*;
|
||||
use connection_table::*;
|
||||
use network_connection::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
// Connection manager
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ConnectionManagerEvent {
|
||||
Accepted(ProtocolNetworkConnection),
|
||||
Finished(ConnectionDescriptor),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConnectionManagerInner {
|
||||
connection_table: ConnectionTable,
|
||||
sender: flume::Sender<ConnectionManagerEvent>,
|
||||
async_processor_jh: Option<MustJoinHandle<()>>,
|
||||
stop_source: Option<StopSource>,
|
||||
}
|
||||
|
||||
struct ConnectionManagerArc {
|
||||
network_manager: NetworkManager,
|
||||
inner: AsyncMutex<Option<ConnectionManagerInner>>,
|
||||
inner: Mutex<Option<ConnectionManagerInner>>,
|
||||
}
|
||||
impl core::fmt::Debug for ConnectionManagerArc {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
@ -30,16 +39,23 @@ pub struct ConnectionManager {
|
||||
}
|
||||
|
||||
impl ConnectionManager {
|
||||
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
|
||||
fn new_inner(
|
||||
config: VeilidConfig,
|
||||
stop_source: StopSource,
|
||||
sender: flume::Sender<ConnectionManagerEvent>,
|
||||
async_processor_jh: JoinHandle<()>,
|
||||
) -> ConnectionManagerInner {
|
||||
ConnectionManagerInner {
|
||||
stop_source: Some(StopSource::new()),
|
||||
stop_source: Some(stop_source),
|
||||
sender: sender,
|
||||
async_processor_jh: Some(MustJoinHandle::new(async_processor_jh)),
|
||||
connection_table: ConnectionTable::new(config),
|
||||
}
|
||||
}
|
||||
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
|
||||
ConnectionManagerArc {
|
||||
network_manager,
|
||||
inner: AsyncMutex::new(None),
|
||||
inner: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
pub fn new(network_manager: NetworkManager) -> Self {
|
||||
@ -54,18 +70,34 @@ impl ConnectionManager {
|
||||
|
||||
pub async fn startup(&self) {
|
||||
trace!("startup connection manager");
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let mut inner = self.arc.inner.lock();
|
||||
if inner.is_some() {
|
||||
panic!("shouldn't start connection manager twice without shutting it down first");
|
||||
}
|
||||
|
||||
*inner = Some(Self::new_inner(self.network_manager().config()));
|
||||
// Create channel for async_processor to receive notifications of networking events
|
||||
let (sender, receiver) = flume::unbounded();
|
||||
|
||||
// Create the stop source we'll use to stop the processor and the connection table
|
||||
let stop_source = StopSource::new();
|
||||
|
||||
// Spawn the async processor
|
||||
let async_processor = spawn(self.clone().async_processor(stop_source.token(), receiver));
|
||||
|
||||
// Store in the inner object
|
||||
*inner = Some(Self::new_inner(
|
||||
self.network_manager().config(),
|
||||
stop_source,
|
||||
sender,
|
||||
async_processor,
|
||||
));
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
debug!("starting connection manager shutdown");
|
||||
// Remove the inner from the lock
|
||||
let mut inner = {
|
||||
let mut inner_lock = self.arc.inner.lock().await;
|
||||
let mut inner_lock = self.arc.inner.lock();
|
||||
let inner = match inner_lock.take() {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
@ -75,11 +107,17 @@ impl ConnectionManager {
|
||||
inner
|
||||
};
|
||||
|
||||
// Stop all the connections
|
||||
// Stop all the connections and the async processor
|
||||
debug!("stopping async processor task");
|
||||
drop(inner.stop_source.take());
|
||||
|
||||
let async_processor_jh = inner.async_processor_jh.take().unwrap();
|
||||
// wait for the async processor to stop
|
||||
debug!("waiting for async processor to stop");
|
||||
async_processor_jh.await;
|
||||
// Wait for the connections to complete
|
||||
debug!("waiting for connection handlers to complete");
|
||||
inner.connection_table.join().await;
|
||||
debug!("finished connection manager shutdown");
|
||||
}
|
||||
|
||||
// Returns a network connection if one already is established
|
||||
@ -87,7 +125,7 @@ impl ConnectionManager {
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Option<ConnectionHandle> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let mut inner = self.arc.inner.lock();
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
@ -128,86 +166,95 @@ impl ConnectionManager {
|
||||
local_addr: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ConnectionHandle, String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
panic!("not started");
|
||||
}
|
||||
};
|
||||
let killed = {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
panic!("not started");
|
||||
}
|
||||
};
|
||||
|
||||
log_net!(
|
||||
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
||||
local_addr.green(),
|
||||
dial_info.green()
|
||||
);
|
||||
|
||||
let peer_address = dial_info.to_peer_address();
|
||||
let descriptor = match local_addr {
|
||||
Some(la) => {
|
||||
ConnectionDescriptor::new(peer_address, SocketAddress::from_socket_addr(la))
|
||||
}
|
||||
None => ConnectionDescriptor::new_no_local(peer_address),
|
||||
};
|
||||
|
||||
// If any connection to this remote exists that has the same protocol, return it
|
||||
// Any connection will do, we don't have to match the local address
|
||||
|
||||
if let Some(conn) = inner
|
||||
.connection_table
|
||||
.get_last_connection_by_remote(descriptor.remote())
|
||||
{
|
||||
log_net!(
|
||||
"== Returning existing connection local_addr={:?} peer_address={:?}",
|
||||
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
||||
local_addr.green(),
|
||||
peer_address.green()
|
||||
dial_info.green()
|
||||
);
|
||||
|
||||
return Ok(conn);
|
||||
}
|
||||
let peer_address = dial_info.to_peer_address();
|
||||
let descriptor = match local_addr {
|
||||
Some(la) => {
|
||||
ConnectionDescriptor::new(peer_address, SocketAddress::from_socket_addr(la))
|
||||
}
|
||||
None => ConnectionDescriptor::new_no_local(peer_address),
|
||||
};
|
||||
|
||||
// Drop any other protocols connections to this remote that have the same local addr
|
||||
// otherwise this connection won't succeed due to binding
|
||||
let mut killed = false;
|
||||
if let Some(local_addr) = local_addr {
|
||||
if local_addr.port() != 0 {
|
||||
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
|
||||
let pa = PeerAddress::new(descriptor.remote_address().clone(), pt);
|
||||
for prior_descriptor in inner
|
||||
.connection_table
|
||||
.get_connection_descriptors_by_remote(pa)
|
||||
{
|
||||
let mut kill = false;
|
||||
// See if the local address would collide
|
||||
if let Some(prior_local) = prior_descriptor.local() {
|
||||
if (local_addr.ip().is_unspecified()
|
||||
|| prior_local.to_ip_addr().is_unspecified()
|
||||
|| (local_addr.ip() == prior_local.to_ip_addr()))
|
||||
&& prior_local.port() == local_addr.port()
|
||||
{
|
||||
kill = true;
|
||||
// If any connection to this remote exists that has the same protocol, return it
|
||||
// Any connection will do, we don't have to match the local address
|
||||
|
||||
if let Some(conn) = inner
|
||||
.connection_table
|
||||
.get_last_connection_by_remote(descriptor.remote())
|
||||
{
|
||||
log_net!(
|
||||
"== Returning existing connection local_addr={:?} peer_address={:?}",
|
||||
local_addr.green(),
|
||||
peer_address.green()
|
||||
);
|
||||
|
||||
return Ok(conn);
|
||||
}
|
||||
|
||||
// Drop any other protocols connections to this remote that have the same local addr
|
||||
// otherwise this connection won't succeed due to binding
|
||||
let mut killed = Vec::<NetworkConnection>::new();
|
||||
if let Some(local_addr) = local_addr {
|
||||
if local_addr.port() != 0 {
|
||||
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
|
||||
let pa = PeerAddress::new(descriptor.remote_address().clone(), pt);
|
||||
for prior_descriptor in inner
|
||||
.connection_table
|
||||
.get_connection_descriptors_by_remote(pa)
|
||||
{
|
||||
let mut kill = false;
|
||||
// See if the local address would collide
|
||||
if let Some(prior_local) = prior_descriptor.local() {
|
||||
if (local_addr.ip().is_unspecified()
|
||||
|| prior_local.to_ip_addr().is_unspecified()
|
||||
|| (local_addr.ip() == prior_local.to_ip_addr()))
|
||||
&& prior_local.port() == local_addr.port()
|
||||
{
|
||||
kill = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if kill {
|
||||
log_net!(debug
|
||||
">< Terminating connection prior_descriptor={:?}",
|
||||
prior_descriptor
|
||||
);
|
||||
if let Err(e) =
|
||||
inner.connection_table.remove_connection(prior_descriptor)
|
||||
{
|
||||
log_net!(error e);
|
||||
if kill {
|
||||
log_net!(debug
|
||||
">< Terminating connection prior_descriptor={:?}",
|
||||
prior_descriptor
|
||||
);
|
||||
let mut conn = inner
|
||||
.connection_table
|
||||
.remove_connection(prior_descriptor)
|
||||
.expect("connection not in table");
|
||||
|
||||
conn.close();
|
||||
|
||||
killed.push(conn);
|
||||
}
|
||||
killed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
killed
|
||||
};
|
||||
|
||||
// Wait for the killed connections to end their recv loops
|
||||
let mut retry_count = if !killed.is_empty() { 2 } else { 0 };
|
||||
for k in killed {
|
||||
k.await;
|
||||
}
|
||||
|
||||
// Attempt new connection
|
||||
let mut retry_count = if killed { 2 } else { 0 };
|
||||
|
||||
let conn = loop {
|
||||
match ProtocolNetworkConnection::connect(local_addr, dial_info.clone()).await {
|
||||
Ok(v) => break Ok(v),
|
||||
@ -222,44 +269,113 @@ impl ConnectionManager {
|
||||
}
|
||||
}?;
|
||||
|
||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
||||
// Add to the connection table
|
||||
let mut inner = self.arc.inner.lock();
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Err("shutting down".to_owned());
|
||||
}
|
||||
};
|
||||
self.on_new_protocol_network_connection(inner, conn)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Callbacks
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
async fn async_processor(
|
||||
self,
|
||||
stop_token: StopToken,
|
||||
receiver: flume::Receiver<ConnectionManagerEvent>,
|
||||
) {
|
||||
// Process async commands
|
||||
while let Ok(Ok(event)) = receiver.recv_async().timeout_at(stop_token.clone()).await {
|
||||
match event {
|
||||
ConnectionManagerEvent::Accepted(conn) => {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
match &mut *inner {
|
||||
Some(inner) => {
|
||||
// Register the connection
|
||||
// We don't care if this fails, since nobody here asked for the inbound connection.
|
||||
// If it does, we just drop the connection
|
||||
let _ = self.on_new_protocol_network_connection(inner, conn);
|
||||
}
|
||||
None => {
|
||||
// If this somehow happens, we're shutting down
|
||||
}
|
||||
};
|
||||
}
|
||||
ConnectionManagerEvent::Finished(desc) => {
|
||||
let conn = {
|
||||
let mut inner_lock = self.arc.inner.lock();
|
||||
match &mut *inner_lock {
|
||||
Some(inner) => {
|
||||
// Remove the connection and wait for the connection loop to terminate
|
||||
if let Ok(conn) = inner.connection_table.remove_connection(desc) {
|
||||
// Must close and wait to ensure things join
|
||||
Some(conn)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
};
|
||||
if let Some(mut conn) = conn {
|
||||
conn.close();
|
||||
conn.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Called by low-level network when any connection-oriented protocol connection appears
|
||||
// either from incoming connections.
|
||||
pub(super) async fn on_accepted_protocol_network_connection(
|
||||
&self,
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// If we are shutting down, just drop this and return
|
||||
return Ok(());
|
||||
}
|
||||
// Get channel sender
|
||||
let sender = {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// If we are shutting down, just drop this and return
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
inner.sender.clone()
|
||||
};
|
||||
self.on_new_protocol_network_connection(inner, conn)
|
||||
.map(drop)
|
||||
|
||||
// Inform the processor of the event
|
||||
let _ = sender
|
||||
.send_async(ConnectionManagerEvent::Accepted(conn))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Callback from network connection receive loop when it exits
|
||||
// cleans up the entry in the connection table
|
||||
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// If we're shutting down, do nothing here
|
||||
return;
|
||||
}
|
||||
// Get channel sender
|
||||
let sender = {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// If we are shutting down, just drop this and return
|
||||
return;
|
||||
}
|
||||
};
|
||||
inner.sender.clone()
|
||||
};
|
||||
|
||||
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
|
||||
log_net!(error e);
|
||||
}
|
||||
// Inform the processor of the event
|
||||
let _ = sender
|
||||
.send_async(ConnectionManagerEvent::Finished(descriptor))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
@ -160,13 +160,16 @@ impl ConnectionTable {
|
||||
.expect("Inconsistency in connection table");
|
||||
}
|
||||
|
||||
pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> {
|
||||
pub fn remove_connection(
|
||||
&mut self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let _ = self.conn_by_descriptor[index]
|
||||
let conn = self.conn_by_descriptor[index]
|
||||
.remove(&descriptor)
|
||||
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
|
||||
|
||||
self.remove_connection_records(descriptor);
|
||||
Ok(())
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
@ -299,26 +299,31 @@ impl NetworkManager {
|
||||
|
||||
#[instrument(level = "debug", skip_all)]
|
||||
pub async fn shutdown(&self) {
|
||||
trace!("NetworkManager::shutdown begin");
|
||||
debug!("starting network manager shutdown");
|
||||
|
||||
// Cancel all tasks
|
||||
debug!("stopping rolling transfers task");
|
||||
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
|
||||
warn!("rolling_transfers_task not stopped: {}", e);
|
||||
}
|
||||
debug!("stopping relay management task task");
|
||||
if let Err(e) = self.unlocked_inner.relay_management_task.stop().await {
|
||||
warn!("relay_management_task not stopped: {}", e);
|
||||
}
|
||||
|
||||
// Shutdown network components if they started up
|
||||
debug!("shutting down network components");
|
||||
|
||||
let components = self.inner.lock().components.clone();
|
||||
if let Some(components) = components {
|
||||
components.receipt_manager.shutdown().await;
|
||||
components.rpc_processor.shutdown().await;
|
||||
components.net.shutdown().await;
|
||||
components.connection_manager.shutdown().await;
|
||||
components.rpc_processor.shutdown().await;
|
||||
components.receipt_manager.shutdown().await;
|
||||
}
|
||||
|
||||
// reset the state
|
||||
debug!("resetting network manager state");
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
inner.components = None;
|
||||
@ -326,9 +331,10 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
// send update
|
||||
debug!("sending network state update");
|
||||
self.send_network_update();
|
||||
|
||||
trace!("NetworkManager::shutdown end");
|
||||
debug!("finished network manager shutdown");
|
||||
}
|
||||
|
||||
pub fn update_client_whitelist(&self, client: DHTKey) {
|
||||
@ -883,8 +889,7 @@ impl NetworkManager {
|
||||
match self
|
||||
.net()
|
||||
.send_data_to_existing_connection(descriptor, data)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
.await?
|
||||
{
|
||||
None => Ok(()),
|
||||
Some(_) => Err("unable to send over reverse connection".to_owned()),
|
||||
@ -973,8 +978,7 @@ impl NetworkManager {
|
||||
match self
|
||||
.net()
|
||||
.send_data_to_existing_connection(descriptor, data)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
.await?
|
||||
{
|
||||
None => Ok(()),
|
||||
Some(_) => Err("unable to send over hole punch".to_owned()),
|
||||
@ -1005,8 +1009,7 @@ impl NetworkManager {
|
||||
match this
|
||||
.net()
|
||||
.send_data_to_existing_connection(descriptor, data)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
.await?
|
||||
{
|
||||
None => {
|
||||
return Ok(if descriptor.matches_peer_scope(PeerScope::Local) {
|
||||
@ -1150,7 +1153,7 @@ impl NetworkManager {
|
||||
"failed to resolve recipient node for relay, dropping outbound relayed packet...: {:?}",
|
||||
e
|
||||
)
|
||||
}).map_err(logthru_net!())?
|
||||
})?
|
||||
} else {
|
||||
// If this is not a node in the client whitelist, only allow inbound relay
|
||||
// which only performs a lightweight lookup before passing the packet back out
|
||||
|
@ -279,20 +279,14 @@ impl Network {
|
||||
let res = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
|
||||
}
|
||||
};
|
||||
if res.is_ok() {
|
||||
@ -324,10 +318,7 @@ impl Network {
|
||||
descriptor
|
||||
);
|
||||
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?;
|
||||
ph.clone().send_message(data, peer_socket_addr).await?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@ -345,7 +336,7 @@ impl Network {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send_async(data).await.map_err(logthru_net!())?;
|
||||
conn.send_async(data).await?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@ -372,10 +363,7 @@ impl Network {
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
let res = ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!());
|
||||
let res = ph.send_message(data, peer_socket_addr).await;
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@ -383,8 +371,7 @@ impl Network {
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
@ -394,7 +381,7 @@ impl Network {
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?;
|
||||
|
||||
let res = conn.send_async(data).await.map_err(logthru_net!(error));
|
||||
let res = conn.send_async(data).await;
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@ -414,43 +401,51 @@ impl Network {
|
||||
// initialize interfaces
|
||||
let mut interfaces = NetworkInterfaces::new();
|
||||
interfaces.refresh().await?;
|
||||
self.inner.lock().interfaces = interfaces;
|
||||
|
||||
// get protocol config
|
||||
let protocol_config = {
|
||||
let c = self.config.get();
|
||||
let mut inbound = ProtocolSet::new();
|
||||
let mut inner = self.inner.lock();
|
||||
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
inbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
|
||||
inbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
|
||||
inbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
|
||||
inbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
// Create stop source
|
||||
inner.stop_source = Some(StopSource::new());
|
||||
inner.interfaces = interfaces;
|
||||
|
||||
let mut outbound = ProtocolSet::new();
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
outbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
|
||||
outbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
|
||||
outbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
|
||||
outbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
// get protocol config
|
||||
let protocol_config = {
|
||||
let c = self.config.get();
|
||||
let mut inbound = ProtocolSet::new();
|
||||
|
||||
ProtocolConfig { inbound, outbound }
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
inbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
|
||||
inbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
|
||||
inbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
|
||||
inbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
let mut outbound = ProtocolSet::new();
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
outbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
|
||||
outbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
|
||||
outbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
|
||||
outbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
ProtocolConfig { inbound, outbound }
|
||||
};
|
||||
inner.protocol_config = Some(protocol_config);
|
||||
protocol_config
|
||||
};
|
||||
self.inner.lock().protocol_config = Some(protocol_config);
|
||||
|
||||
// start listeners
|
||||
if protocol_config.inbound.contains(ProtocolType::UDP) {
|
||||
@ -503,28 +498,32 @@ impl Network {
|
||||
|
||||
#[instrument(level = "debug", skip_all)]
|
||||
pub async fn shutdown(&self) {
|
||||
info!("stopping network");
|
||||
debug!("starting low level network shutdown");
|
||||
|
||||
let network_manager = self.network_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Stop all tasks
|
||||
debug!("stopping update network class task");
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
|
||||
error!("update_network_class_task not cancelled: {}", e);
|
||||
}
|
||||
|
||||
let mut unord = FuturesUnordered::new();
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
// take the join handles out
|
||||
for h in inner.join_handles.drain(..) {
|
||||
unord.push(h);
|
||||
}
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
}
|
||||
debug!("stopping {} low level network tasks", unord.len());
|
||||
// Wait for everything to stop
|
||||
while unord.next().await.is_some() {}
|
||||
|
||||
debug!("clearing dial info");
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||
@ -532,7 +531,7 @@ impl Network {
|
||||
// Reset state including network class
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
|
||||
info!("network stopped");
|
||||
debug!("finished low level network shutdown");
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::*;
|
||||
use futures_util::stream::FuturesUnordered;
|
||||
use futures_util::FutureExt;
|
||||
use stop_token::future::FutureExt as StopTokenFutureExt;
|
||||
|
||||
struct DetectedPublicDialInfo {
|
||||
dial_info: DialInfo,
|
||||
@ -584,20 +585,32 @@ impl Network {
|
||||
// Wait for all discovery futures to complete and collect contexts
|
||||
let mut contexts = Vec::<DiscoveryContext>::new();
|
||||
let mut network_class = Option::<NetworkClass>::None;
|
||||
while let Some(ctxvec) = unord.next().await {
|
||||
if let Some(ctxvec) = ctxvec {
|
||||
for ctx in ctxvec {
|
||||
if let Some(nc) = ctx.inner.lock().detected_network_class {
|
||||
if let Some(last_nc) = network_class {
|
||||
if nc < last_nc {
|
||||
network_class = Some(nc);
|
||||
loop {
|
||||
match unord.next().timeout_at(stop_token.clone()).await {
|
||||
Ok(Some(ctxvec)) => {
|
||||
if let Some(ctxvec) = ctxvec {
|
||||
for ctx in ctxvec {
|
||||
if let Some(nc) = ctx.inner.lock().detected_network_class {
|
||||
if let Some(last_nc) = network_class {
|
||||
if nc < last_nc {
|
||||
network_class = Some(nc);
|
||||
}
|
||||
} else {
|
||||
network_class = Some(nc);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
network_class = Some(nc);
|
||||
|
||||
contexts.push(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
contexts.push(ctx);
|
||||
}
|
||||
Ok(None) => {
|
||||
// Normal completion
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
// Stop token, exit early without error propagation
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -65,8 +65,7 @@ impl Network {
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
.map_err(map_to_string)?;
|
||||
|
||||
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
|
||||
.await
|
||||
@ -82,8 +81,7 @@ impl Network {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), tcp_stream.clone(), addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
.await?
|
||||
{
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
|
@ -39,15 +39,11 @@ impl ProtocolNetworkConnection {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
|
||||
|
@ -31,19 +31,14 @@ impl RawTcpNetworkConnection {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
.clone()
|
||||
.close()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
@ -54,23 +49,17 @@ impl RawTcpNetworkConnection {
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
stream
|
||||
.write_all(&message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
stream.write_all(&header).await.map_err(map_to_string)?;
|
||||
stream.write_all(&message).await.map_err(map_to_string)
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
let stream = self.stream.clone();
|
||||
Self::send_internal(stream, message).await
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(message.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
@ -90,6 +79,8 @@ impl RawTcpNetworkConnection {
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
stream.read_exact(&mut out).await.map_err(map_to_string)?;
|
||||
|
||||
tracing::Span::current().record("message.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@ -120,6 +111,7 @@ impl RawTcpProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, stream, tcp_stream))]
|
||||
async fn on_accept_async(
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
@ -151,6 +143,7 @@ impl RawTcpProtocolHandler {
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err)]
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
@ -191,6 +184,7 @@ impl RawTcpProtocolHandler {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
|
@ -10,6 +10,7 @@ impl RawUdpProtocolHandler {
|
||||
Self { socket }
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn recv_message(
|
||||
&self,
|
||||
data: &mut [u8],
|
||||
@ -35,9 +36,13 @@ impl RawUdpProtocolHandler {
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_socket_addr),
|
||||
);
|
||||
|
||||
tracing::Span::current().record("ret.len", &size);
|
||||
tracing::Span::current().record("ret.from", &format!("{:?}", descriptor).as_str());
|
||||
Ok((size, descriptor))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
|
||||
|
@ -49,21 +49,17 @@ where
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
.clone()
|
||||
.close()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
@ -76,6 +72,7 @@ where
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(message.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let out = match self.stream.clone().next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
@ -86,12 +83,13 @@ where
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
|
||||
return Err("WS stream closed".to_owned());
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
tracing::Span::current().record("message.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@ -137,6 +135,7 @@ impl WebsocketProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, ps, tcp_stream))]
|
||||
pub async fn on_accept_async(
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
@ -156,9 +155,9 @@ impl WebsocketProtocolHandler {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::TimedOut {
|
||||
return Err(e).map_err(map_to_string).map_err(logthru_net!());
|
||||
return Err(e).map_err(map_to_string);
|
||||
}
|
||||
return Err(e).map_err(map_to_string).map_err(logthru_net!(error));
|
||||
return Err(e).map_err(map_to_string);
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,10 +236,7 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
|
||||
|
||||
// Make our connection descriptor
|
||||
let descriptor = ConnectionDescriptor::new(
|
||||
@ -274,6 +270,7 @@ impl WebsocketProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err)]
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
@ -281,6 +278,7 @@ impl WebsocketProtocolHandler {
|
||||
Self::connect_internal(local_address, dial_info).await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
|
@ -89,6 +89,7 @@ pub struct NetworkConnection {
|
||||
established_time: u64,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
sender: flume::Sender<Vec<u8>>,
|
||||
stop_source: Option<StopSource>,
|
||||
}
|
||||
|
||||
impl NetworkConnection {
|
||||
@ -105,12 +106,13 @@ impl NetworkConnection {
|
||||
last_message_recv_time: None,
|
||||
})),
|
||||
sender,
|
||||
stop_source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn from_protocol(
|
||||
connection_manager: ConnectionManager,
|
||||
stop_token: StopToken,
|
||||
manager_stop_token: StopToken,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
) -> Self {
|
||||
// Get timeout
|
||||
@ -133,10 +135,14 @@ impl NetworkConnection {
|
||||
last_message_recv_time: None,
|
||||
}));
|
||||
|
||||
let stop_source = StopSource::new();
|
||||
let local_stop_token = stop_source.token();
|
||||
|
||||
// Spawn connection processor and pass in protocol connection
|
||||
let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection(
|
||||
connection_manager,
|
||||
stop_token,
|
||||
local_stop_token,
|
||||
manager_stop_token,
|
||||
descriptor.clone(),
|
||||
receiver,
|
||||
protocol_connection,
|
||||
@ -151,6 +157,7 @@ impl NetworkConnection {
|
||||
established_time: intf::get_timestamp(),
|
||||
stats,
|
||||
sender,
|
||||
stop_source: Some(stop_source),
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,6 +169,13 @@ impl NetworkConnection {
|
||||
ConnectionHandle::new(self.descriptor.clone(), self.sender.clone())
|
||||
}
|
||||
|
||||
pub fn close(&mut self) {
|
||||
if let Some(stop_source) = self.stop_source.take() {
|
||||
// drop the stopper
|
||||
drop(stop_source);
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_internal(
|
||||
protocol_connection: &ProtocolNetworkConnection,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
@ -200,7 +214,8 @@ impl NetworkConnection {
|
||||
// Connection receiver loop
|
||||
fn process_connection(
|
||||
connection_manager: ConnectionManager,
|
||||
stop_token: StopToken,
|
||||
local_stop_token: StopToken,
|
||||
manager_stop_token: StopToken,
|
||||
descriptor: ConnectionDescriptor,
|
||||
receiver: flume::Receiver<Vec<u8>>,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
@ -293,7 +308,13 @@ impl NetworkConnection {
|
||||
}
|
||||
|
||||
// Process futures
|
||||
match unord.next().timeout_at(stop_token.clone()).await {
|
||||
match unord
|
||||
.next()
|
||||
.timeout_at(local_stop_token.clone())
|
||||
.timeout_at(manager_stop_token.clone())
|
||||
.await
|
||||
.and_then(std::convert::identity) // flatten
|
||||
{
|
||||
Ok(Some(RecvLoopAction::Send)) => {
|
||||
// Don't reset inactivity timer if we're only sending
|
||||
need_sender = true;
|
||||
@ -312,7 +333,7 @@ impl NetworkConnection {
|
||||
unreachable!();
|
||||
}
|
||||
Err(_) => {
|
||||
// Stop token
|
||||
// Either one of the stop tokens
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -82,7 +82,12 @@ pub async fn test_add_get_remove() {
|
||||
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.remove_connection(a2), Ok(()));
|
||||
assert_eq!(
|
||||
table
|
||||
.remove_connection(a2)
|
||||
.map(|c| c.connection_descriptor()),
|
||||
Ok(a1)
|
||||
);
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_err!(table.remove_connection(a2));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
@ -98,9 +103,24 @@ pub async fn test_add_get_remove() {
|
||||
table.add_connection(c3).unwrap();
|
||||
table.add_connection(c4).unwrap();
|
||||
assert_eq!(table.connection_count(), 3);
|
||||
assert_eq!(table.remove_connection(a2), Ok(()));
|
||||
assert_eq!(table.remove_connection(a3), Ok(()));
|
||||
assert_eq!(table.remove_connection(a4), Ok(()));
|
||||
assert_eq!(
|
||||
table
|
||||
.remove_connection(a2)
|
||||
.map(|c| c.connection_descriptor()),
|
||||
Ok(a2)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.remove_connection(a3)
|
||||
.map(|c| c.connection_descriptor()),
|
||||
Ok(a3)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.remove_connection(a4)
|
||||
.map(|c| c.connection_descriptor()),
|
||||
Ok(a4)
|
||||
);
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,6 @@ impl Network {
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
};
|
||||
if res.is_ok() {
|
||||
@ -102,7 +101,7 @@ impl Network {
|
||||
// Try to send to the exact existing connection if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
|
||||
// connection exists, send over it
|
||||
conn.send_async(data).await.map_err(logthru_net!())?;
|
||||
conn.send_async(data).await?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
|
@ -314,6 +314,7 @@ impl ReceiptManager {
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
debug!("starting receipt manager shutdown");
|
||||
let network_manager = self.network_manager();
|
||||
|
||||
// Stop all tasks
|
||||
@ -325,11 +326,13 @@ impl ReceiptManager {
|
||||
};
|
||||
|
||||
// Wait for everything to stop
|
||||
debug!("waiting for timeout task to stop");
|
||||
if !timeout_task.join().await.is_ok() {
|
||||
panic!("joining timeout task failed");
|
||||
}
|
||||
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
debug!("finished receipt manager shutdown");
|
||||
}
|
||||
|
||||
pub fn record_receipt(
|
||||
|
@ -374,19 +374,26 @@ impl RoutingTable {
|
||||
}
|
||||
|
||||
pub async fn terminate(&self) {
|
||||
debug!("starting routing table terminate");
|
||||
|
||||
// Cancel all tasks being ticked
|
||||
debug!("stopping rolling transfers task");
|
||||
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
|
||||
error!("rolling_transfers_task not stopped: {}", e);
|
||||
}
|
||||
debug!("stopping bootstrap task");
|
||||
if let Err(e) = self.unlocked_inner.bootstrap_task.stop().await {
|
||||
error!("bootstrap_task not stopped: {}", e);
|
||||
}
|
||||
debug!("stopping peer minimum refresh task");
|
||||
if let Err(e) = self.unlocked_inner.peer_minimum_refresh_task.stop().await {
|
||||
error!("peer_minimum_refresh_task not stopped: {}", e);
|
||||
}
|
||||
debug!("stopping ping_validator task");
|
||||
if let Err(e) = self.unlocked_inner.ping_validator_task.stop().await {
|
||||
error!("ping_validator_task not stopped: {}", e);
|
||||
}
|
||||
debug!("stopping node info update singlefuture");
|
||||
if self
|
||||
.unlocked_inner
|
||||
.node_info_update_single_future
|
||||
@ -398,6 +405,8 @@ impl RoutingTable {
|
||||
}
|
||||
|
||||
*self.inner.lock() = Self::new_inner(self.network_manager());
|
||||
|
||||
debug!("finished routing table terminate");
|
||||
}
|
||||
|
||||
// Inform routing table entries that our dial info has changed
|
||||
|
@ -1428,23 +1428,30 @@ impl RPCProcessor {
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
debug!("starting rpc processor shutdown");
|
||||
|
||||
// Stop the rpc workers
|
||||
let mut unord = FuturesUnordered::new();
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
// drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
// take the join handles out
|
||||
for h in inner.worker_join_handles.drain(..) {
|
||||
unord.push(h);
|
||||
}
|
||||
// drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
}
|
||||
debug!("stopping {} rpc worker tasks", unord.len());
|
||||
|
||||
// Wait for them to complete
|
||||
while unord.next().await.is_some() {}
|
||||
|
||||
|
||||
debug!("resetting rpc processor state");
|
||||
|
||||
// Release the rpc processor
|
||||
*self.inner.lock() = Self::new_inner(self.network_manager());
|
||||
|
||||
debug!("finished rpc processor shutdown");
|
||||
}
|
||||
|
||||
pub fn enqueue_message(
|
||||
|
@ -518,10 +518,10 @@ pub async fn test_single_future() {
|
||||
69
|
||||
})
|
||||
.await,
|
||||
Ok(None)
|
||||
Ok((None, true))
|
||||
);
|
||||
assert_eq!(sf.check().await, Ok(None));
|
||||
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None));
|
||||
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok((None, false)));
|
||||
assert_eq!(sf.join().await, Ok(Some(69)));
|
||||
assert_eq!(
|
||||
sf.single_spawn(async {
|
||||
@ -529,7 +529,7 @@ pub async fn test_single_future() {
|
||||
37
|
||||
})
|
||||
.await,
|
||||
Ok(None)
|
||||
Ok((None, true))
|
||||
);
|
||||
intf::sleep(2000).await;
|
||||
assert_eq!(
|
||||
@ -538,7 +538,7 @@ pub async fn test_single_future() {
|
||||
27
|
||||
})
|
||||
.await,
|
||||
Ok(Some(37))
|
||||
Ok((Some(37), true))
|
||||
);
|
||||
intf::sleep(2000).await;
|
||||
assert_eq!(sf.join().await, Ok(Some(27)));
|
||||
@ -555,10 +555,10 @@ pub async fn test_must_join_single_future() {
|
||||
69
|
||||
})
|
||||
.await,
|
||||
Ok(None)
|
||||
Ok((None, true))
|
||||
);
|
||||
assert_eq!(sf.check().await, Ok(None));
|
||||
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok(None));
|
||||
assert_eq!(sf.single_spawn(async { panic!() }).await, Ok((None, false)));
|
||||
assert_eq!(sf.join().await, Ok(Some(69)));
|
||||
assert_eq!(
|
||||
sf.single_spawn(async {
|
||||
@ -566,7 +566,7 @@ pub async fn test_must_join_single_future() {
|
||||
37
|
||||
})
|
||||
.await,
|
||||
Ok(None)
|
||||
Ok((None, true))
|
||||
);
|
||||
intf::sleep(2000).await;
|
||||
assert_eq!(
|
||||
@ -575,7 +575,7 @@ pub async fn test_must_join_single_future() {
|
||||
27
|
||||
})
|
||||
.await,
|
||||
Ok(Some(37))
|
||||
Ok((Some(37), true))
|
||||
);
|
||||
intf::sleep(2000).await;
|
||||
assert_eq!(sf.join().await, Ok(Some(27)));
|
||||
|
@ -1,20 +1,19 @@
|
||||
use async_executors::JoinHandle;
|
||||
use core::future::Future;
|
||||
use core::pin::Pin;
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
use core::task::{Context, Poll};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MustJoinHandle<T> {
|
||||
join_handle: JoinHandle<T>,
|
||||
completed: AtomicBool,
|
||||
completed: bool,
|
||||
}
|
||||
|
||||
impl<T> MustJoinHandle<T> {
|
||||
pub fn new(join_handle: JoinHandle<T>) -> Self {
|
||||
Self {
|
||||
join_handle,
|
||||
completed: AtomicBool::new(false),
|
||||
completed: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -22,7 +21,7 @@ impl<T> MustJoinHandle<T> {
|
||||
impl<T> Drop for MustJoinHandle<T> {
|
||||
fn drop(&mut self) {
|
||||
// panic if we haven't completed
|
||||
if !self.completed.load(Ordering::Relaxed) {
|
||||
if !self.completed {
|
||||
panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.")
|
||||
}
|
||||
}
|
||||
@ -34,7 +33,7 @@ impl<T: 'static> Future for MustJoinHandle<T> {
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match Pin::new(&mut self.join_handle).poll(cx) {
|
||||
Poll::Ready(t) => {
|
||||
self.completed.store(true, Ordering::Relaxed);
|
||||
self.completed = true;
|
||||
Poll::Ready(t)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
@ -131,7 +131,7 @@ where
|
||||
pub async fn single_spawn(
|
||||
&self,
|
||||
future: impl Future<Output = T> + 'static,
|
||||
) -> Result<Option<T>, ()> {
|
||||
) -> Result<(Option<T>,bool), ()> {
|
||||
let mut out: Option<T> = None;
|
||||
|
||||
// See if we have a result we can return
|
||||
@ -164,7 +164,7 @@ where
|
||||
}
|
||||
|
||||
// Return the prior result if we have one
|
||||
Ok(out)
|
||||
Ok((out, run))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -178,7 +178,7 @@ cfg_if! {
|
||||
pub async fn single_spawn(
|
||||
&self,
|
||||
future: impl Future<Output = T> + Send + 'static,
|
||||
) -> Result<Option<T>, ()> {
|
||||
) -> Result<(Option<T>, bool), ()> {
|
||||
let mut out: Option<T> = None;
|
||||
// See if we have a result we can return
|
||||
let maybe_jh = match self.try_lock() {
|
||||
@ -206,7 +206,7 @@ cfg_if! {
|
||||
self.unlock(Some(MustJoinHandle::new(spawn(future))));
|
||||
}
|
||||
// Return the prior result if we have one
|
||||
Ok(out)
|
||||
Ok((out, run))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ where
|
||||
pub async fn single_spawn(
|
||||
&self,
|
||||
future: impl Future<Output = T> + 'static,
|
||||
) -> Result<Option<T>, ()> {
|
||||
) -> Result<(Option<T>, bool), ()> {
|
||||
let mut out: Option<T> = None;
|
||||
|
||||
// See if we have a result we can return
|
||||
@ -193,7 +193,7 @@ where
|
||||
}
|
||||
|
||||
// Return the prior result if we have one
|
||||
Ok(out)
|
||||
Ok((out, run))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -207,7 +207,7 @@ cfg_if! {
|
||||
pub async fn single_spawn(
|
||||
&self,
|
||||
future: impl Future<Output = T> + Send + 'static,
|
||||
) -> Result<Option<T>, ()> {
|
||||
) -> Result<(Option<T>, bool), ()> {
|
||||
let mut out: Option<T> = None;
|
||||
// See if we have a result we can return
|
||||
let maybe_jh = match self.try_lock() {
|
||||
@ -235,7 +235,7 @@ cfg_if! {
|
||||
self.unlock(Some(spawn(future)));
|
||||
}
|
||||
// Return the prior result if we have one
|
||||
Ok(out)
|
||||
Ok((out, run))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,11 +76,13 @@ impl TickTask {
|
||||
let opt_stop_source = &mut *self.stop_source.lock().await;
|
||||
if opt_stop_source.is_none() {
|
||||
// already stopped, just return
|
||||
trace!("tick task already stopped");
|
||||
return Ok(());
|
||||
}
|
||||
*opt_stop_source = None;
|
||||
drop(opt_stop_source.take());
|
||||
|
||||
// wait for completion of the tick task
|
||||
trace!("stopping single future");
|
||||
match self.single_future.join().await {
|
||||
Ok(Some(Err(err))) => Err(err),
|
||||
_ => Ok(()),
|
||||
@ -91,37 +93,61 @@ impl TickTask {
|
||||
let now = get_timestamp();
|
||||
let last_timestamp_us = self.last_timestamp_us.load(Ordering::Acquire);
|
||||
|
||||
if last_timestamp_us == 0u64 || (now - last_timestamp_us) >= self.tick_period_us {
|
||||
// Run the singlefuture
|
||||
let opt_stop_source = &mut *self.stop_source.lock().await;
|
||||
let stop_source = StopSource::new();
|
||||
match self
|
||||
.single_future
|
||||
.single_spawn(self.routine.get().unwrap()(
|
||||
stop_source.token(),
|
||||
last_timestamp_us,
|
||||
now,
|
||||
))
|
||||
.await
|
||||
{
|
||||
// Single future ran this tick
|
||||
Ok(Some(ret)) => {
|
||||
// Set new timer
|
||||
self.last_timestamp_us.store(now, Ordering::Release);
|
||||
// Save new stopper
|
||||
*opt_stop_source = Some(stop_source);
|
||||
ret
|
||||
}
|
||||
// Single future did not run this tick
|
||||
Ok(None) | Err(()) => {
|
||||
// If the execution didn't happen this time because it was already running
|
||||
// then we should try again the next tick and not reset the timestamp so we try as soon as possible
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if last_timestamp_us != 0u64 && (now - last_timestamp_us) < self.tick_period_us {
|
||||
// It's not time yet
|
||||
Ok(())
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Lock the stop source, tells us if we have ever started this future
|
||||
let opt_stop_source = &mut *self.stop_source.lock().await;
|
||||
if opt_stop_source.is_some() {
|
||||
// See if the previous execution finished with an error
|
||||
match self.single_future.check().await {
|
||||
Ok(Some(Err(e))) => {
|
||||
// We have an error result, which means the singlefuture ran but we need to propagate the error
|
||||
return Err(e);
|
||||
}
|
||||
Ok(Some(Ok(()))) => {
|
||||
// We have an ok result, which means the singlefuture ran, and we should run it again this tick
|
||||
}
|
||||
Ok(None) => {
|
||||
// No prior result to return which means things are still running
|
||||
// We can just return now, since the singlefuture will not run a second time
|
||||
return Ok(());
|
||||
}
|
||||
Err(()) => {
|
||||
// If we get this, it's because we are joining the singlefuture already
|
||||
// Don't bother running but this is not an error in this case
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Run the singlefuture
|
||||
let stop_source = StopSource::new();
|
||||
match self
|
||||
.single_future
|
||||
.single_spawn(self.routine.get().unwrap()(
|
||||
stop_source.token(),
|
||||
last_timestamp_us,
|
||||
now,
|
||||
))
|
||||
.await
|
||||
{
|
||||
// We should have already consumed the result of the last run, or there was none
|
||||
// and we should definitely have run, because the prior 'check()' operation
|
||||
// should have ensured the singlefuture was ready to run
|
||||
Ok((None, true)) => {
|
||||
// Set new timer
|
||||
self.last_timestamp_us.store(now, Ordering::Release);
|
||||
// Save new stopper
|
||||
*opt_stop_source = Some(stop_source);
|
||||
Ok(())
|
||||
}
|
||||
// All other conditions should not be reachable
|
||||
_ => {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -32,11 +32,13 @@ pub async fn run_veilid_server(settings: Settings, server_mode: ServerMode) -> R
|
||||
run_veilid_server_internal(settings, server_mode).await
|
||||
}
|
||||
|
||||
#[instrument(err)]
|
||||
#[instrument(err, skip_all)]
|
||||
pub async fn run_veilid_server_internal(
|
||||
settings: Settings,
|
||||
server_mode: ServerMode,
|
||||
) -> Result<(), String> {
|
||||
trace!(?settings, ?server_mode);
|
||||
|
||||
let settingsr = settings.read();
|
||||
|
||||
// Create client api state change pipe
|
||||
|
Loading…
Reference in New Issue
Block a user