This commit is contained in:
John Smith
2022-08-19 12:49:18 -04:00
parent 568a308c82
commit 6e34bdd420
7 changed files with 116 additions and 573 deletions

View File

@@ -5,6 +5,7 @@ use crate::routing_table::*;
use connection_manager::*;
use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*;
use std::io;
/////////////////////////////////////////////////////////////////
@@ -59,8 +60,8 @@ impl Network {
) -> EyreResult<NetworkResult<()>> {
let data_len = data.len();
let timeout_ms = {
let c = self.config().get();
c.network.connection_initial_timeout_ms;
let c = self.config.get();
c.network.connection_initial_timeout_ms
};
match dial_info.protocol_type() {
@@ -71,18 +72,19 @@ impl Network {
bail!("no support for TCP protocol")
}
ProtocolType::WS | ProtocolType::WSS => {
let pnc = WebsocketProtocolHandler::connect(None, &dial_info, timeout_ms)
.await
.wrap_err("connect failure")?;
pnc.send(data).await.wrap_err("send failure")?;
let pnc =
network_result_try!(WebsocketProtocolHandler::connect(&dial_info, timeout_ms)
.await
.wrap_err("connect failure")?);
network_result_try!(pnc.send(data).await.wrap_err("send failure")?);
}
};
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
Ok(())
Ok(NetworkResult::Value(()))
}
// Send data to a dial info, unbound, using a new connection from a random port
@@ -90,7 +92,7 @@ impl Network {
// This creates a short-lived connection in the case of connection-oriented protocols
// for the purpose of sending this one message.
// This bypasses the connection table as it is not a 'node to node' connection.
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len(), ret.len))]
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))]
pub async fn send_recv_data_unbound_to_dial_info(
&self,
dial_info: DialInfo,
@@ -99,8 +101,8 @@ impl Network {
) -> EyreResult<NetworkResult<Vec<u8>>> {
let data_len = data.len();
let connect_timeout_ms = {
let c = self.config().get();
c.network.connection_initial_timeout_ms;
let c = self.config.get();
c.network.connection_initial_timeout_ms
};
match dial_info.protocol_type() {
@@ -111,40 +113,29 @@ impl Network {
bail!("no support for TCP protocol")
}
ProtocolType::WS | ProtocolType::WSS => {
let pnc = match dial_info.protocol_type() {
let pnc = network_result_try!(match dial_info.protocol_type() {
ProtocolType::UDP => unreachable!(),
ProtocolType::TCP => unreachable!(),
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::connect(None, &dial_info, connect_timeout_ms)
WebsocketProtocolHandler::connect(&dial_info, connect_timeout_ms)
.await
.wrap_err("connect failure")?
}
};
});
pnc.send(data).await.wrap_err("send failure")?;
network_result_try!(pnc.send(data).await.wrap_err("send failure")?);
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
let out = timeout(timeout_ms, pnc.recv())
let out = network_result_try!(network_result_try!(timeout(timeout_ms, pnc.recv())
.await
.into_timeout_or()
.into_result()
.wrap_err("recv failure")?;
.into_network_result())
.wrap_err("recv failure")?);
tracing::Span::current().record(
"ret.timeout_or",
&match out {
TimeoutOr::<Vec<u8>>::Value(ref v) => format!("Value(len={})", v.len()),
TimeoutOr::<Vec<u8>>::Timeout => "Timeout".to_owned(),
},
);
self.network_manager()
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
if let TimeoutOr::Value(out) = &out {
self.network_manager()
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
}
Ok(out)
Ok(NetworkResult::Value(out))
}
}
}
@@ -171,19 +162,27 @@ impl Network {
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
// connection exists, send over it
conn.send_async(data).await?;
match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => {
// Network accounting
self.network_manager().stats_packet_sent(
descriptor.remote().to_socket_addr().ip(),
data_len as u64,
);
// Network accounting
self.network_manager()
.stats_packet_sent(descriptor.remote().to_socket_addr().ip(), data_len as u64);
// Data was consumed
Ok(None)
} else {
// Connection or didn't exist
// Pass the data back out so we don't own it any more
Ok(Some(data))
// Data was consumed
return Ok(None);
}
ConnectionHandleSendResult::NotSent(data) => {
// Couldn't send
// Pass the data back out so we don't own it any more
return Ok(Some(data));
}
}
}
// Connection didn't exist
// Pass the data back out so we don't own it any more
Ok(Some(data))
}
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))]
@@ -203,7 +202,7 @@ impl Network {
// Handle connection-oriented protocols
let conn = network_result_try!(
self.connection_manager()
.get_or_create_connection(Some(local_addr), dial_info.clone())
.get_or_create_connection(None, dial_info.clone())
.await?
);
@@ -214,11 +213,11 @@ impl Network {
)));
}
let connection_descriptor = conn.connection_descriptor();
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
Ok(NetworkResult::value(connection_descriptor))
}
@@ -228,8 +227,8 @@ impl Network {
// get protocol config
self.inner.lock().protocol_config = Some({
let c = self.config.get();
let inbound = ProtocolSet::new();
let mut outbound = ProtocolSet::new();
let inbound = ProtocolTypeSet::new();
let mut outbound = ProtocolTypeSet::new();
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
@@ -239,10 +238,15 @@ impl Network {
}
// XXX: See issue #92
let family_global = AddressSet::all();
let family_local = AddressSet::all();
let family_global = AddressTypeSet::all();
let family_local = AddressTypeSet::all();
ProtocolConfig { inbound, outbound, family_global, family_local }
ProtocolConfig {
inbound,
outbound,
family_global,
family_local,
}
});
self.inner.lock().network_started = true;

View File

@@ -15,8 +15,9 @@ pub enum ProtocolNetworkConnection {
impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
_local_address: Option<SocketAddr>,
dial_info: &DialInfo,
timeout_ms: u32,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
@@ -26,7 +27,7 @@ impl ProtocolNetworkConnection {
panic!("TCP dial info is not supported on WASM targets");
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
ws::WebsocketProtocolHandler::connect(dial_info, timeout_ms).await
}
}
}

