fix cooperative cancellation
This commit is contained in:
@@ -279,20 +279,14 @@ impl Network {
|
||||
let res = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
|
||||
}
|
||||
};
|
||||
if res.is_ok() {
|
||||
@@ -324,10 +318,7 @@ impl Network {
|
||||
descriptor
|
||||
);
|
||||
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?;
|
||||
ph.clone().send_message(data, peer_socket_addr).await?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -345,7 +336,7 @@ impl Network {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send_async(data).await.map_err(logthru_net!())?;
|
||||
conn.send_async(data).await?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -372,10 +363,7 @@ impl Network {
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
let res = ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!());
|
||||
let res = ph.send_message(data, peer_socket_addr).await;
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -383,8 +371,7 @@ impl Network {
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
@@ -394,7 +381,7 @@ impl Network {
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?;
|
||||
|
||||
let res = conn.send_async(data).await.map_err(logthru_net!(error));
|
||||
let res = conn.send_async(data).await;
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -414,43 +401,51 @@ impl Network {
|
||||
// initialize interfaces
|
||||
let mut interfaces = NetworkInterfaces::new();
|
||||
interfaces.refresh().await?;
|
||||
self.inner.lock().interfaces = interfaces;
|
||||
|
||||
// get protocol config
|
||||
let protocol_config = {
|
||||
let c = self.config.get();
|
||||
let mut inbound = ProtocolSet::new();
|
||||
let mut inner = self.inner.lock();
|
||||
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
inbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
|
||||
inbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
|
||||
inbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
|
||||
inbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
// Create stop source
|
||||
inner.stop_source = Some(StopSource::new());
|
||||
inner.interfaces = interfaces;
|
||||
|
||||
let mut outbound = ProtocolSet::new();
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
outbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
|
||||
outbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
|
||||
outbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
|
||||
outbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
// get protocol config
|
||||
let protocol_config = {
|
||||
let c = self.config.get();
|
||||
let mut inbound = ProtocolSet::new();
|
||||
|
||||
ProtocolConfig { inbound, outbound }
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
inbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
|
||||
inbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
|
||||
inbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
|
||||
inbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
let mut outbound = ProtocolSet::new();
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
outbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
|
||||
outbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
|
||||
outbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
|
||||
outbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
ProtocolConfig { inbound, outbound }
|
||||
};
|
||||
inner.protocol_config = Some(protocol_config);
|
||||
protocol_config
|
||||
};
|
||||
self.inner.lock().protocol_config = Some(protocol_config);
|
||||
|
||||
// start listeners
|
||||
if protocol_config.inbound.contains(ProtocolType::UDP) {
|
||||
@@ -503,28 +498,32 @@ impl Network {
|
||||
|
||||
#[instrument(level = "debug", skip_all)]
|
||||
pub async fn shutdown(&self) {
|
||||
info!("stopping network");
|
||||
debug!("starting low level network shutdown");
|
||||
|
||||
let network_manager = self.network_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Stop all tasks
|
||||
debug!("stopping update network class task");
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.stop().await {
|
||||
error!("update_network_class_task not cancelled: {}", e);
|
||||
}
|
||||
|
||||
let mut unord = FuturesUnordered::new();
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
// take the join handles out
|
||||
for h in inner.join_handles.drain(..) {
|
||||
unord.push(h);
|
||||
}
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
}
|
||||
debug!("stopping {} low level network tasks", unord.len());
|
||||
// Wait for everything to stop
|
||||
while unord.next().await.is_some() {}
|
||||
|
||||
debug!("clearing dial info");
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||
@@ -532,7 +531,7 @@ impl Network {
|
||||
// Reset state including network class
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
|
||||
info!("network stopped");
|
||||
debug!("finished low level network shutdown");
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::*;
|
||||
use futures_util::stream::FuturesUnordered;
|
||||
use futures_util::FutureExt;
|
||||
use stop_token::future::FutureExt as StopTokenFutureExt;
|
||||
|
||||
struct DetectedPublicDialInfo {
|
||||
dial_info: DialInfo,
|
||||
@@ -584,20 +585,32 @@ impl Network {
|
||||
// Wait for all discovery futures to complete and collect contexts
|
||||
let mut contexts = Vec::<DiscoveryContext>::new();
|
||||
let mut network_class = Option::<NetworkClass>::None;
|
||||
while let Some(ctxvec) = unord.next().await {
|
||||
if let Some(ctxvec) = ctxvec {
|
||||
for ctx in ctxvec {
|
||||
if let Some(nc) = ctx.inner.lock().detected_network_class {
|
||||
if let Some(last_nc) = network_class {
|
||||
if nc < last_nc {
|
||||
network_class = Some(nc);
|
||||
loop {
|
||||
match unord.next().timeout_at(stop_token.clone()).await {
|
||||
Ok(Some(ctxvec)) => {
|
||||
if let Some(ctxvec) = ctxvec {
|
||||
for ctx in ctxvec {
|
||||
if let Some(nc) = ctx.inner.lock().detected_network_class {
|
||||
if let Some(last_nc) = network_class {
|
||||
if nc < last_nc {
|
||||
network_class = Some(nc);
|
||||
}
|
||||
} else {
|
||||
network_class = Some(nc);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
network_class = Some(nc);
|
||||
|
||||
contexts.push(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
contexts.push(ctx);
|
||||
}
|
||||
Ok(None) => {
|
||||
// Normal completion
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
// Stop token, exit early without error propagation
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,8 +65,7 @@ impl Network {
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
.map_err(map_to_string)?;
|
||||
|
||||
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
|
||||
.await
|
||||
@@ -82,8 +81,7 @@ impl Network {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), tcp_stream.clone(), addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
.await?
|
||||
{
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
|
||||
@@ -39,15 +39,11 @@ impl ProtocolNetworkConnection {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data).await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
|
||||
|
||||
@@ -31,19 +31,14 @@ impl RawTcpNetworkConnection {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
.clone()
|
||||
.close()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
@@ -54,23 +49,17 @@ impl RawTcpNetworkConnection {
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
stream
|
||||
.write_all(&message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
stream.write_all(&header).await.map_err(map_to_string)?;
|
||||
stream.write_all(&message).await.map_err(map_to_string)
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
let stream = self.stream.clone();
|
||||
Self::send_internal(stream, message).await
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(message.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
@@ -90,6 +79,8 @@ impl RawTcpNetworkConnection {
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
stream.read_exact(&mut out).await.map_err(map_to_string)?;
|
||||
|
||||
tracing::Span::current().record("message.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@@ -120,6 +111,7 @@ impl RawTcpProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, stream, tcp_stream))]
|
||||
async fn on_accept_async(
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
@@ -151,6 +143,7 @@ impl RawTcpProtocolHandler {
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err)]
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
@@ -191,6 +184,7 @@ impl RawTcpProtocolHandler {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
|
||||
@@ -10,6 +10,7 @@ impl RawUdpProtocolHandler {
|
||||
Self { socket }
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn recv_message(
|
||||
&self,
|
||||
data: &mut [u8],
|
||||
@@ -35,9 +36,13 @@ impl RawUdpProtocolHandler {
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_socket_addr),
|
||||
);
|
||||
|
||||
tracing::Span::current().record("ret.len", &size);
|
||||
tracing::Span::current().record("ret.from", &format!("{:?}", descriptor).as_str());
|
||||
Ok((size, descriptor))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
|
||||
|
||||
@@ -49,21 +49,17 @@ where
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
.clone()
|
||||
.close()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
@@ -76,6 +72,7 @@ where
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(message.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let out = match self.stream.clone().next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
@@ -86,12 +83,13 @@ where
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
|
||||
return Err("WS stream closed".to_owned());
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
tracing::Span::current().record("message.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@@ -137,6 +135,7 @@ impl WebsocketProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, ps, tcp_stream))]
|
||||
pub async fn on_accept_async(
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
@@ -156,9 +155,9 @@ impl WebsocketProtocolHandler {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::TimedOut {
|
||||
return Err(e).map_err(map_to_string).map_err(logthru_net!());
|
||||
return Err(e).map_err(map_to_string);
|
||||
}
|
||||
return Err(e).map_err(map_to_string).map_err(logthru_net!(error));
|
||||
return Err(e).map_err(map_to_string);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,10 +236,7 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
|
||||
|
||||
// Make our connection descriptor
|
||||
let descriptor = ConnectionDescriptor::new(
|
||||
@@ -274,6 +270,7 @@ impl WebsocketProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err)]
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
@@ -281,6 +278,7 @@ impl WebsocketProtocolHandler {
|
||||
Self::connect_internal(local_address, dial_info).await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
|
||||
Reference in New Issue
Block a user