fix cooperative cancellation

This commit is contained in:
John Smith
2022-06-15 14:05:04 -04:00
parent 180628beef
commit c33f78ac8b
24 changed files with 520 additions and 299 deletions

View File

@@ -279,20 +279,14 @@ impl Network {
let res = match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
.await
.map_err(logthru_net!())
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
}
};
if res.is_ok() {
@@ -324,10 +318,7 @@ impl Network {
descriptor
);
ph.clone()
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!())?;
ph.clone().send_message(data, peer_socket_addr).await?;
// Network accounting
self.network_manager()
@@ -345,7 +336,7 @@ impl Network {
log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it
conn.send_async(data).await.map_err(logthru_net!())?;
conn.send_async(data).await?;
// Network accounting
self.network_manager()
@@ -372,10 +363,7 @@ impl Network {
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
let res = ph
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
let res = ph.send_message(data, peer_socket_addr).await;
if res.is_ok() {
// Network accounting
self.network_manager()
@@ -383,8 +371,7 @@ impl Network {
}
return res;
}
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
.map_err(logthru_net!(error));
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
}
// Handle connection-oriented protocols
@@ -394,7 +381,7 @@ impl Network {
.get_or_create_connection(Some(local_addr), dial_info.clone())
.await?;
let res = conn.send_async(data).await.map_err(logthru_net!(error));
let res = conn.send_async(data).await;
if res.is_ok() {
// Network accounting
self.network_manager()
@@ -414,43 +401,51 @@ impl Network {
// initialize interfaces
let mut interfaces = NetworkInterfaces::new();
interfaces.refresh().await?;
self.inner.lock().interfaces = interfaces;
// get protocol config
let protocol_config = {
let c = self.config.get();
let mut inbound = ProtocolSet::new();
let mut inner = self.inner.lock();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
inbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
inbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
inbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
inbound.insert(ProtocolType::WSS);
}
// Create stop source
inner.stop_source = Some(StopSource::new());
inner.interfaces = interfaces;
let mut outbound = ProtocolSet::new();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
outbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
outbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
outbound.insert(ProtocolType::WSS);
}
// get protocol config
let protocol_config = {
let c = self.config.get();
let mut inbound = ProtocolSet::new();
ProtocolConfig { inbound, outbound }
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
inbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
inbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
inbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
inbound.insert(ProtocolType::WSS);
}
let mut outbound = ProtocolSet::new();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
outbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
outbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
outbound.insert(ProtocolType::WSS);
}
ProtocolConfig { inbound, outbound }
};
inner.protocol_config = Some(protocol_config);
protocol_config
};
self.inner.lock().protocol_config = Some(protocol_config);
// start listeners
if protocol_config.inbound.contains(ProtocolType::UDP) {
@@ -503,28 +498,32 @@ impl Network {
#[instrument(level = "debug", skip_all)]
pub async fn shutdown(&self) {
info!("stopping network");
debug!("starting low level network shutdown");
let network_manager = self.network_manager();
let routing_table = self.routing_table();
// Stop all tasks
debug!("stopping update network class task");
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
error!("update_network_class_task not cancelled: {}", e);
}
let mut unord = FuturesUnordered::new();
{
let mut inner = self.inner.lock();
// Drop the stop
drop(inner.stop_source.take());
// take the join handles out
for h in inner.join_handles.drain(..) {
unord.push(h);
}
// Drop the stop
drop(inner.stop_source.take());
}
debug!("stopping {} low level network tasks", unord.len());
// Wait for everything to stop
while unord.next().await.is_some() {}
debug!("clearing dial info");
// Drop all dial info
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
@@ -532,7 +531,7 @@ impl Network {
// Reset state including network class
*self.inner.lock() = Self::new_inner(network_manager);
info!("network stopped");
debug!("finished low level network shutdown");
}
//////////////////////////////////////////

View File

@@ -1,6 +1,7 @@
use super::*;
use futures_util::stream::FuturesUnordered;
use futures_util::FutureExt;
use stop_token::future::FutureExt as StopTokenFutureExt;
struct DetectedPublicDialInfo {
dial_info: DialInfo,
@@ -584,20 +585,32 @@ impl Network {
// Wait for all discovery futures to complete and collect contexts
let mut contexts = Vec::<DiscoveryContext>::new();
let mut network_class = Option::<NetworkClass>::None;
while let Some(ctxvec) = unord.next().await {
if let Some(ctxvec) = ctxvec {
for ctx in ctxvec {
if let Some(nc) = ctx.inner.lock().detected_network_class {
if let Some(last_nc) = network_class {
if nc < last_nc {
network_class = Some(nc);
loop {
match unord.next().timeout_at(stop_token.clone()).await {
Ok(Some(ctxvec)) => {
if let Some(ctxvec) = ctxvec {
for ctx in ctxvec {
if let Some(nc) = ctx.inner.lock().detected_network_class {
if let Some(last_nc) = network_class {
if nc < last_nc {
network_class = Some(nc);
}
} else {
network_class = Some(nc);
}
}
} else {
network_class = Some(nc);
contexts.push(ctx);
}
}
contexts.push(ctx);
}
Ok(None) => {
// Normal completion
break;
}
Err(_) => {
// Stop token, exit early without error propagation
return Ok(());
}
}
}

View File

@@ -65,8 +65,7 @@ impl Network {
ps.peek_exact(&mut first_packet),
)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
.map_err(map_to_string)?;
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
.await
@@ -82,8 +81,7 @@ impl Network {
for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr)
.await
.map_err(logthru_net!())?
.await?
{
return Ok(Some(nc));
}

View File

@@ -39,15 +39,11 @@ impl ProtocolNetworkConnection {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await

View File

@@ -31,19 +31,14 @@ impl RawTcpNetworkConnection {
self.descriptor.clone()
}
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
self.stream.clone().close().await.map_err(map_to_string)?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
}
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
@@ -54,23 +49,17 @@ impl RawTcpNetworkConnection {
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
stream
.write_all(&header)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
stream
.write_all(&message)
.await
.map_err(map_to_string)
.map_err(logthru_net!())
stream.write_all(&header).await.map_err(map_to_string)?;
stream.write_all(&message).await.map_err(map_to_string)
}
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let stream = self.stream.clone();
Self::send_internal(stream, message).await
}
#[instrument(level="trace", err, skip(self), fields(message.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4];
@@ -90,6 +79,8 @@ impl RawTcpNetworkConnection {
let mut out: Vec<u8> = vec![0u8; len];
stream.read_exact(&mut out).await.map_err(map_to_string)?;
tracing::Span::current().record("message.len", &out.len());
Ok(out)
}
}
@@ -120,6 +111,7 @@ impl RawTcpProtocolHandler {
}
}
#[instrument(level = "trace", err, skip(self, stream, tcp_stream))]
async fn on_accept_async(
self,
stream: AsyncPeekStream,
@@ -151,6 +143,7 @@ impl RawTcpProtocolHandler {
Ok(Some(conn))
}
#[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
@@ -191,6 +184,7 @@ impl RawTcpProtocolHandler {
Ok(conn)
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(
socket_addr: SocketAddr,
data: Vec<u8>,

View File

@@ -10,6 +10,7 @@ impl RawUdpProtocolHandler {
Self { socket }
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn recv_message(
&self,
data: &mut [u8],
@@ -35,9 +36,13 @@ impl RawUdpProtocolHandler {
peer_addr,
SocketAddress::from_socket_addr(local_socket_addr),
);
tracing::Span::current().record("ret.len", &size);
tracing::Span::current().record("ret.from", &format!("{:?}", descriptor).as_str());
Ok((size, descriptor))
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));

View File

@@ -49,21 +49,17 @@ where
self.descriptor.clone()
}
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
self.stream.clone().close().await.map_err(map_to_string)?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
}
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned());
@@ -76,6 +72,7 @@ where
.map_err(logthru_net!(error "failed to send websocket message"))
}
#[instrument(level="trace", err, skip(self), fields(message.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let out = match self.stream.clone().next().await {
Some(Ok(Message::Binary(v))) => v,
@@ -86,12 +83,13 @@ where
return Err(e.to_string()).map_err(logthru_net!(error));
}
None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
return Err("WS stream closed".to_owned());
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
tracing::Span::current().record("message.len", &out.len());
Ok(out)
}
}
@@ -137,6 +135,7 @@ impl WebsocketProtocolHandler {
}
}
#[instrument(level = "trace", err, skip(self, ps, tcp_stream))]
pub async fn on_accept_async(
self,
ps: AsyncPeekStream,
@@ -156,9 +155,9 @@ impl WebsocketProtocolHandler {
Ok(_) => (),
Err(e) => {
if e.kind() == io::ErrorKind::TimedOut {
return Err(e).map_err(map_to_string).map_err(logthru_net!());
return Err(e).map_err(map_to_string);
}
return Err(e).map_err(map_to_string).map_err(logthru_net!(error));
return Err(e).map_err(map_to_string);
}
}
@@ -237,10 +236,7 @@ impl WebsocketProtocolHandler {
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
// See what local address we ended up with
let actual_local_addr = tcp_stream
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!())?;
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
// Make our connection descriptor
let descriptor = ConnectionDescriptor::new(
@@ -274,6 +270,7 @@ impl WebsocketProtocolHandler {
}
}
#[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
@@ -281,6 +278,7 @@ impl WebsocketProtocolHandler {
Self::connect_internal(local_address, dial_info).await
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());