View File

@@ -1,8 +1,8 @@
use super::*;
use futures_util::{SinkExt, StreamExt};
use send_wrapper::*;
use std::io;
use ws_stream_wasm::*;
use send_wrapper::*;
struct WebsocketNetworkConnectionInner {
_ws_meta: WsMeta,
@@ -45,33 +45,46 @@ impl WebsocketNetworkConnection {
// self.inner.ws_meta.close().await.map_err(to_io).map(drop)
// }
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
if message.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large WS message");
}
self.inner
.ws_stream
.clone()
.send(WsMessage::Binary(message))
.await
.map_err(to_io)
let out = SendWrapper::new(
self.inner
.ws_stream
.clone()
.send(WsMessage::Binary(message)),
)
.await
.map_err(to_io)
.into_network_result()?;
tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out)
}
#[instrument(level = "trace", err, skip(self), fields(ret.len))]
pub async fn recv(&self) -> io::Result<Vec<u8>> {
#[instrument(level = "trace", err, skip(self), fields(network_result, ret.len))]
pub async fn recv(&self) -> io::Result<NetworkResult<Vec<u8>>> {
let out = match SendWrapper::new(self.inner.ws_stream.clone().next()).await {
Some(WsMessage::Binary(v)) => v,
Some(_) => {
bail_io_error_other!("Unexpected WS message type");
Some(WsMessage::Binary(v)) => {
if v.len() > MAX_MESSAGE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"too large ws message",
));
}
NetworkResult::Value(v)
}
Some(_) => NetworkResult::NoConnection(io::Error::new(
io::ErrorKind::ConnectionReset,
"Unexpected WS message type",
)),
None => {
bail_io_error_other!("WS stream closed");
}
};
if out.len() > MAX_MESSAGE_SIZE {
bail_io_error_other!("sending too large WS message")
}
tracing::Span::current().record("network_result", &tracing::field::display(&out));
Ok(out)
}
}
@@ -84,13 +97,11 @@ pub struct WebsocketProtocolHandler {}
impl WebsocketProtocolHandler {
#[instrument(level = "trace", err)]
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: &DialInfo,
) -> io::Result<ProtocolNetworkConnection> {
assert!(local_address.is_none());
timeout_ms: u32,
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
// Split dial info up
let (_tls, scheme) = match dial_info {
let (tls, scheme) = match dial_info {
DialInfo::WS(_) => (false, "ws"),
DialInfo::WSS(_) => (true, "wss"),
_ => panic!("invalid dialinfo for WS/WSS protocol"),
@@ -101,15 +112,23 @@ impl WebsocketProtocolHandler {
bail_io_error_other!("invalid websocket url scheme");
}
let fut = spawn_local(WsMeta::connect(request, None));
let (wsmeta, wsio) = fut.await.map_err(to_io)?;
let fut = SendWrapper::new(timeout(timeout_ms, async move {
WsMeta::connect(request, None).await.map_err(to_io)
}));
let (wsmeta, wsio) = network_result_try!(network_result_try!(fut
.await
.into_network_result())
.into_network_result()?);
// Make our connection descriptor
Ok(WebsocketNetworkConnection::new(
let wnc = WebsocketNetworkConnection::new(
ConnectionDescriptor::new_no_local(dial_info.to_peer_address())
.map_err(|e| io::Error::new(io::ErrorKind::AddrNotAvailable, e))?,
wsmeta,
wsio,
))
);
Ok(NetworkResult::Value(ProtocolNetworkConnection::Ws(wnc)))
}
}

View File

@@ -39,6 +39,7 @@ impl<T> IoNetworkResultExt<T> for io::Result<T> {
},
#[cfg(not(feature = "io_error_more"))]
Err(e) => {
#[cfg(not(target_arch = "wasm32"))]
if let Some(os_err) = e.raw_os_error() {
if os_err == libc::EHOSTUNREACH || os_err == libc::ENETUNREACH {
return Ok(NetworkResult::NoConnection(e));
@@ -93,6 +94,7 @@ impl<T> FoldedNetworkResultExt<T> for io::Result<TimeoutOr<T>> {
},
#[cfg(not(feature = "io_error_more"))]
Err(e) => {
#[cfg(not(target_arch = "wasm32"))]
if let Some(os_err) = e.raw_os_error() {
if os_err == libc::EHOSTUNREACH || os_err == libc::ENETUNREACH {
return Ok(NetworkResult::NoConnection(e));
@@ -126,6 +128,7 @@ impl<T> FoldedNetworkResultExt<T> for io::Result<NetworkResult<T>> {
},
#[cfg(not(feature = "io_error_more"))]
Err(e) => {
#[cfg(not(target_arch = "wasm32"))]
if let Some(os_err) = e.raw_os_error() {
if os_err == libc::EHOSTUNREACH || os_err == libc::ENETUNREACH {
return Ok(NetworkResult::NoConnection(e));