refactor for cooperative cancellation
This commit is contained in:
@@ -42,7 +42,8 @@ struct NetworkInner {
|
||||
protocol_config: Option<ProtocolConfig>,
|
||||
static_public_dialinfo: ProtocolSet,
|
||||
network_class: Option<NetworkClass>,
|
||||
join_handles: Vec<JoinHandle<()>>,
|
||||
join_handles: Vec<MustJoinHandle<()>>,
|
||||
stop_source: Option<StopSource>,
|
||||
udp_port: u16,
|
||||
tcp_port: u16,
|
||||
ws_port: u16,
|
||||
@@ -82,6 +83,7 @@ impl Network {
|
||||
static_public_dialinfo: ProtocolSet::empty(),
|
||||
network_class: None,
|
||||
join_handles: Vec::new(),
|
||||
stop_source: None,
|
||||
udp_port: 0u16,
|
||||
tcp_port: 0u16,
|
||||
ws_port: 0u16,
|
||||
@@ -115,8 +117,8 @@ impl Network {
|
||||
let this2 = this.clone();
|
||||
this.unlocked_inner
|
||||
.update_network_class_task
|
||||
.set_routine(move |l, t| {
|
||||
Box::pin(this2.clone().update_network_class_task_routine(l, t))
|
||||
.set_routine(move |s, l, t| {
|
||||
Box::pin(this2.clone().update_network_class_task_routine(s, l, t))
|
||||
});
|
||||
}
|
||||
|
||||
@@ -200,7 +202,7 @@ impl Network {
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
|
||||
fn add_to_join_handles(&self, jh: MustJoinHandle<()>) {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.join_handles.push(jh);
|
||||
}
|
||||
@@ -506,17 +508,28 @@ impl Network {
|
||||
let network_manager = self.network_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Cancel all tasks
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await {
|
||||
warn!("update_network_class_task not cancelled: {}", e);
|
||||
// Stop all tasks
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
|
||||
error!("update_network_class_task not cancelled: {}", e);
|
||||
}
|
||||
let mut unord = FuturesUnordered::new();
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
// take the join handles out
|
||||
for h in inner.join_handles.drain(..) {
|
||||
unord.push(h);
|
||||
}
|
||||
}
|
||||
// Wait for everything to stop
|
||||
while unord.next().await.is_some() {}
|
||||
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||
|
||||
// Reset state including network class
|
||||
// Cancels all async background tasks by dropping join handles
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
|
||||
info!("network stopped");
|
||||
|
@@ -465,7 +465,12 @@ impl Network {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> {
|
||||
pub async fn update_network_class_task_routine(
|
||||
self,
|
||||
stop_token: StopToken,
|
||||
_l: u64,
|
||||
_t: u64,
|
||||
) -> Result<(), String> {
|
||||
// Ensure we aren't trying to update this without clearing it first
|
||||
let old_network_class = self.inner.lock().network_class;
|
||||
assert_eq!(old_network_class, None);
|
||||
|
@@ -2,6 +2,7 @@ use super::*;
|
||||
use crate::intf::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -91,6 +92,106 @@ impl Network {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn tcp_acceptor(
|
||||
self,
|
||||
tcp_stream: async_std::io::Result<TcpStream>,
|
||||
listener_state: Arc<RwLock<ListenerState>>,
|
||||
connection_manager: ConnectionManager,
|
||||
connection_initial_timeout: u64,
|
||||
tls_connection_initial_timeout: u64,
|
||||
) {
|
||||
let tcp_stream = match tcp_stream {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
// If this happened our low-level listener socket probably died
|
||||
// so it's time to restart the network
|
||||
self.inner.lock().network_needs_restart = true;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let listener_state = listener_state.clone();
|
||||
let connection_manager = connection_manager.clone();
|
||||
|
||||
// Limit the number of connections from the same IP address
|
||||
// and the number of total connections
|
||||
let addr = match tcp_stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
log_net!(error "failed to get peer address: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
// Create a stream we can peek on
|
||||
let ps = AsyncPeekStream::new(tcp_stream.clone());
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
// If we fail to get a packet within the connection initial timeout
|
||||
// then we punt this connection
|
||||
log_net!(warn "connection initial timeout from: {:?}", addr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run accept handlers on accepted stream
|
||||
|
||||
// Check is this could be TLS
|
||||
let ls = listener_state.read().clone();
|
||||
|
||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
self.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
tcp_stream,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
let conn = match conn {
|
||||
Ok(Some(c)) => {
|
||||
log_net!("protocol handler found for {:?}: {:?}", addr, c);
|
||||
c
|
||||
}
|
||||
Ok(None) => {
|
||||
// No protocol handlers matched? drop it.
|
||||
log_net!(warn "no protocol handler for connection from {:?}", addr);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
// Failed to negotiate connection? drop it.
|
||||
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager
|
||||
.on_accepted_protocol_network_connection(conn)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
// Get config
|
||||
let (connection_initial_timeout, tls_connection_initial_timeout) = {
|
||||
@@ -123,111 +224,40 @@ impl Network {
|
||||
|
||||
// Spawn the socket task
|
||||
let this = self.clone();
|
||||
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
|
||||
let connection_manager = self.connection_manager();
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
let jh = spawn(async move {
|
||||
// moves listener object in and get incoming iterator
|
||||
// when this task exists, the listener will close the socket
|
||||
listener
|
||||
let _ = listener
|
||||
.incoming()
|
||||
.for_each_concurrent(None, |tcp_stream| async {
|
||||
let tcp_stream = tcp_stream.unwrap();
|
||||
.for_each_concurrent(None, |tcp_stream| {
|
||||
let this = this.clone();
|
||||
let listener_state = listener_state.clone();
|
||||
let connection_manager = connection_manager.clone();
|
||||
|
||||
// Limit the number of connections from the same IP address
|
||||
// and the number of total connections
|
||||
let addr = match tcp_stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
log_net!(error "failed to get peer address: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
// Create a stream we can peek on
|
||||
let ps = AsyncPeekStream::new(tcp_stream.clone());
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
Self::tcp_acceptor(
|
||||
this,
|
||||
tcp_stream,
|
||||
listener_state,
|
||||
connection_manager,
|
||||
connection_initial_timeout,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
// If we fail to get a packet within the connection initial timeout
|
||||
// then we punt this connection
|
||||
log_net!(warn "connection initial timeout from: {:?}", addr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run accept handlers on accepted stream
|
||||
|
||||
// Check is this could be TLS
|
||||
let ls = listener_state.read().clone();
|
||||
|
||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
this.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
tcp_stream,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
let conn = match conn {
|
||||
Ok(Some(c)) => {
|
||||
log_net!("protocol handler found for {:?}: {:?}", addr, c);
|
||||
c
|
||||
}
|
||||
Ok(None) => {
|
||||
// No protocol handlers matched? drop it.
|
||||
log_net!(warn "no protocol handler for connection from {:?}", addr);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
// Failed to negotiate connection? drop it.
|
||||
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager
|
||||
.on_accepted_protocol_network_connection(conn)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
})
|
||||
.timeout_at(stop_token)
|
||||
.await;
|
||||
|
||||
log_net!(debug "exited incoming loop for {}", addr);
|
||||
// Remove our listener state from this address if we're stopping
|
||||
this.inner.lock().listener_states.remove(&addr);
|
||||
log_net!(debug "listener state removed for {}", addr);
|
||||
|
||||
// If this happened our low-level listener socket probably died
|
||||
// so it's time to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
});
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handles
|
||||
self.add_to_join_handles(jh);
|
||||
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
impl Network {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
||||
@@ -43,47 +44,75 @@ impl Network {
|
||||
// Spawn a local async task for each socket
|
||||
let mut protocol_handlers_unordered = FuturesUnordered::new();
|
||||
let network_manager = this.network_manager();
|
||||
let stop_token = this.inner.lock().stop_source.as_ref().unwrap().token();
|
||||
|
||||
for ph in protocol_handlers {
|
||||
let network_manager = network_manager.clone();
|
||||
let stop_token = stop_token.clone();
|
||||
let jh = spawn_local(async move {
|
||||
let mut data = vec![0u8; 65536];
|
||||
|
||||
while let Ok((size, descriptor)) = ph.recv_message(&mut data).await {
|
||||
// XXX: Limit the number of packets from the same IP address?
|
||||
log_net!("UDP packet: {:?}", descriptor);
|
||||
|
||||
// Network accounting
|
||||
network_manager.stats_packet_rcvd(
|
||||
descriptor.remote_address().to_ip_addr(),
|
||||
size as u64,
|
||||
);
|
||||
|
||||
// Pass it up for processing
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
loop {
|
||||
match ph
|
||||
.recv_message(&mut data)
|
||||
.timeout_at(stop_token.clone())
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
Ok(Ok((size, descriptor))) => {
|
||||
// XXX: Limit the number of packets from the same IP address?
|
||||
log_net!("UDP packet: {:?}", descriptor);
|
||||
|
||||
// Network accounting
|
||||
network_manager.stats_packet_rcvd(
|
||||
descriptor.remote_address().to_ip_addr(),
|
||||
size as u64,
|
||||
);
|
||||
|
||||
// Pass it up for processing
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
return false;
|
||||
}
|
||||
Err(_) => {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
protocol_handlers_unordered.push(jh);
|
||||
}
|
||||
// Now we wait for any join handle to exit,
|
||||
// which would indicate an error needing
|
||||
// Now we wait for join handles to exit,
|
||||
// if any error out it indicates an error needing
|
||||
// us to completely restart the network
|
||||
let _ = protocol_handlers_unordered.next().await;
|
||||
loop {
|
||||
match protocol_handlers_unordered.next().await {
|
||||
Some(v) => {
|
||||
// true = stopped, false = errored
|
||||
if !v {
|
||||
// If any protocol handler fails, our socket died and we need to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// All protocol handlers exited
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("UDP listener task stopped");
|
||||
// If this loop fails, our socket died and we need to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
});
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handle
|
||||
self.add_to_join_handles(jh);
|
||||
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user