refactor for cooperative cancellation

This commit is contained in:
John Smith
2022-06-12 20:58:02 -04:00
parent bcc1bfc1a3
commit 180628beef
19 changed files with 810 additions and 228 deletions

View File

@@ -42,7 +42,8 @@ struct NetworkInner {
protocol_config: Option<ProtocolConfig>,
static_public_dialinfo: ProtocolSet,
network_class: Option<NetworkClass>,
join_handles: Vec<JoinHandle<()>>,
join_handles: Vec<MustJoinHandle<()>>,
stop_source: Option<StopSource>,
udp_port: u16,
tcp_port: u16,
ws_port: u16,
@@ -82,6 +83,7 @@ impl Network {
static_public_dialinfo: ProtocolSet::empty(),
network_class: None,
join_handles: Vec::new(),
stop_source: None,
udp_port: 0u16,
tcp_port: 0u16,
ws_port: 0u16,
@@ -115,8 +117,8 @@ impl Network {
let this2 = this.clone();
this.unlocked_inner
.update_network_class_task
.set_routine(move |l, t| {
Box::pin(this2.clone().update_network_class_task_routine(l, t))
.set_routine(move |s, l, t| {
Box::pin(this2.clone().update_network_class_task_routine(s, l, t))
});
}
@@ -200,7 +202,7 @@ impl Network {
Ok(config)
}
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
fn add_to_join_handles(&self, jh: MustJoinHandle<()>) {
let mut inner = self.inner.lock();
inner.join_handles.push(jh);
}
@@ -506,17 +508,28 @@ impl Network {
let network_manager = self.network_manager();
let routing_table = self.routing_table();
// Cancel all tasks
if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await {
warn!("update_network_class_task not cancelled: {}", e);
// Stop all tasks
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
error!("update_network_class_task not cancelled: {}", e);
}
let mut unord = FuturesUnordered::new();
{
let mut inner = self.inner.lock();
// Drop the stop
drop(inner.stop_source.take());
// take the join handles out
for h in inner.join_handles.drain(..) {
unord.push(h);
}
}
// Wait for everything to stop
while unord.next().await.is_some() {}
// Drop all dial info
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
// Reset state including network class
// Cancels all async background tasks by dropping join handles
*self.inner.lock() = Self::new_inner(network_manager);
info!("network stopped");

View File

@@ -465,7 +465,12 @@ impl Network {
}
#[instrument(level = "trace", skip(self), err)]
pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> {
pub async fn update_network_class_task_routine(
self,
stop_token: StopToken,
_l: u64,
_t: u64,
) -> Result<(), String> {
// Ensure we aren't trying to update this without clearing it first
let old_network_class = self.inner.lock().network_class;
assert_eq!(old_network_class, None);

View File

@@ -2,6 +2,7 @@ use super::*;
use crate::intf::*;
use async_tls::TlsAcceptor;
use sockets::*;
use stop_token::future::FutureExt;
/////////////////////////////////////////////////////////////////
@@ -91,6 +92,106 @@ impl Network {
Ok(None)
}
async fn tcp_acceptor(
self,
tcp_stream: async_std::io::Result<TcpStream>,
listener_state: Arc<RwLock<ListenerState>>,
connection_manager: ConnectionManager,
connection_initial_timeout: u64,
tls_connection_initial_timeout: u64,
) {
let tcp_stream = match tcp_stream {
Ok(v) => v,
Err(_) => {
// If this happened our low-level listener socket probably died
// so it's time to restart the network
self.inner.lock().network_needs_restart = true;
return;
}
};
let listener_state = listener_state.clone();
let connection_manager = connection_manager.clone();
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
log_net!(error "failed to get peer address: {}", e);
return;
}
};
// XXX limiting
log_net!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream.clone());
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
log_net!(warn "connection initial timeout from: {:?}", addr);
return;
}
// Run accept handlers on accepted stream
// Check is this could be TLS
let ls = listener_state.read().clone();
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
self.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
tcp_stream,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await
} else {
self.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
.await
};
let conn = match conn {
Ok(Some(c)) => {
log_net!("protocol handler found for {:?}: {:?}", addr, c);
c
}
Ok(None) => {
// No protocol handlers matched? drop it.
log_net!(warn "no protocol handler for connection from {:?}", addr);
return;
}
Err(e) => {
// Failed to negotiate connection? drop it.
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
return;
}
};
// Register the new connection in the connection manager
if let Err(e) = connection_manager
.on_accepted_protocol_network_connection(conn)
.await
{
log_net!(error "failed to register new connection: {}", e);
}
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
// Get config
let (connection_initial_timeout, tls_connection_initial_timeout) = {
@@ -123,111 +224,40 @@ impl Network {
// Spawn the socket task
let this = self.clone();
let stop_token = self.inner.lock().stop_source.as_ref().unwrap().token();
let connection_manager = self.connection_manager();
////////////////////////////////////////////////////////////
let jh = spawn(async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
listener
let _ = listener
.incoming()
.for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap();
.for_each_concurrent(None, |tcp_stream| {
let this = this.clone();
let listener_state = listener_state.clone();
let connection_manager = connection_manager.clone();
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
log_net!(error "failed to get peer address: {}", e);
return;
}
};
// XXX limiting
log_net!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream.clone());
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
Self::tcp_acceptor(
this,
tcp_stream,
listener_state,
connection_manager,
connection_initial_timeout,
tls_connection_initial_timeout,
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
log_net!(warn "connection initial timeout from: {:?}", addr);
return;
}
// Run accept handlers on accepted stream
// Check is this could be TLS
let ls = listener_state.read().clone();
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
this.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
tcp_stream,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await
} else {
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
.await
};
let conn = match conn {
Ok(Some(c)) => {
log_net!("protocol handler found for {:?}: {:?}", addr, c);
c
}
Ok(None) => {
// No protocol handlers matched? drop it.
log_net!(warn "no protocol handler for connection from {:?}", addr);
return;
}
Err(e) => {
// Failed to negotiate connection? drop it.
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
return;
}
};
// Register the new connection in the connection manager
if let Err(e) = connection_manager
.on_accepted_protocol_network_connection(conn)
.await
{
log_net!(error "failed to register new connection: {}", e);
}
})
.timeout_at(stop_token)
.await;
log_net!(debug "exited incoming loop for {}", addr);
// Remove our listener state from this address if we're stopping
this.inner.lock().listener_states.remove(&addr);
log_net!(debug "listener state removed for {}", addr);
// If this happened our low-level listener socket probably died
// so it's time to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handles
self.add_to_join_handles(jh);
self.add_to_join_handles(MustJoinHandle::new(jh));
Ok(())
}

View File

@@ -1,5 +1,6 @@
use super::*;
use sockets::*;
use stop_token::future::FutureExt;
impl Network {
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
@@ -43,47 +44,75 @@ impl Network {
// Spawn a local async task for each socket
let mut protocol_handlers_unordered = FuturesUnordered::new();
let network_manager = this.network_manager();
let stop_token = this.inner.lock().stop_source.as_ref().unwrap().token();
for ph in protocol_handlers {
let network_manager = network_manager.clone();
let stop_token = stop_token.clone();
let jh = spawn_local(async move {
let mut data = vec![0u8; 65536];
while let Ok((size, descriptor)) = ph.recv_message(&mut data).await {
// XXX: Limit the number of packets from the same IP address?
log_net!("UDP packet: {:?}", descriptor);
// Network accounting
network_manager.stats_packet_rcvd(
descriptor.remote_address().to_ip_addr(),
size as u64,
);
// Pass it up for processing
if let Err(e) = network_manager
.on_recv_envelope(&data[..size], descriptor)
loop {
match ph
.recv_message(&mut data)
.timeout_at(stop_token.clone())
.await
{
log_net!(error "failed to process received udp envelope: {}", e);
Ok(Ok((size, descriptor))) => {
// XXX: Limit the number of packets from the same IP address?
log_net!("UDP packet: {:?}", descriptor);
// Network accounting
network_manager.stats_packet_rcvd(
descriptor.remote_address().to_ip_addr(),
size as u64,
);
// Pass it up for processing
if let Err(e) = network_manager
.on_recv_envelope(&data[..size], descriptor)
.await
{
log_net!(error "failed to process received udp envelope: {}", e);
}
}
Ok(Err(_)) => {
return false;
}
Err(_) => {
return true;
}
}
}
});
protocol_handlers_unordered.push(jh);
}
// Now we wait for any join handle to exit,
// which would indicate an error needing
// Now we wait for join handles to exit,
// if any error out it indicates an error needing
// us to completely restart the network
let _ = protocol_handlers_unordered.next().await;
loop {
match protocol_handlers_unordered.next().await {
Some(v) => {
// true = stopped, false = errored
if !v {
// If any protocol handler fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
}
}
None => {
// All protocol handlers exited
break;
}
}
}
trace!("UDP listener task stopped");
// If this loop fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handle
self.add_to_join_handles(jh);
self.add_to_join_handles(MustJoinHandle::new(jh));
}
Ok(())