refactor
This commit is contained in:
@@ -61,7 +61,7 @@ struct NetworkUnlockedInner {
|
||||
network_manager: NetworkManager,
|
||||
connection_manager: ConnectionManager,
|
||||
// Background processes
|
||||
update_network_class_task: TickTask,
|
||||
update_network_class_task: TickTask<EyreReport>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -266,7 +266,7 @@ impl Network {
|
||||
// See if our interface addresses have changed, if so we need to punt the network
|
||||
// and redo all our addresses. This is overkill, but anything more accurate
|
||||
// would require inspection of routing tables that we dont want to bother with
|
||||
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
|
||||
pub async fn check_interface_addresses(&self) -> EyreResult<bool> {
|
||||
let mut inner = self.inner.lock();
|
||||
if !inner.interfaces.refresh().await? {
|
||||
return Ok(false);
|
||||
@@ -286,7 +286,7 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let data_len = data.len();
|
||||
let res = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
@@ -300,7 +300,8 @@ impl Network {
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
|
||||
}
|
||||
};
|
||||
}
|
||||
.wrap_err("low level network error");
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -320,7 +321,7 @@ impl Network {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> EyreResult<Vec<u8>> {
|
||||
let data_len = data.len();
|
||||
let out = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
@@ -358,7 +359,7 @@ impl Network {
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
) -> EyreResult<Option<Vec<u8>>> {
|
||||
let data_len = data.len();
|
||||
|
||||
// Handle connectionless protocol
|
||||
@@ -369,12 +370,10 @@ impl Network {
|
||||
&peer_socket_addr,
|
||||
&descriptor.local().map(|sa| sa.to_socket_addr()),
|
||||
) {
|
||||
log_net!(
|
||||
"send_data_to_existing_connection connectionless to {:?}",
|
||||
descriptor
|
||||
);
|
||||
|
||||
ph.clone().send_message(data, peer_socket_addr).await?;
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.wrap_err("sending data to existing conection")?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -389,10 +388,10 @@ impl Network {
|
||||
|
||||
// Try to send to the exact existing connection if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send_async(data).await?;
|
||||
conn.send_async(data)
|
||||
.await
|
||||
.wrap_err("sending data to existing connection")?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -413,13 +412,16 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let data_len = data.len();
|
||||
// Handle connectionless protocol
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
let res = ph.send_message(data, peer_socket_addr).await;
|
||||
let res = ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.wrap_err("failed to send data to dial info");
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -427,7 +429,7 @@ impl Network {
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
|
||||
bail!("no appropriate UDP protocol handler for dial_info");
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
@@ -453,7 +455,7 @@ impl Network {
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", err, skip_all)]
|
||||
pub async fn startup(&self) -> Result<(), String> {
|
||||
pub async fn startup(&self) -> EyreResult<()> {
|
||||
// initialize interfaces
|
||||
let mut interfaces = NetworkInterfaces::new();
|
||||
interfaces.refresh().await?;
|
||||
@@ -604,7 +606,7 @@ impl Network {
|
||||
|
||||
//////////////////////////////////////////
|
||||
|
||||
pub async fn tick(&self) -> Result<(), String> {
|
||||
pub async fn tick(&self) -> EyreResult<()> {
|
||||
let network_class = self.get_network_class().unwrap_or(NetworkClass::Invalid);
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
|
||||
@@ -250,7 +250,7 @@ impl DiscoveryContext {
|
||||
|
||||
// If we know we are not behind NAT, check our firewall status
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
pub async fn protocol_process_no_nat(&self) -> Result<(), String> {
|
||||
pub async fn protocol_process_no_nat(&self) -> EyreResult<()> {
|
||||
let (node_1, external_1_dial_info) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
@@ -281,7 +281,7 @@ impl DiscoveryContext {
|
||||
|
||||
// If we know we are behind NAT check what kind
|
||||
#[instrument(level = "trace", skip(self), ret, err)]
|
||||
pub async fn protocol_process_nat(&self) -> Result<bool, String> {
|
||||
pub async fn protocol_process_nat(&self) -> EyreResult<bool> {
|
||||
let (node_1, external_1_dial_info, external_1_address, protocol_type, address_type) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
@@ -375,7 +375,7 @@ impl Network {
|
||||
&self,
|
||||
context: &DiscoveryContext,
|
||||
protocol_type: ProtocolType,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let mut retry_count = {
|
||||
let c = self.config.get();
|
||||
c.network.restricted_nat_retries
|
||||
@@ -437,7 +437,7 @@ impl Network {
|
||||
&self,
|
||||
context: &DiscoveryContext,
|
||||
protocol_type: ProtocolType,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Start doing ipv6 protocol
|
||||
context.protocol_begin(protocol_type, AddressType::IPV6);
|
||||
|
||||
@@ -479,7 +479,7 @@ impl Network {
|
||||
stop_token: StopToken,
|
||||
_l: u64,
|
||||
_t: u64,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Ensure we aren't trying to update this without clearing it first
|
||||
let old_network_class = self.inner.lock().network_class;
|
||||
assert_eq!(old_network_class, None);
|
||||
|
||||
@@ -25,14 +25,14 @@ impl ListenerState {
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
impl Network {
|
||||
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
|
||||
fn get_or_create_tls_acceptor(&self) -> EyreResult<TlsAcceptor> {
|
||||
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
|
||||
return Ok(ts.clone());
|
||||
}
|
||||
|
||||
let server_config = self
|
||||
.load_server_config()
|
||||
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
|
||||
.wrap_err("Couldn't create TLS configuration")?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(server_config));
|
||||
self.inner.lock().tls_acceptor = Some(acceptor.clone());
|
||||
Ok(acceptor)
|
||||
@@ -45,12 +45,11 @@ impl Network {
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout_ms: u32,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> EyreResult<Option<ProtocolNetworkConnection>> {
|
||||
let tls_stream = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
|
||||
.wrap_err("TLS stream failed handshake")?;
|
||||
let ps = AsyncPeekStream::new(tls_stream);
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
@@ -63,8 +62,8 @@ impl Network {
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)?
|
||||
.map_err(map_to_string)?;
|
||||
.wrap_err("tls initial timeout")?
|
||||
.wrap_err("failed to peek tls stream")?;
|
||||
|
||||
self.try_handlers(ps, addr, protocol_handlers).await
|
||||
}
|
||||
@@ -74,9 +73,13 @@ impl Network {
|
||||
stream: AsyncPeekStream,
|
||||
addr: SocketAddr,
|
||||
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> EyreResult<Option<ProtocolNetworkConnection>> {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah.on_accept(stream.clone(), addr).await? {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), addr)
|
||||
.await
|
||||
.wrap_err("io error")?
|
||||
{
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
}
|
||||
@@ -114,7 +117,7 @@ impl Network {
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
// XXX limiting here instead for connection table? may be faster and avoids tls negotiation
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
@@ -185,7 +188,7 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<()> {
|
||||
// Get config
|
||||
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = {
|
||||
let c = self.config.get();
|
||||
@@ -196,11 +199,12 @@ impl Network {
|
||||
};
|
||||
|
||||
// Create a reusable socket with no linger time, and no delay
|
||||
let socket = new_bound_shared_tcp_socket(addr)?;
|
||||
let socket = new_bound_shared_tcp_socket(addr)
|
||||
.wrap_err("failed to create bound shared tcp socket")?;
|
||||
// Listen on the socket
|
||||
socket
|
||||
.listen(128)
|
||||
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
|
||||
.wrap_err("Couldn't listen on TCP socket")?;
|
||||
|
||||
// Make an async tcplistener from the socket2 socket
|
||||
let std_listener: std::net::TcpListener = socket.into();
|
||||
@@ -209,7 +213,7 @@ impl Network {
|
||||
let listener = TcpListener::from(std_listener);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let listener = TcpListener::from_std(std_listener).map_err(map_to_string)?;
|
||||
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +283,7 @@ impl Network {
|
||||
port: u16,
|
||||
is_tls: bool,
|
||||
new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
|
||||
) -> Result<Vec<SocketAddress>, String> {
|
||||
) -> EyreResult<Vec<SocketAddress>> {
|
||||
let mut out = Vec::<SocketAddress>::new();
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
|
||||
@@ -3,7 +3,7 @@ use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
impl Network {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> {
|
||||
// Spawn socket tasks
|
||||
let mut task_count = {
|
||||
let c = self.config.get();
|
||||
@@ -73,7 +73,7 @@ impl Network {
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
log_net!(debug "failed to process received udp envelope: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
@@ -110,7 +110,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
|
||||
pub(super) async fn create_udp_outbound_sockets(&self) -> EyreResult<()> {
|
||||
let mut inner = self.inner.lock();
|
||||
let mut port = inner.udp_port;
|
||||
// v4
|
||||
@@ -119,9 +119,9 @@ impl Network {
|
||||
// Pull the port if we randomly bound, so v6 can be on the same port
|
||||
port = socket
|
||||
.local_addr()
|
||||
.map_err(map_to_string)?
|
||||
.wrap_err("failed to get local address")?
|
||||
.as_socket_ipv4()
|
||||
.ok_or_else(|| "expected ipv4 address type".to_owned())?
|
||||
.ok_or_else(|| eyre!("expected ipv4 address type"))?
|
||||
.port();
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
@@ -131,7 +131,7 @@ impl Network {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make outbound v4 tokio udpsocket")?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
@@ -152,7 +152,7 @@ impl Network {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make outbound v6 tokio udpsocket")?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
@@ -166,7 +166,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> EyreResult<()> {
|
||||
log_net!("create_udp_inbound_socket on {:?}", &addr);
|
||||
|
||||
// Create a reusable socket
|
||||
@@ -179,7 +179,7 @@ impl Network {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
@@ -200,7 +200,7 @@ impl Network {
|
||||
&self,
|
||||
ip_addrs: Vec<IpAddr>,
|
||||
port: u16,
|
||||
) -> Result<Vec<DialInfo>, String> {
|
||||
) -> EyreResult<Vec<DialInfo>> {
|
||||
let mut out = Vec::<DialInfo>::new();
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
|
||||
@@ -6,6 +6,7 @@ pub mod ws;
|
||||
|
||||
use super::*;
|
||||
use crate::xx::*;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProtocolNetworkConnection {
|
||||
@@ -21,7 +22,7 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
@@ -35,7 +36,7 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
@@ -55,7 +56,7 @@ impl ProtocolNetworkConnection {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
@@ -102,7 +103,7 @@ impl ProtocolNetworkConnection {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::RawTcp(t) => t.send(message).await,
|
||||
@@ -111,7 +112,7 @@ impl ProtocolNetworkConnection {
|
||||
Self::Wss(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::RawTcp(t) => t.recv().await,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use async_io::Async;
|
||||
use std::io;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub use async_std::net::{TcpStream, TcpListener, Shutdown, UdpSocket};
|
||||
@@ -19,12 +21,12 @@ cfg_if! {
|
||||
use winapi::ctypes::c_int;
|
||||
use std::os::windows::io::AsRawSocket;
|
||||
|
||||
fn set_exclusiveaddruse(socket: &Socket) -> Result<(), String> {
|
||||
fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
|
||||
unsafe {
|
||||
let optval:c_int = 1;
|
||||
if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
|
||||
std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
|
||||
return Err("Unable to SO_EXCLUSIVEADDRUSE".to_owned());
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -32,49 +34,37 @@ cfg_if! {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
pub fn new_unbound_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_address(true)?;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_unbound_shared_udp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket.bind(&socket2_addr).map_err(|e| {
|
||||
format!(
|
||||
"failed to bind UDP socket to '{}' in domain '{:?}': {} ",
|
||||
local_address, domain, e
|
||||
)
|
||||
})?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
log_net!("created bound shared udp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
// Bind the socket -first- before turning on 'reuse address' this way it will
|
||||
// fail if the port is already taken
|
||||
@@ -87,18 +77,15 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
}
|
||||
}
|
||||
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind UDP socket: {}", e))?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
// Set 'reuse address' so future binds to this port will succeed
|
||||
// This does not work on Windows, where reuse options can not be set after the bind
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
.set_reuse_address(true)?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
log_net!("created bound first udp socket on {:?}", &local_address);
|
||||
@@ -106,10 +93,8 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
@@ -117,43 +102,33 @@ pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_address(true)?;
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = new_unbound_shared_tcp_socket(domain)?;
|
||||
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
log_net!("created bound shared tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
@@ -161,9 +136,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
|
||||
@@ -176,18 +149,15 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
// Bind the socket -first- before turning on 'reuse address' this way it will
|
||||
// fail if the port is already taken
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
// Set 'reuse address' so future binds to this port will succeed
|
||||
// This does not work on Windows, where reuse options can not be set after the bind
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
.set_reuse_address(true)?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
log_net!("created bound first tcp socket on {:?}", &local_address);
|
||||
@@ -196,7 +166,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
}
|
||||
|
||||
// Non-blocking connect is tricky when you want to start with a prepared socket
|
||||
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::Result<TcpStream> {
|
||||
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> io::Result<TcpStream> {
|
||||
// Set for non blocking connect
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
|
||||
@@ -42,47 +42,45 @@ impl RawTcpNetworkConnection {
|
||||
// }
|
||||
// }
|
||||
|
||||
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> io::Result<()> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
bail_io_error_other!("sending too large TCP message");
|
||||
}
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
stream.write_all(&header).await.map_err(map_to_string)?;
|
||||
stream.write_all(&message).await.map_err(map_to_string)
|
||||
stream.write_all(&header).await?;
|
||||
stream.write_all(&message).await
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
let mut stream = self.stream.clone();
|
||||
Self::send_internal(&mut stream, message).await
|
||||
}
|
||||
|
||||
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<Vec<u8>> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
stream
|
||||
.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| format!("TCP recv error: {}", e))?;
|
||||
stream.read_exact(&mut header).await?;
|
||||
|
||||
if header[0] != b'V' || header[1] != b'L' {
|
||||
return Err("received invalid TCP frame header".to_owned());
|
||||
bail_io_error_other!("received invalid TCP frame header");
|
||||
}
|
||||
let len = ((header[3] as usize) << 8) | (header[2] as usize);
|
||||
if len > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large TCP frame".to_owned());
|
||||
bail_io_error_other!("received too large TCP frame");
|
||||
}
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
stream.read_exact(&mut out).await.map_err(map_to_string)?;
|
||||
stream.read_exact(&mut out).await?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(ret.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
let mut stream = self.stream.clone();
|
||||
let out = Self::recv_internal(&mut stream).await?;
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
@@ -121,14 +119,10 @@ impl RawTcpProtocolHandler {
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> io::Result<Option<ProtocolNetworkConnection>> {
|
||||
log_net!("TCP: on_accept_async: enter");
|
||||
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
|
||||
let peeklen = stream
|
||||
.peek(&mut peekbuf)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not peek tcp stream"))?;
|
||||
let peeklen = stream.peek(&mut peekbuf).await?;
|
||||
assert_eq!(peeklen, PEEK_DETECT_LEN);
|
||||
|
||||
let peer_addr = PeerAddress::new(
|
||||
@@ -150,7 +144,7 @@ impl RawTcpProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
// Get remote socket address to connect to
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
@@ -163,15 +157,10 @@ impl RawTcpProtocolHandler {
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, remote_socket_addr).await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
let ts = nonblocking_connect(socket, remote_socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let actual_local_address = ts.local_addr()?;
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let ts = ts.compat();
|
||||
let ps = AsyncPeekStream::new(ts);
|
||||
@@ -189,12 +178,9 @@ impl RawTcpProtocolHandler {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound TCP message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound TCP message");
|
||||
}
|
||||
trace!(
|
||||
"sending unbound message of length {} to {}",
|
||||
@@ -206,10 +192,7 @@ impl RawTcpProtocolHandler {
|
||||
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
// let actual_local_address = ts
|
||||
@@ -231,9 +214,9 @@ impl RawTcpProtocolHandler {
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound TCP message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound TCP message");
|
||||
}
|
||||
trace!(
|
||||
"sending unbound message of length {} to {}",
|
||||
@@ -245,10 +228,7 @@ impl RawTcpProtocolHandler {
|
||||
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
// let actual_local_address = ts
|
||||
@@ -265,7 +245,7 @@ impl RawTcpProtocolHandler {
|
||||
|
||||
let out = timeout(timeout_ms, RawTcpNetworkConnection::recv_internal(&mut ps))
|
||||
.await
|
||||
.map_err(map_to_string)??;
|
||||
.map_err(|e| e.to_io())??;
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
@@ -277,7 +257,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,27 +12,30 @@ impl RawUdpProtocolHandler {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn recv_message(
|
||||
&self,
|
||||
data: &mut [u8],
|
||||
) -> Result<(usize, ConnectionDescriptor), String> {
|
||||
let (size, remote_addr) = self.socket.recv_from(data).await.map_err(map_to_string)?;
|
||||
|
||||
if size > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large UDP message".to_owned());
|
||||
}
|
||||
|
||||
trace!(
|
||||
"receiving UDP message of length {} from {}",
|
||||
size,
|
||||
remote_addr
|
||||
);
|
||||
pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> {
|
||||
let (size, remote_addr) = loop {
|
||||
match self.socket.recv_from(data).await {
|
||||
Ok((size, remote_addr)) => {
|
||||
if size > MAX_MESSAGE_SIZE {
|
||||
bail_io_error_other!("received too large UDP message");
|
||||
}
|
||||
break (size, remote_addr);
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::ConnectionReset {
|
||||
// Ignore icmp
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(remote_addr),
|
||||
ProtocolType::UDP,
|
||||
);
|
||||
let local_socket_addr = self.socket.local_addr().map_err(map_to_string)?;
|
||||
let local_socket_addr = self.socket.local_addr()?;
|
||||
let descriptor = ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_socket_addr),
|
||||
@@ -44,45 +47,24 @@ impl RawUdpProtocolHandler {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large UDP message");
|
||||
}
|
||||
|
||||
log_net!(
|
||||
"sending UDP message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
let len = self
|
||||
.socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed udp send: addr={}", socket_addr))?;
|
||||
|
||||
let len = self.socket.send_to(&data, socket_addr).await?;
|
||||
if len != data.len() {
|
||||
Err("UDP partial send".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(())
|
||||
bail_io_error_other!("UDP partial send")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound UDP message".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large unbound UDP message");
|
||||
}
|
||||
log_net!(
|
||||
"sending unbound message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
// get local wildcard address for bind
|
||||
let local_socket_addr = match socket_addr {
|
||||
@@ -91,20 +73,13 @@ impl RawUdpProtocolHandler {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||
}
|
||||
};
|
||||
let socket = UdpSocket::bind(local_socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
|
||||
let len = socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
|
||||
let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
let len = socket.send_to(&data, socket_addr).await?;
|
||||
if len != data.len() {
|
||||
Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(())
|
||||
bail_io_error_other!("UDP partial unbound send")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))]
|
||||
@@ -112,16 +87,10 @@ impl RawUdpProtocolHandler {
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound UDP message".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large unbound UDP message");
|
||||
}
|
||||
log_net!(
|
||||
"sending unbound message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
// get local wildcard address for bind
|
||||
let local_socket_addr = match socket_addr {
|
||||
@@ -132,29 +101,21 @@ impl RawUdpProtocolHandler {
|
||||
};
|
||||
|
||||
// get unspecified bound socket
|
||||
let socket = UdpSocket::bind(local_socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
|
||||
let len = socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
|
||||
let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
let len = socket.send_to(&data, socket_addr).await?;
|
||||
if len != data.len() {
|
||||
return Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error));
|
||||
bail_io_error_other!("UDP partial unbound send");
|
||||
}
|
||||
|
||||
// receive single response
|
||||
let mut out = vec![0u8; MAX_MESSAGE_SIZE];
|
||||
let (len, from_addr) = timeout(timeout_ms, socket.recv_from(&mut out))
|
||||
.await
|
||||
.map_err(map_to_string)?
|
||||
.map_err(map_to_string)?;
|
||||
.map_err(|e| e.to_io())??;
|
||||
|
||||
// if the from address is not the same as the one we sent to, then drop this
|
||||
if from_addr != socket_addr {
|
||||
return Err(format!(
|
||||
bail_io_error_other!(format!(
|
||||
"Unbound response received from wrong address: addr={}",
|
||||
from_addr,
|
||||
));
|
||||
|
||||
@@ -17,6 +17,25 @@ cfg_if! {
|
||||
}
|
||||
}
|
||||
|
||||
fn to_io(err: async_tungstenite::tungstenite::Error) -> io::Error {
|
||||
let kind = match err {
|
||||
async_tungstenite::tungstenite::Error::ConnectionClosed => io::ErrorKind::ConnectionReset,
|
||||
async_tungstenite::tungstenite::Error::AlreadyClosed => io::ErrorKind::NotConnected,
|
||||
async_tungstenite::tungstenite::Error::Io(x) => {
|
||||
return x;
|
||||
}
|
||||
async_tungstenite::tungstenite::Error::Tls(_) => io::ErrorKind::InvalidData,
|
||||
async_tungstenite::tungstenite::Error::Capacity(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Protocol(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::SendQueueFull(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Utf8 => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Url(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Http(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::HttpFormat(_) => io::ErrorKind::Other,
|
||||
};
|
||||
io::Error::new(kind, err)
|
||||
}
|
||||
|
||||
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
|
||||
|
||||
pub struct WebsocketNetworkConnection<T>
|
||||
@@ -62,41 +81,49 @@ where
|
||||
// }
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
bail_io_error_other!("received too large WS message");
|
||||
}
|
||||
self.stream
|
||||
.clone()
|
||||
.send(Message::binary(message))
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
.map_err(to_io)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self), fields(ret.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
let out = match self.stream.clone().next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
Some(Ok(Message::Close(e))) => {
|
||||
return Err(format!("WS connection closed: {:?}", e));
|
||||
Some(Ok(Message::Binary(v))) => {
|
||||
if v.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"too large ws message",
|
||||
));
|
||||
}
|
||||
v
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
return Err(io::Error::new(io::ErrorKind::ConnectionReset, "closeframe"))
|
||||
}
|
||||
Some(Ok(x)) => {
|
||||
return Err(format!("Unexpected WS message type: {:?}", x));
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Unexpected WS message type: {:?}", x),
|
||||
));
|
||||
}
|
||||
Some(Err(e)) => return Err(to_io(e)),
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned());
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"connection ended",
|
||||
))
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,21 +172,18 @@ impl WebsocketProtocolHandler {
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> io::Result<Option<ProtocolNetworkConnection>> {
|
||||
log_net!("WS: on_accept_async: enter");
|
||||
let request_path_len = self.arc.request_path.len() + 2;
|
||||
|
||||
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
|
||||
match timeout(
|
||||
if let Err(_) = timeout(
|
||||
self.arc.connection_initial_timeout_ms,
|
||||
ps.peek_exact(&mut peekbuf),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
return Err(e.to_string());
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Check for websocket path
|
||||
@@ -169,15 +193,12 @@ impl WebsocketProtocolHandler {
|
||||
&& peekbuf[request_path_len - 1] == b' '));
|
||||
|
||||
if !matches_path {
|
||||
log_net!("WS: not websocket");
|
||||
return Ok(None);
|
||||
}
|
||||
log_net!("WS: found websocket");
|
||||
|
||||
let ws_stream = accept_async(ps)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed websockets handshake"))?;
|
||||
.map_err(|e| io_error_other!(format!("failed websockets handshake: {}", e)))?;
|
||||
|
||||
// Wrap the websocket in a NetworkConnection and register it
|
||||
let protocol_type = if self.arc.tls {
|
||||
@@ -205,7 +226,7 @@ impl WebsocketProtocolHandler {
|
||||
async fn connect_internal(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match &dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
@@ -213,9 +234,9 @@ impl WebsocketProtocolHandler {
|
||||
_ => panic!("invalid dialinfo for WS/WSS protocol"),
|
||||
};
|
||||
let request = dial_info.request().unwrap();
|
||||
let split_url = SplitUrl::from_str(&request)?;
|
||||
let split_url = SplitUrl::from_str(&request).map_err(to_io_error_other)?;
|
||||
if split_url.scheme != scheme {
|
||||
return Err("invalid websocket url scheme".to_string());
|
||||
bail_io_error_other!("invalid websocket url scheme");
|
||||
}
|
||||
let domain = split_url.host.clone();
|
||||
|
||||
@@ -231,12 +252,10 @@ impl WebsocketProtocolHandler {
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
|
||||
let actual_local_addr = tcp_stream.local_addr()?;
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let tcp_stream = tcp_stream.compat();
|
||||
@@ -249,15 +268,10 @@ impl WebsocketProtocolHandler {
|
||||
// Negotiate TLS if this is WSS
|
||||
if tls {
|
||||
let connector = TlsConnector::default();
|
||||
let tls_stream = connector
|
||||
.connect(domain.to_string(), tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let tls_stream = connector.connect(domain.to_string(), tcp_stream).await?;
|
||||
let (ws_stream, _response) = client_async(request, tls_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.map_err(to_io_error_other)?;
|
||||
|
||||
Ok(ProtocolNetworkConnection::Wss(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
@@ -265,8 +279,7 @@ impl WebsocketProtocolHandler {
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.map_err(to_io_error_other)?;
|
||||
Ok(ProtocolNetworkConnection::Ws(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
))
|
||||
@@ -277,19 +290,17 @@ impl WebsocketProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
Self::connect_internal(local_address, dial_info).await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound WS message");
|
||||
}
|
||||
|
||||
let protconn = Self::connect_internal(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
|
||||
|
||||
protconn.send(data).await
|
||||
}
|
||||
@@ -299,19 +310,17 @@ impl WebsocketProtocolHandler {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound WS message");
|
||||
}
|
||||
|
||||
let protconn = Self::connect_internal(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
|
||||
|
||||
protconn.send(data).await?;
|
||||
let out = timeout(timeout_ms, protconn.recv())
|
||||
.await
|
||||
.map_err(map_to_string)??;
|
||||
.map_err(|e| e.to_io())??;
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
@@ -323,7 +332,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ impl Network {
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
fn find_available_udp_port(&self) -> Result<u16, String> {
|
||||
fn find_available_udp_port(&self) -> EyreResult<u16> {
|
||||
// If the address is empty, iterate ports until we find one we can use.
|
||||
let mut udp_port = 5150u16;
|
||||
loop {
|
||||
@@ -175,14 +175,14 @@ impl Network {
|
||||
break;
|
||||
}
|
||||
if udp_port == 65535 {
|
||||
return Err("Could not find free udp port to listen on".to_owned());
|
||||
bail!("Could not find free udp port to listen on");
|
||||
}
|
||||
udp_port += 1;
|
||||
}
|
||||
Ok(udp_port)
|
||||
}
|
||||
|
||||
fn find_available_tcp_port(&self) -> Result<u16, String> {
|
||||
fn find_available_tcp_port(&self) -> EyreResult<u16> {
|
||||
// If the address is empty, iterate ports until we find one we can use.
|
||||
let mut tcp_port = 5150u16;
|
||||
loop {
|
||||
@@ -193,17 +193,14 @@ impl Network {
|
||||
break;
|
||||
}
|
||||
if tcp_port == 65535 {
|
||||
return Err("Could not find free tcp port to listen on".to_owned());
|
||||
bail!("Could not find free tcp port to listen on");
|
||||
}
|
||||
tcp_port += 1;
|
||||
}
|
||||
Ok(tcp_port)
|
||||
}
|
||||
|
||||
async fn allocate_udp_port(
|
||||
&self,
|
||||
listen_address: String,
|
||||
) -> Result<(u16, Vec<IpAddr>), String> {
|
||||
async fn allocate_udp_port(&self, listen_address: String) -> EyreResult<(u16, Vec<IpAddr>)> {
|
||||
if listen_address.is_empty() {
|
||||
// If listen address is empty, find us a port iteratively
|
||||
let port = self.find_available_udp_port()?;
|
||||
@@ -217,21 +214,17 @@ impl Network {
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
|
||||
if sockaddrs.is_empty() {
|
||||
return Err(format!("No valid listen address: {}", listen_address));
|
||||
bail!("No valid listen address: {}", listen_address);
|
||||
}
|
||||
let port = sockaddrs[0].port();
|
||||
if self.bind_first_udp_port(port) {
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
} else {
|
||||
Err("Could not find free udp port to listen on".to_owned())
|
||||
if !self.bind_first_udp_port(port) {
|
||||
bail!("Could not find free udp port to listen on");
|
||||
}
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn allocate_tcp_port(
|
||||
&self,
|
||||
listen_address: String,
|
||||
) -> Result<(u16, Vec<IpAddr>), String> {
|
||||
async fn allocate_tcp_port(&self, listen_address: String) -> EyreResult<(u16, Vec<IpAddr>)> {
|
||||
if listen_address.is_empty() {
|
||||
// If listen address is empty, find us a port iteratively
|
||||
let port = self.find_available_tcp_port()?;
|
||||
@@ -245,20 +238,19 @@ impl Network {
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
|
||||
if sockaddrs.is_empty() {
|
||||
return Err(format!("No valid listen address: {}", listen_address));
|
||||
bail!("No valid listen address: {}", listen_address);
|
||||
}
|
||||
let port = sockaddrs[0].port();
|
||||
if self.bind_first_tcp_port(port) {
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
} else {
|
||||
Err("Could not find free tcp port to listen on".to_owned())
|
||||
if !self.bind_first_tcp_port(port) {
|
||||
bail!("Could not find free tcp port to listen on");
|
||||
}
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_udp_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting udp listeners");
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, public_address, enable_local_peer_scope) = {
|
||||
@@ -319,7 +311,7 @@ impl Network {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
.wrap_err(format!("Unable to resolve address: {}", public_address))?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
for pdi_addr in &mut public_sockaddrs {
|
||||
@@ -364,7 +356,7 @@ impl Network {
|
||||
self.create_udp_listener_tasks().await
|
||||
}
|
||||
|
||||
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_ws_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting ws listeners");
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, url, path, enable_local_peer_scope) = {
|
||||
@@ -405,9 +397,9 @@ impl Network {
|
||||
|
||||
// Add static public dialinfo if it's configured
|
||||
if let Some(url) = url.as_ref() {
|
||||
let mut split_url = SplitUrl::from_str(url)?;
|
||||
let mut split_url = SplitUrl::from_str(url).wrap_err("couldn't split url")?;
|
||||
if split_url.scheme.to_ascii_lowercase() != "ws" {
|
||||
return Err("WS URL must use 'ws://' scheme".to_owned());
|
||||
bail!("WS URL must use 'ws://' scheme");
|
||||
}
|
||||
split_url.scheme = "ws".to_owned();
|
||||
|
||||
@@ -415,13 +407,11 @@ impl Network {
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(80)
|
||||
.to_socket_addrs()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.wrap_err("failed to resolve ws url")?;
|
||||
|
||||
for gsa in global_socket_addrs {
|
||||
let pdi = DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.wrap_err("try_ws failed")?;
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
@@ -458,9 +448,7 @@ impl Network {
|
||||
}
|
||||
// Build dial info request url
|
||||
let local_url = format!("ws://{}/{}", socket_address, path);
|
||||
let local_di = DialInfo::try_ws(socket_address, local_url)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let local_di = DialInfo::try_ws(socket_address, local_url).wrap_err("try_ws failed")?;
|
||||
|
||||
if url.is_none() && (socket_address.address().is_global() || enable_local_peer_scope) {
|
||||
// Register public dial info
|
||||
@@ -490,7 +478,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_wss_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting wss listeners");
|
||||
|
||||
let routing_table = self.routing_table();
|
||||
@@ -538,7 +526,7 @@ impl Network {
|
||||
// Add static public dialinfo if it's configured
|
||||
let mut split_url = SplitUrl::from_str(url)?;
|
||||
if split_url.scheme.to_ascii_lowercase() != "wss" {
|
||||
return Err("WSS URL must use 'wss://' scheme".to_owned());
|
||||
bail!("WSS URL must use 'wss://' scheme");
|
||||
}
|
||||
split_url.scheme = "wss".to_owned();
|
||||
|
||||
@@ -546,13 +534,10 @@ impl Network {
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(443)
|
||||
.to_socket_addrs()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
.wrap_err("failed to resolve wss url")?;
|
||||
for gsa in global_socket_addrs {
|
||||
let pdi = DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.wrap_err("try_wss failed")?;
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
@@ -581,7 +566,7 @@ impl Network {
|
||||
registered_addresses.insert(gsa.ip());
|
||||
}
|
||||
} else {
|
||||
return Err("WSS URL must be specified due to TLS requirements".to_owned());
|
||||
bail!("WSS URL must be specified due to TLS requirements");
|
||||
}
|
||||
|
||||
if static_public {
|
||||
@@ -594,7 +579,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_tcp_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting tcp listeners");
|
||||
|
||||
let routing_table = self.routing_table();
|
||||
@@ -659,7 +644,7 @@ impl Network {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
.wrap_err("failed to resolve tcp address")?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
for pdi_addr in &mut public_sockaddrs {
|
||||
|
||||
Reference in New Issue
Block a user