refactor for cooperative cancellation
This commit is contained in:
@@ -9,11 +9,12 @@ use network_connection::*;
|
||||
#[derive(Debug)]
|
||||
struct ConnectionManagerInner {
|
||||
connection_table: ConnectionTable,
|
||||
stop_source: Option<StopSource>,
|
||||
}
|
||||
|
||||
struct ConnectionManagerArc {
|
||||
network_manager: NetworkManager,
|
||||
inner: AsyncMutex<ConnectionManagerInner>,
|
||||
inner: AsyncMutex<Option<ConnectionManagerInner>>,
|
||||
}
|
||||
impl core::fmt::Debug for ConnectionManagerArc {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
@@ -31,14 +32,14 @@ pub struct ConnectionManager {
|
||||
impl ConnectionManager {
|
||||
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
|
||||
ConnectionManagerInner {
|
||||
stop_source: Some(StopSource::new()),
|
||||
connection_table: ConnectionTable::new(config),
|
||||
}
|
||||
}
|
||||
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
|
||||
let config = network_manager.config();
|
||||
ConnectionManagerArc {
|
||||
network_manager,
|
||||
inner: AsyncMutex::new(Self::new_inner(config)),
|
||||
inner: AsyncMutex::new(None),
|
||||
}
|
||||
}
|
||||
pub fn new(network_manager: NetworkManager) -> Self {
|
||||
@@ -53,12 +54,32 @@ impl ConnectionManager {
|
||||
|
||||
pub async fn startup(&self) {
|
||||
trace!("startup connection manager");
|
||||
//let mut inner = self.arc.inner.lock().await;
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
if inner.is_some() {
|
||||
panic!("shouldn't start connection manager twice without shutting it down first");
|
||||
}
|
||||
|
||||
*inner = Some(Self::new_inner(self.network_manager().config()));
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
// Drops connection table, which drops all connections in it
|
||||
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
|
||||
// Remove the inner from the lock
|
||||
let mut inner = {
|
||||
let mut inner_lock = self.arc.inner.lock().await;
|
||||
let inner = match inner_lock.take() {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
panic!("not started");
|
||||
}
|
||||
};
|
||||
inner
|
||||
};
|
||||
|
||||
// Stop all the connections
|
||||
drop(inner.stop_source.take());
|
||||
|
||||
// Wait for the connections to complete
|
||||
inner.connection_table.join().await;
|
||||
}
|
||||
|
||||
// Returns a network connection if one already is established
|
||||
@@ -67,6 +88,12 @@ impl ConnectionManager {
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Option<ConnectionHandle> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
panic!("not started");
|
||||
}
|
||||
};
|
||||
inner.connection_table.get_connection(descriptor)
|
||||
}
|
||||
|
||||
@@ -81,24 +108,18 @@ impl ConnectionManager {
|
||||
log_net!("on_new_protocol_network_connection: {:?}", conn);
|
||||
|
||||
// Wrap with NetworkConnection object to start the connection processing loop
|
||||
let conn = NetworkConnection::from_protocol(self.clone(), conn);
|
||||
let stop_token = match &inner.stop_source {
|
||||
Some(ss) => ss.token(),
|
||||
None => return Err("not creating connection because we are stopping".to_owned()),
|
||||
};
|
||||
|
||||
let conn = NetworkConnection::from_protocol(self.clone(), stop_token, conn);
|
||||
let handle = conn.get_handle();
|
||||
// Add to the connection table
|
||||
inner.connection_table.add_connection(conn)?;
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
// Called by low-level network when any connection-oriented protocol connection appears
|
||||
// either from incoming connections.
|
||||
pub(super) async fn on_accepted_protocol_network_connection(
|
||||
&self,
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// Called when we want to create a new connection or get the current one that already exists
|
||||
// This will kill off any connections that are in conflict with the new connection to be made
|
||||
// in order to make room for the new connection in the system's connection table
|
||||
@@ -107,6 +128,14 @@ impl ConnectionManager {
|
||||
local_addr: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ConnectionHandle, String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
panic!("not started");
|
||||
}
|
||||
};
|
||||
|
||||
log_net!(
|
||||
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
||||
local_addr.green(),
|
||||
@@ -123,7 +152,6 @@ impl ConnectionManager {
|
||||
|
||||
// If any connection to this remote exists that has the same protocol, return it
|
||||
// Any connection will do, we don't have to match the local address
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
|
||||
if let Some(conn) = inner
|
||||
.connection_table
|
||||
@@ -197,10 +225,39 @@ impl ConnectionManager {
|
||||
self.on_new_protocol_network_connection(&mut *inner, conn)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Callbacks
|
||||
|
||||
// Called by low-level network when any connection-oriented protocol connection appears
|
||||
// either from incoming connections.
|
||||
pub(super) async fn on_accepted_protocol_network_connection(
|
||||
&self,
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// If we are shutting down, just drop this and return
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
self.on_new_protocol_network_connection(inner, conn)
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// Callback from network connection receive loop when it exits
|
||||
// cleans up the entry in the connection table
|
||||
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
// If we're shutting down, do nothing here
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
|
||||
log_net!(error e);
|
||||
}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use alloc::collections::btree_map::Entry;
|
||||
use futures_util::StreamExt;
|
||||
use hashlink::LruCache;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -41,6 +42,16 @@ impl ConnectionTable {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn join(&mut self) {
|
||||
let mut unord = FuturesUnordered::new();
|
||||
for table in &mut self.conn_by_descriptor {
|
||||
for (_, v) in table.drain() {
|
||||
unord.push(v);
|
||||
}
|
||||
}
|
||||
while unord.next().await.is_some() {}
|
||||
}
|
||||
|
||||
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
|
||||
let descriptor = conn.connection_descriptor();
|
||||
let ip_addr = descriptor.remote_address().to_ip_addr();
|
||||
|
@@ -171,8 +171,8 @@ impl NetworkManager {
|
||||
let this2 = this.clone();
|
||||
this.unlocked_inner
|
||||
.rolling_transfers_task
|
||||
.set_routine(move |l, t| {
|
||||
Box::pin(this2.clone().rolling_transfers_task_routine(l, t))
|
||||
.set_routine(move |s, l, t| {
|
||||
Box::pin(this2.clone().rolling_transfers_task_routine(s, l, t))
|
||||
});
|
||||
}
|
||||
// Set relay management tick task
|
||||
@@ -180,8 +180,8 @@ impl NetworkManager {
|
||||
let this2 = this.clone();
|
||||
this.unlocked_inner
|
||||
.relay_management_task
|
||||
.set_routine(move |l, t| {
|
||||
Box::pin(this2.clone().relay_management_task_routine(l, t))
|
||||
.set_routine(move |s, l, t| {
|
||||
Box::pin(this2.clone().relay_management_task_routine(s, l, t))
|
||||
});
|
||||
}
|
||||
this
|
||||
@@ -275,10 +275,10 @@ impl NetworkManager {
|
||||
});
|
||||
|
||||
// Start network components
|
||||
connection_manager.startup().await;
|
||||
net.startup().await?;
|
||||
rpc_processor.startup().await?;
|
||||
receipt_manager.startup().await?;
|
||||
net.startup().await?;
|
||||
connection_manager.startup().await;
|
||||
|
||||
trace!("NetworkManager::internal_startup end");
|
||||
|
||||
@@ -302,20 +302,20 @@ impl NetworkManager {
|
||||
trace!("NetworkManager::shutdown begin");
|
||||
|
||||
// Cancel all tasks
|
||||
if let Err(e) = self.unlocked_inner.rolling_transfers_task.cancel().await {
|
||||
warn!("rolling_transfers_task not cancelled: {}", e);
|
||||
if let Err(e) = self.unlocked_inner.rolling_transfers_task.stop().await {
|
||||
warn!("rolling_transfers_task not stopped: {}", e);
|
||||
}
|
||||
if let Err(e) = self.unlocked_inner.relay_management_task.cancel().await {
|
||||
warn!("relay_management_task not cancelled: {}", e);
|
||||
if let Err(e) = self.unlocked_inner.relay_management_task.stop().await {
|
||||
warn!("relay_management_task not stopped: {}", e);
|
||||
}
|
||||
|
||||
// Shutdown network components if they started up
|
||||
let components = self.inner.lock().components.clone();
|
||||
if let Some(components) = components {
|
||||
components.connection_manager.shutdown().await;
|
||||
components.net.shutdown().await;
|
||||
components.receipt_manager.shutdown().await;
|
||||
components.rpc_processor.shutdown().await;
|
||||
components.net.shutdown().await;
|
||||
components.connection_manager.shutdown().await;
|
||||
}
|
||||
|
||||
// reset the state
|
||||
@@ -1202,7 +1202,12 @@ impl NetworkManager {
|
||||
|
||||
// Keep relays assigned and accessible
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
async fn relay_management_task_routine(self, _last_ts: u64, cur_ts: u64) -> Result<(), String> {
|
||||
async fn relay_management_task_routine(
|
||||
self,
|
||||
stop_token: StopToken,
|
||||
_last_ts: u64,
|
||||
cur_ts: u64,
|
||||
) -> Result<(), String> {
|
||||
// log_net!("--- network manager relay_management task");
|
||||
|
||||
// Get our node's current node info and network class and do the right thing
|
||||
@@ -1255,7 +1260,12 @@ impl NetworkManager {
|
||||
|
||||
// Compute transfer statistics for the low level network
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
async fn rolling_transfers_task_routine(self, last_ts: u64, cur_ts: u64) -> Result<(), String> {
|
||||
async fn rolling_transfers_task_routine(
|
||||
self,
|
||||
stop_token: StopToken,
|
||||
last_ts: u64,
|
||||
cur_ts: u64,
|
||||
) -> Result<(), String> {
|
||||
// log_net!("--- network manager rolling_transfers task");
|
||||
{
|
||||
let inner = &mut *self.inner.lock();
|
||||
|
@@ -42,7 +42,8 @@ struct NetworkInner {
|
||||
protocol_config: Option<ProtocolConfig>,
|
||||
static_public_dialinfo: ProtocolSet,
|
||||
network_class: Option<NetworkClass>,
|
||||
join_handles: Vec<JoinHandle<()>>,
|
||||
join_handles: Vec<MustJoinHandle<()>>,
|
||||
stop_source: Option<StopSource>,
|
||||
udp_port: u16,
|
||||
tcp_port: u16,
|
||||
ws_port: u16,
|
||||
@@ -82,6 +83,7 @@ impl Network {
|
||||
static_public_dialinfo: ProtocolSet::empty(),
|
||||
network_class: None,
|
||||
join_handles: Vec::new(),
|
||||
stop_source: None,
|
||||
udp_port: 0u16,
|
||||
tcp_port: 0u16,
|
||||
ws_port: 0u16,
|
||||
@@ -115,8 +117,8 @@ impl Network {
|
||||
let this2 = this.clone();
|
||||
this.unlocked_inner
|
||||
.update_network_class_task
|
||||
.set_routine(move |l, t| {
|
||||
Box::pin(this2.clone().update_network_class_task_routine(l, t))
|
||||
.set_routine(move |s, l, t| {
|
||||
Box::pin(this2.clone().update_network_class_task_routine(s, l, t))
|
||||
});
|
||||
}
|
||||
|
||||
@@ -200,7 +202,7 @@ impl Network {
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
|
||||
fn add_to_join_handles(&self, jh: MustJoinHandle<()>) {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.join_handles.push(jh);
|
||||
}
|
||||
@@ -506,17 +508,28 @@ impl Network {
|
||||
let network_manager = self.network_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Cancel all tasks
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await {
|
||||
warn!("update_network_class_task not cancelled: {}", e);
|
||||
// Stop all tasks
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
|
||||
error!("update_network_class_task not cancelled: {}", e);
|
||||
}
|
||||
let mut unord = FuturesUnordered::new();
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
// take the join handles out
|
||||
for h in inner.join_handles.drain(..) {
|
||||
unord.push(h);
|
||||
}
|
||||
}
|
||||
// Wait for everything to stop
|
||||
while unord.next().await.is_some() {}
|
||||
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||
|
||||
// Reset state including network class
|
||||
// Cancels all async background tasks by dropping join handles
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
|
||||
info!("network stopped");
|
||||
|
@@ -465,7 +465,12 @@ impl Network {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> {
|
||||
pub async fn update_network_class_task_routine(
|
||||
self,
|
||||
stop_token: StopToken,
|
||||
_l: u64,
|
||||
_t: u64,
|
||||
) -> Result<(), String> {
|
||||
// Ensure we aren't trying to update this without clearing it first
|
||||
let old_network_class = self.inner.lock().network_class;
|
||||
assert_eq!(old_network_class, None);
|
||||
|
@@ -2,6 +2,7 @@ use super::*;
|
||||
use crate::intf::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -91,6 +92,106 @@ impl Network {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn tcp_acceptor(
|
||||
self,
|
||||
tcp_stream: async_std::io::Result<TcpStream>,
|
||||
listener_state: Arc<RwLock<ListenerState>>,
|
||||
connection_manager: ConnectionManager,
|
||||
connection_initial_timeout: u64,
|
||||
tls_connection_initial_timeout: u64,
|
||||
) {
|
||||
let tcp_stream = match tcp_stream {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
// If this happened our low-level listener socket probably died
|
||||
// so it's time to restart the network
|
||||
self.inner.lock().network_needs_restart = true;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let listener_state = listener_state.clone();
|
||||
let connection_manager = connection_manager.clone();
|
||||
|
||||
// Limit the number of connections from the same IP address
|
||||
// and the number of total connections
|
||||
let addr = match tcp_stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
log_net!(error "failed to get peer address: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
// Create a stream we can peek on
|
||||
let ps = AsyncPeekStream::new(tcp_stream.clone());
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
// If we fail to get a packet within the connection initial timeout
|
||||
// then we punt this connection
|
||||
log_net!(warn "connection initial timeout from: {:?}", addr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run accept handlers on accepted stream
|
||||
|
||||
// Check is this could be TLS
|
||||
let ls = listener_state.read().clone();
|
||||
|
||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
self.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
tcp_stream,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
let conn = match conn {
|
||||
Ok(Some(c)) => {
|
||||
log_net!("protocol handler found for {:?}: {:?}", addr, c);
|
||||
c
|
||||
}
|
||||
Ok(None) => {
|
||||
// No protocol handlers matched? drop it.
|
||||
log_net!(warn "no protocol handler for connection from {:?}", addr);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
// Failed to negotiate connection? drop it.
|
||||
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager
|
||||
.on_accepted_protocol_network_connection(conn)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
// Get config
|
||||
let (connection_initial_timeout, tls_connection_initial_timeout) = {
|
||||
@@ -123,111 +224,40 @@ impl Network {
|
||||
|
||||
// Spawn the socket task
|
||||
let this = self.clone();
|
||||
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
|
||||
let connection_manager = self.connection_manager();
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
let jh = spawn(async move {
|
||||
// moves listener object in and get incoming iterator
|
||||
// when this task exists, the listener will close the socket
|
||||
listener
|
||||
let _ = listener
|
||||
.incoming()
|
||||
.for_each_concurrent(None, |tcp_stream| async {
|
||||
let tcp_stream = tcp_stream.unwrap();
|
||||
.for_each_concurrent(None, |tcp_stream| {
|
||||
let this = this.clone();
|
||||
let listener_state = listener_state.clone();
|
||||
let connection_manager = connection_manager.clone();
|
||||
|
||||
// Limit the number of connections from the same IP address
|
||||
// and the number of total connections
|
||||
let addr = match tcp_stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
log_net!(error "failed to get peer address: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
// Create a stream we can peek on
|
||||
let ps = AsyncPeekStream::new(tcp_stream.clone());
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
Self::tcp_acceptor(
|
||||
this,
|
||||
tcp_stream,
|
||||
listener_state,
|
||||
connection_manager,
|
||||
connection_initial_timeout,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
// If we fail to get a packet within the connection initial timeout
|
||||
// then we punt this connection
|
||||
log_net!(warn "connection initial timeout from: {:?}", addr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run accept handlers on accepted stream
|
||||
|
||||
// Check is this could be TLS
|
||||
let ls = listener_state.read().clone();
|
||||
|
||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
this.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
tcp_stream,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
let conn = match conn {
|
||||
Ok(Some(c)) => {
|
||||
log_net!("protocol handler found for {:?}: {:?}", addr, c);
|
||||
c
|
||||
}
|
||||
Ok(None) => {
|
||||
// No protocol handlers matched? drop it.
|
||||
log_net!(warn "no protocol handler for connection from {:?}", addr);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
// Failed to negotiate connection? drop it.
|
||||
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager
|
||||
.on_accepted_protocol_network_connection(conn)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
})
|
||||
.timeout_at(stop_token)
|
||||
.await;
|
||||
|
||||
log_net!(debug "exited incoming loop for {}", addr);
|
||||
// Remove our listener state from this address if we're stopping
|
||||
this.inner.lock().listener_states.remove(&addr);
|
||||
log_net!(debug "listener state removed for {}", addr);
|
||||
|
||||
// If this happened our low-level listener socket probably died
|
||||
// so it's time to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
});
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handles
|
||||
self.add_to_join_handles(jh);
|
||||
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
impl Network {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
||||
@@ -43,47 +44,75 @@ impl Network {
|
||||
// Spawn a local async task for each socket
|
||||
let mut protocol_handlers_unordered = FuturesUnordered::new();
|
||||
let network_manager = this.network_manager();
|
||||
let stop_token = this.inner.lock().stop_source.as_ref().unwrap().token();
|
||||
|
||||
for ph in protocol_handlers {
|
||||
let network_manager = network_manager.clone();
|
||||
let stop_token = stop_token.clone();
|
||||
let jh = spawn_local(async move {
|
||||
let mut data = vec![0u8; 65536];
|
||||
|
||||
while let Ok((size, descriptor)) = ph.recv_message(&mut data).await {
|
||||
// XXX: Limit the number of packets from the same IP address?
|
||||
log_net!("UDP packet: {:?}", descriptor);
|
||||
|
||||
// Network accounting
|
||||
network_manager.stats_packet_rcvd(
|
||||
descriptor.remote_address().to_ip_addr(),
|
||||
size as u64,
|
||||
);
|
||||
|
||||
// Pass it up for processing
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
loop {
|
||||
match ph
|
||||
.recv_message(&mut data)
|
||||
.timeout_at(stop_token.clone())
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
Ok(Ok((size, descriptor))) => {
|
||||
// XXX: Limit the number of packets from the same IP address?
|
||||
log_net!("UDP packet: {:?}", descriptor);
|
||||
|
||||
// Network accounting
|
||||
network_manager.stats_packet_rcvd(
|
||||
descriptor.remote_address().to_ip_addr(),
|
||||
size as u64,
|
||||
);
|
||||
|
||||
// Pass it up for processing
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
return false;
|
||||
}
|
||||
Err(_) => {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
protocol_handlers_unordered.push(jh);
|
||||
}
|
||||
// Now we wait for any join handle to exit,
|
||||
// which would indicate an error needing
|
||||
// Now we wait for join handles to exit,
|
||||
// if any error out it indicates an error needing
|
||||
// us to completely restart the network
|
||||
let _ = protocol_handlers_unordered.next().await;
|
||||
loop {
|
||||
match protocol_handlers_unordered.next().await {
|
||||
Some(v) => {
|
||||
// true = stopped, false = errored
|
||||
if !v {
|
||||
// If any protocol handler fails, our socket died and we need to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// All protocol handlers exited
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("UDP listener task stopped");
|
||||
// If this loop fails, our socket died and we need to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
});
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handle
|
||||
self.add_to_join_handles(jh);
|
||||
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use futures_util::{FutureExt, StreamExt};
|
||||
use stop_token::prelude::*;
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
@@ -84,7 +85,7 @@ pub struct NetworkConnectionStats {
|
||||
#[derive(Debug)]
|
||||
pub struct NetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
_processor: Option<JoinHandle<()>>,
|
||||
processor: Option<MustJoinHandle<()>>,
|
||||
established_time: u64,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
sender: flume::Sender<Vec<u8>>,
|
||||
@@ -97,7 +98,7 @@ impl NetworkConnection {
|
||||
|
||||
Self {
|
||||
descriptor,
|
||||
_processor: None,
|
||||
processor: None,
|
||||
established_time: intf::get_timestamp(),
|
||||
stats: Arc::new(Mutex::new(NetworkConnectionStats {
|
||||
last_message_sent_time: None,
|
||||
@@ -109,6 +110,7 @@ impl NetworkConnection {
|
||||
|
||||
pub(super) fn from_protocol(
|
||||
connection_manager: ConnectionManager,
|
||||
stop_token: StopToken,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
) -> Self {
|
||||
// Get timeout
|
||||
@@ -132,19 +134,20 @@ impl NetworkConnection {
|
||||
}));
|
||||
|
||||
// Spawn connection processor and pass in protocol connection
|
||||
let processor = intf::spawn_local(Self::process_connection(
|
||||
let processor = MustJoinHandle::new(intf::spawn_local(Self::process_connection(
|
||||
connection_manager,
|
||||
stop_token,
|
||||
descriptor.clone(),
|
||||
receiver,
|
||||
protocol_connection,
|
||||
inactivity_timeout,
|
||||
stats.clone(),
|
||||
));
|
||||
)));
|
||||
|
||||
// Return the connection
|
||||
Self {
|
||||
descriptor,
|
||||
_processor: Some(processor),
|
||||
processor: Some(processor),
|
||||
established_time: intf::get_timestamp(),
|
||||
stats,
|
||||
sender,
|
||||
@@ -197,6 +200,7 @@ impl NetworkConnection {
|
||||
// Connection receiver loop
|
||||
fn process_connection(
|
||||
connection_manager: ConnectionManager,
|
||||
stop_token: StopToken,
|
||||
descriptor: ConnectionDescriptor,
|
||||
receiver: flume::Receiver<Vec<u8>>,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
@@ -289,26 +293,28 @@ impl NetworkConnection {
|
||||
}
|
||||
|
||||
// Process futures
|
||||
match unord.next().await {
|
||||
Some(RecvLoopAction::Send) => {
|
||||
match unord.next().timeout_at(stop_token.clone()).await {
|
||||
Ok(Some(RecvLoopAction::Send)) => {
|
||||
// Don't reset inactivity timer if we're only sending
|
||||
|
||||
need_sender = true;
|
||||
}
|
||||
Some(RecvLoopAction::Recv) => {
|
||||
Ok(Some(RecvLoopAction::Recv)) => {
|
||||
// Reset inactivity timer since we got something from this connection
|
||||
timer.set(new_timer());
|
||||
|
||||
need_receiver = true;
|
||||
}
|
||||
Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => {
|
||||
Ok(Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout)) => {
|
||||
break;
|
||||
}
|
||||
|
||||
None => {
|
||||
Ok(None) => {
|
||||
// Should not happen
|
||||
unreachable!();
|
||||
}
|
||||
Err(_) => {
|
||||
// Stop token
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,9 +323,23 @@ impl NetworkConnection {
|
||||
descriptor.green()
|
||||
);
|
||||
|
||||
// Let the connection manager know the receive loop exited
|
||||
connection_manager
|
||||
.report_connection_finished(descriptor)
|
||||
.await
|
||||
.await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Resolves ready when the connection loop has terminated
|
||||
impl Future for NetworkConnection {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
if let Some(mut processor) = self.processor.as_mut() {
|
||||
Pin::new(&mut processor).poll(cx)
|
||||
} else {
|
||||
task::Poll::Ready(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user