refactor
This commit is contained in:
@@ -43,7 +43,6 @@ struct NetworkInner {
|
||||
tcp_port: u16,
|
||||
ws_port: u16,
|
||||
wss_port: u16,
|
||||
interfaces: NetworkInterfaces,
|
||||
// udp
|
||||
bound_first_udp: BTreeMap<u16, Option<(socket2::Socket, socket2::Socket)>>,
|
||||
inbound_udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
|
||||
@@ -60,8 +59,11 @@ struct NetworkUnlockedInner {
|
||||
routing_table: RoutingTable,
|
||||
network_manager: NetworkManager,
|
||||
connection_manager: ConnectionManager,
|
||||
// Network
|
||||
interfaces: NetworkInterfaces,
|
||||
// Background processes
|
||||
update_network_class_task: TickTask<EyreReport>,
|
||||
network_interfaces_task: TickTask<EyreReport>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -85,7 +87,6 @@ impl Network {
|
||||
tcp_port: 0u16,
|
||||
ws_port: 0u16,
|
||||
wss_port: 0u16,
|
||||
interfaces: NetworkInterfaces::new(),
|
||||
bound_first_udp: BTreeMap::new(),
|
||||
inbound_udp_protocol_handlers: BTreeMap::new(),
|
||||
outbound_udpv4_protocol_handler: None,
|
||||
@@ -105,7 +106,9 @@ impl Network {
|
||||
network_manager,
|
||||
routing_table,
|
||||
connection_manager,
|
||||
interfaces: NetworkInterfaces::new(),
|
||||
update_network_class_task: TickTask::new(1),
|
||||
network_interfaces_task: TickTask::new(5),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +136,15 @@ impl Network {
|
||||
Box::pin(this2.clone().update_network_class_task_routine(s, l, t))
|
||||
});
|
||||
}
|
||||
// Set network interfaces tick task
|
||||
{
|
||||
let this2 = this.clone();
|
||||
this.unlocked_inner
|
||||
.network_interfaces_task
|
||||
.set_routine(move |s, l, t| {
|
||||
Box::pin(this2.clone().network_interfaces_task_routine(s, l, t))
|
||||
});
|
||||
}
|
||||
|
||||
this
|
||||
}
|
||||
@@ -219,11 +231,11 @@ impl Network {
|
||||
inner.join_handles.push(jh);
|
||||
}
|
||||
|
||||
fn translate_unspecified_address(inner: &NetworkInner, from: &SocketAddr) -> Vec<SocketAddr> {
|
||||
fn translate_unspecified_address(&self, from: &SocketAddr) -> Vec<SocketAddr> {
|
||||
if !from.ip().is_unspecified() {
|
||||
vec![*from]
|
||||
} else {
|
||||
inner
|
||||
self.unlocked_inner
|
||||
.interfaces
|
||||
.best_addresses()
|
||||
.iter()
|
||||
@@ -259,19 +271,17 @@ impl Network {
|
||||
where
|
||||
F: FnOnce(&[IpAddr]) -> R,
|
||||
{
|
||||
let inner = self.inner.lock();
|
||||
inner.interfaces.with_best_addresses(f)
|
||||
self.unlocked_inner.interfaces.with_best_addresses(f)
|
||||
}
|
||||
|
||||
// See if our interface addresses have changed, if so we need to punt the network
|
||||
// and redo all our addresses. This is overkill, but anything more accurate
|
||||
// would require inspection of routing tables that we dont want to bother with
|
||||
pub async fn check_interface_addresses(&self) -> EyreResult<bool> {
|
||||
let mut inner = self.inner.lock();
|
||||
if !inner.interfaces.refresh().await? {
|
||||
async fn check_interface_addresses(&self) -> EyreResult<bool> {
|
||||
if !self.unlocked_inner.interfaces.refresh().await? {
|
||||
return Ok(false);
|
||||
}
|
||||
inner.network_needs_restart = true;
|
||||
self.inner.lock().network_needs_restart = true;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
@@ -286,8 +296,13 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> EyreResult<()> {
|
||||
) -> EyreResult<NetworkResult<()>> {
|
||||
let data_len = data.len();
|
||||
let connect_timeout_ms = {
|
||||
let c = self.config.get();
|
||||
c.network.connection_initial_timeout_ms
|
||||
};
|
||||
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
@@ -296,19 +311,28 @@ impl Network {
|
||||
.wrap_err("create socket failure")?;
|
||||
h.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map(NetworkResult::Value)
|
||||
.wrap_err("send message failure")?;
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
let pnc = RawTcpProtocolHandler::connect(None, peer_socket_addr)
|
||||
.await
|
||||
.wrap_err("connect failure")?;
|
||||
let pnc = network_result_try!(RawTcpProtocolHandler::connect(
|
||||
None,
|
||||
peer_socket_addr,
|
||||
connect_timeout_ms
|
||||
)
|
||||
.await
|
||||
.wrap_err("connect failure")?);
|
||||
pnc.send(data).await.wrap_err("send failure")?;
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
let pnc = WebsocketProtocolHandler::connect(None, &dial_info)
|
||||
.await
|
||||
.wrap_err("connect failure")?;
|
||||
let pnc = network_result_try!(WebsocketProtocolHandler::connect(
|
||||
None,
|
||||
&dial_info,
|
||||
connect_timeout_ms
|
||||
)
|
||||
.await
|
||||
.wrap_err("connect failure")?);
|
||||
pnc.send(data).await.wrap_err("send failure")?;
|
||||
}
|
||||
}
|
||||
@@ -316,7 +340,7 @@ impl Network {
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
|
||||
Ok(())
|
||||
Ok(NetworkResult::Value(()))
|
||||
}
|
||||
|
||||
// Send data to a dial info, unbound, using a new connection from a random port
|
||||
@@ -324,14 +348,19 @@ impl Network {
|
||||
// This creates a short-lived connection in the case of connection-oriented protocols
|
||||
// for the purpose of sending this one message.
|
||||
// This bypasses the connection table as it is not a 'node to node' connection.
|
||||
#[instrument(level="trace", err, skip(self, data), fields(ret.timeout_or, data.len = data.len()))]
|
||||
#[instrument(level="trace", err, skip(self, data), fields(data.len = data.len()))]
|
||||
pub async fn send_recv_data_unbound_to_dial_info(
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> EyreResult<TimeoutOr<Vec<u8>>> {
|
||||
) -> EyreResult<NetworkResult<Vec<u8>>> {
|
||||
let data_len = data.len();
|
||||
let connect_timeout_ms = {
|
||||
let c = self.config.get();
|
||||
c.network.connection_initial_timeout_ms
|
||||
};
|
||||
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
@@ -346,18 +375,11 @@ impl Network {
|
||||
|
||||
// receive single response
|
||||
let mut out = vec![0u8; MAX_MESSAGE_SIZE];
|
||||
let timeout_or_ret = timeout(timeout_ms, h.recv_message(&mut out))
|
||||
.await
|
||||
.into_timeout_or()
|
||||
.into_result()
|
||||
let (recv_len, recv_addr) =
|
||||
network_result_try!(timeout(timeout_ms, h.recv_message(&mut out))
|
||||
.await
|
||||
.into_network_result())
|
||||
.wrap_err("recv_message failure")?;
|
||||
let (recv_len, recv_addr) = match timeout_or_ret {
|
||||
TimeoutOr::Value(v) => v,
|
||||
TimeoutOr::Timeout => {
|
||||
tracing::Span::current().record("ret.timeout_or", &"Timeout".to_owned());
|
||||
return Ok(TimeoutOr::Timeout);
|
||||
}
|
||||
};
|
||||
|
||||
let recv_socket_addr = recv_addr.remote_address().to_socket_addr();
|
||||
self.network_manager()
|
||||
@@ -368,48 +390,37 @@ impl Network {
|
||||
bail!("wrong address");
|
||||
}
|
||||
out.resize(recv_len, 0u8);
|
||||
Ok(TimeoutOr::Value(out))
|
||||
Ok(NetworkResult::Value(out))
|
||||
}
|
||||
ProtocolType::TCP | ProtocolType::WS | ProtocolType::WSS => {
|
||||
let pnc = match dial_info.protocol_type() {
|
||||
let pnc = network_result_try!(match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => unreachable!(),
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawTcpProtocolHandler::connect(None, peer_socket_addr)
|
||||
RawTcpProtocolHandler::connect(None, peer_socket_addr, connect_timeout_ms)
|
||||
.await
|
||||
.wrap_err("connect failure")?
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::connect(None, &dial_info)
|
||||
WebsocketProtocolHandler::connect(None, &dial_info, connect_timeout_ms)
|
||||
.await
|
||||
.wrap_err("connect failure")?
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
pnc.send(data).await.wrap_err("send failure")?;
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
|
||||
let out = timeout(timeout_ms, pnc.recv())
|
||||
let out = network_result_try!(network_result_try!(timeout(timeout_ms, pnc.recv())
|
||||
.await
|
||||
.into_timeout_or()
|
||||
.into_result()
|
||||
.wrap_err("recv failure")?;
|
||||
.into_network_result())
|
||||
.wrap_err("recv failure")?);
|
||||
|
||||
tracing::Span::current().record(
|
||||
"ret.timeout_or",
|
||||
&match out {
|
||||
TimeoutOr::<Vec<u8>>::Value(ref v) => format!("Value(len={})", v.len()),
|
||||
TimeoutOr::<Vec<u8>>::Timeout => "Timeout".to_owned(),
|
||||
},
|
||||
);
|
||||
self.network_manager()
|
||||
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
|
||||
|
||||
if let TimeoutOr::Value(out) = &out {
|
||||
self.network_manager()
|
||||
.stats_packet_rcvd(dial_info.to_ip_addr(), out.len() as u64);
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
Ok(NetworkResult::Value(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -449,21 +460,27 @@ impl Network {
|
||||
// Try to send to the exact existing connection if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
|
||||
// connection exists, send over it
|
||||
conn.send_async(data)
|
||||
.await
|
||||
.wrap_err("sending data to existing connection")?;
|
||||
match conn.send_async(data).await {
|
||||
ConnectionHandleSendResult::Sent => {
|
||||
// Network accounting
|
||||
self.network_manager().stats_packet_sent(
|
||||
descriptor.remote().to_socket_addr().ip(),
|
||||
data_len as u64,
|
||||
);
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(descriptor.remote().to_socket_addr().ip(), data_len as u64);
|
||||
|
||||
// Data was consumed
|
||||
Ok(None)
|
||||
} else {
|
||||
// Connection or didn't exist
|
||||
// Pass the data back out so we don't own it any more
|
||||
Ok(Some(data))
|
||||
// Data was consumed
|
||||
return Ok(None);
|
||||
}
|
||||
ConnectionHandleSendResult::NotSent(data) => {
|
||||
// Couldn't send
|
||||
// Pass the data back out so we don't own it any more
|
||||
return Ok(Some(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Connection didn't exist
|
||||
// Pass the data back out so we don't own it any more
|
||||
Ok(Some(data))
|
||||
}
|
||||
|
||||
// Send data directly to a dial info, possibly without knowing which node it is going to
|
||||
@@ -472,40 +489,42 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> EyreResult<()> {
|
||||
) -> EyreResult<NetworkResult<()>> {
|
||||
let data_len = data.len();
|
||||
// Handle connectionless protocol
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
// Handle connectionless protocol
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
let res = ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.wrap_err("failed to send data to dial info");
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(peer_socket_addr.ip(), data_len as u64);
|
||||
}
|
||||
return res;
|
||||
let ph = match self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
Some(ph) => ph,
|
||||
None => bail!("no appropriate UDP protocol handler for dial_info"),
|
||||
};
|
||||
network_result_try!(ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.into_network_result()
|
||||
.wrap_err("failed to send data to dial info")?);
|
||||
} else {
|
||||
// Handle connection-oriented protocols
|
||||
let local_addr = self.get_preferred_local_address(&dial_info);
|
||||
let conn = network_result_try!(
|
||||
self.connection_manager()
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?
|
||||
);
|
||||
|
||||
if let ConnectionHandleSendResult::NotSent(_) = conn.send_async(data).await {
|
||||
return Ok(NetworkResult::NoConnection(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"failed to send",
|
||||
)));
|
||||
}
|
||||
bail!("no appropriate UDP protocol handler for dial_info");
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
let local_addr = self.get_preferred_local_address(&dial_info);
|
||||
let conn = self
|
||||
.connection_manager()
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?;
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
|
||||
let res = conn.send_async(data).await;
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
}
|
||||
res
|
||||
Ok(NetworkResult::Value(()))
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
@@ -517,15 +536,13 @@ impl Network {
|
||||
#[instrument(level = "debug", err, skip_all)]
|
||||
pub async fn startup(&self) -> EyreResult<()> {
|
||||
// initialize interfaces
|
||||
let mut interfaces = NetworkInterfaces::new();
|
||||
interfaces.refresh().await?;
|
||||
self.unlocked_inner.interfaces.refresh().await?;
|
||||
|
||||
let protocol_config = {
|
||||
let mut inner = self.inner.lock();
|
||||
|
||||
// Create stop source
|
||||
inner.stop_source = Some(StopSource::new());
|
||||
inner.interfaces = interfaces;
|
||||
|
||||
// get protocol config
|
||||
let protocol_config = {
|
||||
@@ -666,6 +683,19 @@ impl Network {
|
||||
|
||||
//////////////////////////////////////////
|
||||
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
pub async fn network_interfaces_task_routine(
|
||||
self,
|
||||
stop_token: StopToken,
|
||||
_l: u64,
|
||||
_t: u64,
|
||||
) -> EyreResult<()> {
|
||||
if self.check_interface_addresses().await? {
|
||||
info!("interface addresses changed, restarting network");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn tick(&self) -> EyreResult<()> {
|
||||
let network_class = self.get_network_class().unwrap_or(NetworkClass::Invalid);
|
||||
let routing_table = self.routing_table();
|
||||
@@ -680,6 +710,12 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
// If we aren't resetting the network already,
|
||||
// check our network interfaces to see if they have changed
|
||||
if !self.needs_restart() {
|
||||
self.unlocked_inner.network_interfaces_task.tick().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -288,7 +288,7 @@ impl Network {
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
let addr = SocketAddr::new(ip_addr, port);
|
||||
let idi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
|
||||
let idi_addrs = self.translate_unspecified_address(&addr);
|
||||
|
||||
// see if we've already bound to this already
|
||||
// if not, spawn a listener
|
||||
|
@@ -214,7 +214,7 @@ impl Network {
|
||||
.inbound_udp_protocol_handlers
|
||||
.contains_key(&addr)
|
||||
{
|
||||
let idi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
|
||||
let idi_addrs = self.translate_unspecified_address(&addr);
|
||||
|
||||
self.clone().create_udp_inbound_socket(addr).await?;
|
||||
|
||||
|
@@ -22,16 +22,22 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: &DialInfo,
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
tcp::RawTcpProtocolHandler::connect(local_address, dial_info.to_socket_addr()).await
|
||||
tcp::RawTcpProtocolHandler::connect(
|
||||
local_address,
|
||||
dial_info.to_socket_addr(),
|
||||
timeout_ms,
|
||||
)
|
||||
.await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
|
||||
ws::WebsocketProtocolHandler::connect(local_address, dial_info, timeout_ms).await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -46,7 +52,7 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
// pub async fn close(&self) -> io::Result<()> {
|
||||
// pub async fn close(&self) -> io::Result<NetworkResult<()>> {
|
||||
// match self {
|
||||
// Self::Dummy(d) => d.close(),
|
||||
// Self::RawTcp(t) => t.close().await,
|
||||
@@ -56,7 +62,7 @@ impl ProtocolNetworkConnection {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::RawTcp(t) => t.send(message).await,
|
||||
@@ -65,7 +71,7 @@ impl ProtocolNetworkConnection {
|
||||
Self::Wss(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
pub async fn recv(&self) -> io::Result<NetworkResult<Vec<u8>>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::RawTcp(t) => t.recv().await,
|
||||
|
@@ -166,7 +166,11 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> io::Result<Socke
|
||||
}
|
||||
|
||||
// Non-blocking connect is tricky when you want to start with a prepared socket
|
||||
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> io::Result<TcpStream> {
|
||||
pub async fn nonblocking_connect(
|
||||
socket: Socket,
|
||||
addr: SocketAddr,
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<TimeoutOr<TcpStream>> {
|
||||
// Set for non blocking connect
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
@@ -185,9 +189,10 @@ pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> io::Result
|
||||
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
|
||||
|
||||
// The stream becomes writable when connected
|
||||
intf::timeout(2000, async_stream.writable())
|
||||
intf::timeout(timeout_ms, async_stream.writable())
|
||||
.await
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))??;
|
||||
.into_timeout_or()
|
||||
.into_result()?;
|
||||
|
||||
// Check low level error
|
||||
let async_stream = match async_stream.get_ref().take_error()? {
|
||||
@@ -198,9 +203,9 @@ pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> io::Result
|
||||
// Convert back to inner and then return async version
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
Ok(TcpStream::from(async_stream.into_inner()?))
|
||||
Ok(TimeoutOr::Value(TcpStream::from(async_stream.into_inner()?)))
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
Ok(TcpStream::from_std(async_stream.into_inner()?)?)
|
||||
Ok(TimeoutOr::Value(TcpStream::from_std(async_stream.into_inner()?)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -23,7 +23,7 @@ impl RawTcpNetworkConnection {
|
||||
}
|
||||
|
||||
// #[instrument(level = "trace", err, skip(self))]
|
||||
// pub async fn close(&mut self) -> io::Result<()> {
|
||||
// pub async fn close(&mut self) -> io::Result<NetworkResult<()>> {
|
||||
// // Make an attempt to flush the stream
|
||||
// self.stream.clone().close().await?;
|
||||
// // Then shut down the write side of the socket to effect a clean close
|
||||
@@ -40,7 +40,10 @@ impl RawTcpNetworkConnection {
|
||||
// }
|
||||
// }
|
||||
|
||||
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> io::Result<()> {
|
||||
async fn send_internal(
|
||||
stream: &mut AsyncPeekStream,
|
||||
message: Vec<u8>,
|
||||
) -> io::Result<NetworkResult<()>> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
bail_io_error_other!("sending too large TCP message");
|
||||
@@ -48,20 +51,29 @@ impl RawTcpNetworkConnection {
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
stream.write_all(&header).await?;
|
||||
stream.write_all(&message).await
|
||||
stream.write_all(&header).await.into_network_result()?;
|
||||
stream.write_all(&message).await.into_network_result()
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
#[instrument(level="trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
|
||||
let mut stream = self.stream.clone();
|
||||
Self::send_internal(&mut stream, message).await
|
||||
let out = Self::send_internal(&mut stream, message).await?;
|
||||
tracing::Span::current().record(
|
||||
"network_result",
|
||||
&match &out {
|
||||
NetworkResult::Timeout => "Timeout".to_owned(),
|
||||
NetworkResult::NoConnection(e) => format!("No connection: {}", e),
|
||||
NetworkResult::Value(()) => "Value(())".to_owned(),
|
||||
},
|
||||
);
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<Vec<u8>> {
|
||||
async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<NetworkResult<Vec<u8>>> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
stream.read_exact(&mut header).await?;
|
||||
stream.read_exact(&mut header).await.into_network_result()?;
|
||||
|
||||
if header[0] != b'V' || header[1] != b'L' {
|
||||
bail_io_error_other!("received invalid TCP frame header");
|
||||
@@ -72,16 +84,23 @@ impl RawTcpNetworkConnection {
|
||||
}
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
stream.read_exact(&mut out).await?;
|
||||
stream.read_exact(&mut out).await.into_network_result()?;
|
||||
|
||||
Ok(out)
|
||||
Ok(NetworkResult::Value(out))
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(ret.len))]
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
#[instrument(level = "trace", err, skip(self), fields(network_result))]
|
||||
pub async fn recv(&self) -> io::Result<NetworkResult<Vec<u8>>> {
|
||||
let mut stream = self.stream.clone();
|
||||
let out = Self::recv_internal(&mut stream).await?;
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
tracing::Span::current().record(
|
||||
"network_result",
|
||||
&match &out {
|
||||
NetworkResult::Timeout => "Timeout".to_owned(),
|
||||
NetworkResult::NoConnection(e) => format!("No connection: {}", e),
|
||||
NetworkResult::Value(v) => format!("Value(len={})", v.len()),
|
||||
},
|
||||
);
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@@ -142,7 +161,8 @@ impl RawTcpProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
socket_addr: SocketAddr,
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
|
||||
// Make a shared socket
|
||||
let socket = match local_address {
|
||||
Some(a) => new_bound_shared_tcp_socket(a)?,
|
||||
@@ -150,7 +170,9 @@ impl RawTcpProtocolHandler {
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
let ts = network_result_try!(nonblocking_connect(socket, socket_addr, timeout_ms)
|
||||
.await
|
||||
.folded()?);
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts.local_addr()?;
|
||||
@@ -170,77 +192,8 @@ impl RawTcpProtocolHandler {
|
||||
ps,
|
||||
));
|
||||
|
||||
Ok(conn)
|
||||
Ok(NetworkResult::Value(conn))
|
||||
}
|
||||
|
||||
// #[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
// pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
|
||||
// if data.len() > MAX_MESSAGE_SIZE {
|
||||
// bail_io_error_other!("sending too large unbound TCP message");
|
||||
// }
|
||||
// // Make a shared socket
|
||||
// let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// // Non-blocking connect to remote address
|
||||
// let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
|
||||
// // See what local address we ended up with and turn this into a stream
|
||||
// // let actual_local_address = ts
|
||||
// // .local_addr()
|
||||
// // .map_err(map_to_string)
|
||||
// // .map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
|
||||
// #[cfg(feature = "rt-tokio")]
|
||||
// let ts = ts.compat();
|
||||
// let mut ps = AsyncPeekStream::new(ts);
|
||||
|
||||
// // Send directly from the raw network connection
|
||||
// // this builds the connection and tears it down immediately after the send
|
||||
// RawTcpNetworkConnection::send_internal(&mut ps, data).await
|
||||
// }
|
||||
|
||||
// #[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.timeout_or))]
|
||||
// pub async fn send_recv_unbound_message(
|
||||
// socket_addr: SocketAddr,
|
||||
// data: Vec<u8>,
|
||||
// timeout_ms: u32,
|
||||
// ) -> io::Result<TimeoutOr<Vec<u8>>> {
|
||||
// if data.len() > MAX_MESSAGE_SIZE {
|
||||
// bail_io_error_other!("sending too large unbound TCP message");
|
||||
// }
|
||||
|
||||
// // Make a shared socket
|
||||
// let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// // Non-blocking connect to remote address
|
||||
// let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
|
||||
// // See what local address we ended up with and turn this into a stream
|
||||
// // let actual_local_address = ts
|
||||
// // .local_addr()
|
||||
// // .map_err(map_to_string)
|
||||
// // .map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
// #[cfg(feature = "rt-tokio")]
|
||||
// let ts = ts.compat();
|
||||
// let mut ps = AsyncPeekStream::new(ts);
|
||||
|
||||
// // Send directly from the raw network connection
|
||||
// // this builds the connection and tears it down immediately after the send
|
||||
// RawTcpNetworkConnection::send_internal(&mut ps, data).await?;
|
||||
// let out = timeout(timeout_ms, RawTcpNetworkConnection::recv_internal(&mut ps))
|
||||
// .await
|
||||
// .into_timeout_or()
|
||||
// .into_result()?;
|
||||
|
||||
// tracing::Span::current().record(
|
||||
// "ret.timeout_or",
|
||||
// &match out {
|
||||
// TimeoutOr::<Vec<u8>>::Value(ref v) => format!("Value(len={})", v.len()),
|
||||
// TimeoutOr::<Vec<u8>>::Timeout => "Timeout".to_owned(),
|
||||
// },
|
||||
// );
|
||||
// Ok(out)
|
||||
// }
|
||||
}
|
||||
|
||||
impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
@@ -248,7 +201,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
) -> SendPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -69,56 +69,4 @@ impl RawUdpProtocolHandler {
|
||||
let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
Ok(RawUdpProtocolHandler::new(Arc::new(socket)))
|
||||
}
|
||||
|
||||
// #[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.timeout_or))]
|
||||
// pub async fn send_recv_unbound_message(
|
||||
// socket_addr: SocketAddr,
|
||||
// data: Vec<u8>,
|
||||
// timeout_ms: u32,
|
||||
// ) -> io::Result<TimeoutOr<Vec<u8>>> {
|
||||
// if data.len() > MAX_MESSAGE_SIZE {
|
||||
// bail_io_error_other!("sending too large unbound UDP message");
|
||||
// }
|
||||
|
||||
// // get local wildcard address for bind
|
||||
// let local_socket_addr = match socket_addr {
|
||||
// SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
|
||||
// SocketAddr::V6(_) => {
|
||||
// SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||
// }
|
||||
// };
|
||||
|
||||
// // get unspecified bound socket
|
||||
// let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
// let len = socket.send_to(&data, socket_addr).await?;
|
||||
// if len != data.len() {
|
||||
// bail_io_error_other!("UDP partial unbound send");
|
||||
// }
|
||||
|
||||
// // receive single response
|
||||
// let mut out = vec![0u8; MAX_MESSAGE_SIZE];
|
||||
// let timeout_or_ret = timeout(timeout_ms, socket.recv_from(&mut out))
|
||||
// .await
|
||||
// .into_timeout_or()
|
||||
// .into_result()?;
|
||||
// let (len, from_addr) = match timeout_or_ret {
|
||||
// TimeoutOr::Value(v) => v,
|
||||
// TimeoutOr::Timeout => {
|
||||
// tracing::Span::current().record("ret.timeout_or", &"Timeout".to_owned());
|
||||
// return Ok(TimeoutOr::Timeout);
|
||||
// }
|
||||
// };
|
||||
|
||||
// // if the from address is not the same as the one we sent to, then drop this
|
||||
// if from_addr != socket_addr {
|
||||
// bail_io_error_other!(format!(
|
||||
// "Unbound response received from wrong address: addr={}",
|
||||
// from_addr,
|
||||
// ));
|
||||
// }
|
||||
// out.resize(len, 0u8);
|
||||
|
||||
// tracing::Span::current().record("ret.timeout_or", &format!("Value(len={})", out.len()));
|
||||
// Ok(TimeoutOr::Value(out))
|
||||
// }
|
||||
}
|
||||
|
@@ -80,33 +80,45 @@ where
|
||||
// .map_err(to_io)
|
||||
// }
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
#[instrument(level = "trace", err, skip(self, message), fields(network_result, message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<NetworkResult<()>> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
bail_io_error_other!("received too large WS message");
|
||||
}
|
||||
self.stream
|
||||
let out = self
|
||||
.stream
|
||||
.clone()
|
||||
.send(Message::binary(message))
|
||||
.await
|
||||
.map_err(to_io)
|
||||
.into_network_result()?;
|
||||
tracing::Span::current().record(
|
||||
"network_result",
|
||||
&match &out {
|
||||
NetworkResult::Timeout => "Timeout".to_owned(),
|
||||
NetworkResult::NoConnection(e) => format!("No connection: {}", e),
|
||||
NetworkResult::Value(()) => "Value(())".to_owned(),
|
||||
},
|
||||
);
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self), fields(ret.len))]
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
#[instrument(level = "trace", err, skip(self), fields(network_result, ret.len))]
|
||||
pub async fn recv(&self) -> io::Result<NetworkResult<Vec<u8>>> {
|
||||
let out = match self.stream.clone().next().await {
|
||||
Some(Ok(Message::Binary(v))) => {
|
||||
if v.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
io::ErrorKind::InvalidData,
|
||||
"too large ws message",
|
||||
));
|
||||
}
|
||||
v
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
return Err(io::Error::new(io::ErrorKind::ConnectionReset, "closeframe"))
|
||||
NetworkResult::Value(v)
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => NetworkResult::NoConnection(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"closeframe",
|
||||
)),
|
||||
Some(Ok(x)) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
@@ -114,15 +126,20 @@ where
|
||||
));
|
||||
}
|
||||
Some(Err(e)) => return Err(to_io(e)),
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"connection ended",
|
||||
))
|
||||
}
|
||||
None => NetworkResult::NoConnection(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"connection ended",
|
||||
)),
|
||||
};
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
tracing::Span::current().record(
|
||||
"network_result",
|
||||
&match &out {
|
||||
NetworkResult::Timeout => "Timeout".to_owned(),
|
||||
NetworkResult::NoConnection(e) => format!("No connection: {}", e),
|
||||
NetworkResult::Value(v) => format!("Value(len={})", v.len()),
|
||||
},
|
||||
);
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@@ -227,7 +244,8 @@ impl WebsocketProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: &DialInfo,
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
timeout_ms: u32,
|
||||
) -> io::Result<NetworkResult<ProtocolNetworkConnection>> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
@@ -253,7 +271,10 @@ impl WebsocketProtocolHandler {
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await?;
|
||||
let tcp_stream =
|
||||
network_result_try!(nonblocking_connect(socket, remote_socket_addr, timeout_ms)
|
||||
.await
|
||||
.folded()?);
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream.local_addr()?;
|
||||
@@ -274,16 +295,16 @@ impl WebsocketProtocolHandler {
|
||||
.await
|
||||
.map_err(to_io_error_other)?;
|
||||
|
||||
Ok(ProtocolNetworkConnection::Wss(
|
||||
Ok(NetworkResult::Value(ProtocolNetworkConnection::Wss(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
))
|
||||
)))
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream)
|
||||
.await
|
||||
.map_err(to_io_error_other)?;
|
||||
Ok(ProtocolNetworkConnection::Ws(
|
||||
Ok(NetworkResult::Value(ProtocolNetworkConnection::Ws(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
))
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,7 +314,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
) -> SendPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user