executor work
This commit is contained in:
@@ -5,7 +5,6 @@ mod protocol;
|
||||
mod start_protocols;
|
||||
|
||||
use super::*;
|
||||
use crate::intf::*;
|
||||
use crate::routing_table::*;
|
||||
use connection_manager::*;
|
||||
use network_tcp::*;
|
||||
@@ -15,10 +14,9 @@ use protocol::ws::WebsocketProtocolHandler;
|
||||
pub use protocol::*;
|
||||
use utils::network_interfaces::*;
|
||||
|
||||
use async_std::io;
|
||||
use async_std::net::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use futures_util::StreamExt;
|
||||
use std::io;
|
||||
// xxx: rustls ^0.20
|
||||
//use rustls::{server::NoClientAuth, Certificate, PrivateKey, ServerConfig};
|
||||
use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
|
||||
@@ -26,7 +24,6 @@ use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -558,12 +555,13 @@ impl Network {
|
||||
let mut inner = self.inner.lock();
|
||||
// take the join handles out
|
||||
for h in inner.join_handles.drain(..) {
|
||||
trace!("joining: {:?}", h);
|
||||
unord.push(h);
|
||||
}
|
||||
// Drop the stop
|
||||
drop(inner.stop_source.take());
|
||||
}
|
||||
debug!("stopping {} low level network tasks", unord.len());
|
||||
debug!("stopping {} low level network tasks", unord.len(),);
|
||||
// Wait for everything to stop
|
||||
while unord.next().await.is_some() {}
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
use crate::intf::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
@@ -43,46 +42,41 @@ impl Network {
|
||||
&self,
|
||||
tls_acceptor: &TlsAcceptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout: u64,
|
||||
tls_connection_initial_timeout_ms: u32,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
let ts = tls_acceptor
|
||||
let tls_stream = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
|
||||
let ps = AsyncPeekStream::new(CloneStream::new(ts));
|
||||
let ps = AsyncPeekStream::new(tls_stream);
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// Try the handlers but first get a chunk of data for them to process
|
||||
// Don't waste more than N seconds getting it though, in case someone
|
||||
// is trying to DoS us with a bunch of connections or something
|
||||
// read a chunk of the stream
|
||||
io::timeout(
|
||||
Duration::from_micros(tls_connection_initial_timeout),
|
||||
intf::timeout(
|
||||
tls_connection_initial_timeout_ms,
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)?
|
||||
.map_err(map_to_string)?;
|
||||
|
||||
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
|
||||
.await
|
||||
self.try_handlers(ps, addr, protocol_handlers).await
|
||||
}
|
||||
|
||||
async fn try_handlers(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), tcp_stream.clone(), addr)
|
||||
.await?
|
||||
{
|
||||
if let Some(nc) = ah.on_accept(stream.clone(), addr).await? {
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
}
|
||||
@@ -92,11 +86,11 @@ impl Network {
|
||||
|
||||
async fn tcp_acceptor(
|
||||
self,
|
||||
tcp_stream: async_std::io::Result<TcpStream>,
|
||||
tcp_stream: io::Result<TcpStream>,
|
||||
listener_state: Arc<RwLock<ListenerState>>,
|
||||
connection_manager: ConnectionManager,
|
||||
connection_initial_timeout: u64,
|
||||
tls_connection_initial_timeout: u64,
|
||||
connection_initial_timeout_ms: u32,
|
||||
tls_connection_initial_timeout_ms: u32,
|
||||
) {
|
||||
let tcp_stream = match tcp_stream {
|
||||
Ok(v) => v,
|
||||
@@ -125,14 +119,16 @@ impl Network {
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
// Create a stream we can peek on
|
||||
let ps = AsyncPeekStream::new(tcp_stream.clone());
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let tcp_stream = tcp_stream.compat();
|
||||
let ps = AsyncPeekStream::new(tcp_stream);
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
if timeout(
|
||||
connection_initial_timeout_ms,
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
@@ -153,14 +149,13 @@ impl Network {
|
||||
self.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
tcp_stream,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
tls_connection_initial_timeout_ms,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
self.try_handlers(ps, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
@@ -192,11 +187,11 @@ impl Network {
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
// Get config
|
||||
let (connection_initial_timeout, tls_connection_initial_timeout) = {
|
||||
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = {
|
||||
let c = self.config.get();
|
||||
(
|
||||
ms_to_us(c.network.connection_initial_timeout_ms),
|
||||
ms_to_us(c.network.tls.connection_initial_timeout_ms),
|
||||
c.network.connection_initial_timeout_ms,
|
||||
c.network.tls.connection_initial_timeout_ms,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -209,7 +204,13 @@ impl Network {
|
||||
|
||||
// Make an async tcplistener from the socket2 socket
|
||||
let std_listener: std::net::TcpListener = socket.into();
|
||||
let listener = TcpListener::from(std_listener);
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let listener = TcpListener::from(std_listener);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
let listener = TcpListener::from_std(std_listener).map_err(map_to_string)?;
|
||||
}
|
||||
}
|
||||
|
||||
debug!("spawn_socket_listener: binding successful to {}", addr);
|
||||
|
||||
@@ -229,8 +230,16 @@ impl Network {
|
||||
let jh = spawn(async move {
|
||||
// moves listener object in and get incoming iterator
|
||||
// when this task exists, the listener will close the socket
|
||||
let _ = listener
|
||||
.incoming()
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let incoming_stream = listener.incoming();
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
}
|
||||
}
|
||||
|
||||
let _ = incoming_stream
|
||||
.for_each_concurrent(None, |tcp_stream| {
|
||||
let this = this.clone();
|
||||
let listener_state = listener_state.clone();
|
||||
@@ -240,8 +249,8 @@ impl Network {
|
||||
tcp_stream,
|
||||
listener_state,
|
||||
connection_manager,
|
||||
connection_initial_timeout,
|
||||
tls_connection_initial_timeout,
|
||||
connection_initial_timeout_ms,
|
||||
tls_connection_initial_timeout_ms,
|
||||
)
|
||||
})
|
||||
.timeout_at(stop_token)
|
||||
@@ -255,7 +264,7 @@ impl Network {
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handles
|
||||
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||
self.add_to_join_handles(jh);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@@ -23,7 +23,7 @@ impl Network {
|
||||
// Run thread task to process stream of messages
|
||||
let this = self.clone();
|
||||
|
||||
let jh = spawn(async move {
|
||||
let jh = spawn_with_local_set(async move {
|
||||
trace!("UDP listener task spawned");
|
||||
|
||||
// Collect all our protocol handlers into a vector
|
||||
@@ -49,7 +49,7 @@ impl Network {
|
||||
for ph in protocol_handlers {
|
||||
let network_manager = network_manager.clone();
|
||||
let stop_token = stop_token.clone();
|
||||
let jh = spawn_local(async move {
|
||||
let jh = intf::spawn_local(async move {
|
||||
let mut data = vec![0u8; 65536];
|
||||
|
||||
loop {
|
||||
@@ -112,7 +112,7 @@ impl Network {
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handle
|
||||
self.add_to_join_handles(MustJoinHandle::new(jh));
|
||||
self.add_to_join_handles(jh);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -134,7 +134,13 @@ impl Network {
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
@@ -148,7 +154,13 @@ impl Network {
|
||||
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v6) {
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
@@ -168,7 +180,13 @@ impl Network {
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
|
@@ -92,15 +92,15 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.close(),
|
||||
Self::RawTcp(t) => t.close().await,
|
||||
Self::WsAccepted(w) => w.close().await,
|
||||
Self::Ws(w) => w.close().await,
|
||||
Self::Wss(w) => w.close().await,
|
||||
}
|
||||
}
|
||||
// pub async fn close(&self) -> Result<(), String> {
|
||||
// match self {
|
||||
// Self::Dummy(d) => d.close(),
|
||||
// Self::RawTcp(t) => t.close().await,
|
||||
// Self::WsAccepted(w) => w.close().await,
|
||||
// Self::Ws(w) => w.close().await,
|
||||
// Self::Wss(w) => w.close().await,
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
match self {
|
||||
|
@@ -1,7 +1,15 @@
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use async_io::Async;
|
||||
use async_std::net::TcpStream;
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub use async_std::net::{TcpStream, TcpListener, Shutdown, UdpSocket};
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
|
||||
pub use tokio_util::compat::*;
|
||||
}
|
||||
}
|
||||
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
|
||||
cfg_if! {
|
||||
@@ -218,5 +226,11 @@ pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::R
|
||||
}?;
|
||||
|
||||
// Convert back to inner and then return async version
|
||||
Ok(TcpStream::from(async_stream.into_inner()?))
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
Ok(TcpStream::from(async_stream.into_inner()?))
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
Ok(TcpStream::from_std(async_stream.into_inner()?)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -5,7 +5,6 @@ use sockets::*;
|
||||
pub struct RawTcpNetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
|
||||
impl fmt::Debug for RawTcpNetworkConnection {
|
||||
@@ -15,31 +14,33 @@ impl fmt::Debug for RawTcpNetworkConnection {
|
||||
}
|
||||
|
||||
impl RawTcpNetworkConnection {
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
stream,
|
||||
tcp_stream,
|
||||
}
|
||||
pub fn new(descriptor: ConnectionDescriptor, stream: AsyncPeekStream) -> Self {
|
||||
Self { descriptor, stream }
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
}
|
||||
// #[instrument(level = "trace", err, skip(self))]
|
||||
// pub async fn close(&mut self) -> Result<(), String> {
|
||||
// // Make an attempt to flush the stream
|
||||
// self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// // Then shut down the write side of the socket to effect a clean close
|
||||
// cfg_if! {
|
||||
// if #[cfg(feature="rt-async-std")] {
|
||||
// self.tcp_stream
|
||||
// .shutdown(async_std::net::Shutdown::Write)
|
||||
// .map_err(map_to_string)
|
||||
// } else if #[cfg(feature="rt-tokio")] {
|
||||
// use tokio::io::AsyncWriteExt;
|
||||
// self.tcp_stream.get_mut()
|
||||
// .shutdown()
|
||||
// .await
|
||||
// .map_err(map_to_string)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
@@ -115,11 +116,10 @@ impl RawTcpProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, stream, tcp_stream))]
|
||||
#[instrument(level = "trace", err, skip(self, stream))]
|
||||
async fn on_accept_async(
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
log_net!("TCP: on_accept_async: enter");
|
||||
@@ -139,7 +139,6 @@ impl RawTcpProtocolHandler {
|
||||
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
|
||||
stream,
|
||||
tcp_stream,
|
||||
));
|
||||
|
||||
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
|
||||
@@ -173,7 +172,9 @@ impl RawTcpProtocolHandler {
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let ts = ts.compat();
|
||||
let ps = AsyncPeekStream::new(ts);
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
@@ -182,7 +183,6 @@ impl RawTcpProtocolHandler {
|
||||
SocketAddress::from_socket_addr(actual_local_address),
|
||||
),
|
||||
ps,
|
||||
ts,
|
||||
));
|
||||
|
||||
Ok(conn)
|
||||
@@ -216,7 +216,10 @@ impl RawTcpProtocolHandler {
|
||||
// .local_addr()
|
||||
// .map_err(map_to_string)
|
||||
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let mut ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let ts = ts.compat();
|
||||
let mut ps = AsyncPeekStream::new(ts);
|
||||
|
||||
// Send directly from the raw network connection
|
||||
// this builds the connection and tears it down immediately after the send
|
||||
@@ -252,7 +255,9 @@ impl RawTcpProtocolHandler {
|
||||
// .local_addr()
|
||||
// .map_err(map_to_string)
|
||||
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let mut ps = AsyncPeekStream::new(ts.clone());
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let ts = ts.compat();
|
||||
let mut ps = AsyncPeekStream::new(ts);
|
||||
|
||||
// Send directly from the raw network connection
|
||||
// this builds the connection and tears it down immediately after the send
|
||||
@@ -271,9 +276,8 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
use super::*;
|
||||
use sockets::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RawUdpProtocolHandler {
|
||||
|
@@ -1,28 +1,35 @@
|
||||
use super::*;
|
||||
use async_std::io;
|
||||
|
||||
use async_tls::TlsConnector;
|
||||
use async_tungstenite::tungstenite::protocol::Message;
|
||||
use async_tungstenite::{accept_async, client_async, WebSocketStream};
|
||||
use futures_util::SinkExt;
|
||||
use futures_util::{AsyncRead, AsyncWrite, SinkExt};
|
||||
use sockets::*;
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub type WebsocketNetworkConnectionWSS =
|
||||
WebsocketNetworkConnection<async_tls::client::TlsStream<TcpStream>>;
|
||||
pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<TcpStream>;
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
pub type WebsocketNetworkConnectionWSS =
|
||||
WebsocketNetworkConnection<async_tls::client::TlsStream<Compat<TcpStream>>>;
|
||||
pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<Compat<TcpStream>>;
|
||||
}
|
||||
}
|
||||
|
||||
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
|
||||
pub type WebsocketNetworkConnectionWSS =
|
||||
WebsocketNetworkConnection<async_tls::client::TlsStream<TcpStream>>;
|
||||
pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<TcpStream>;
|
||||
|
||||
pub struct WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: CloneStream<WebSocketStream<T>>,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", std::any::type_name::<Self>())
|
||||
@@ -31,17 +38,12 @@ where
|
||||
|
||||
impl<T> WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: WebSocketStream<T>,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Self {
|
||||
pub fn new(descriptor: ConnectionDescriptor, stream: WebSocketStream<T>) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
stream: CloneStream::new(stream),
|
||||
tcp_stream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,15 +51,15 @@ where
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
}
|
||||
// #[instrument(level = "trace", err, skip(self))]
|
||||
// pub async fn close(&self) -> Result<(), String> {
|
||||
// // Make an attempt to flush the stream
|
||||
// self.stream.clone().close().await.map_err(map_to_string)?;
|
||||
// // Then forcibly close the socket
|
||||
// self.tcp_stream
|
||||
// .shutdown(Shutdown::Both)
|
||||
// .map_err(map_to_string)
|
||||
// }
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
@@ -101,7 +103,7 @@ struct WebsocketProtocolHandlerArc {
|
||||
tls: bool,
|
||||
local_address: SocketAddr,
|
||||
request_path: Vec<u8>,
|
||||
connection_initial_timeout: u64,
|
||||
connection_initial_timeout_ms: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -119,10 +121,10 @@ impl WebsocketProtocolHandler {
|
||||
} else {
|
||||
format!("GET /{}", c.network.protocol.wss.path.trim_end_matches('/'))
|
||||
};
|
||||
let connection_initial_timeout = if tls {
|
||||
ms_to_us(c.network.tls.connection_initial_timeout_ms)
|
||||
let connection_initial_timeout_ms = if tls {
|
||||
c.network.tls.connection_initial_timeout_ms
|
||||
} else {
|
||||
ms_to_us(c.network.connection_initial_timeout_ms)
|
||||
c.network.connection_initial_timeout_ms
|
||||
};
|
||||
|
||||
Self {
|
||||
@@ -130,34 +132,30 @@ impl WebsocketProtocolHandler {
|
||||
tls,
|
||||
local_address,
|
||||
request_path: path.as_bytes().to_vec(),
|
||||
connection_initial_timeout,
|
||||
connection_initial_timeout_ms,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, ps, tcp_stream))]
|
||||
#[instrument(level = "trace", err, skip(self, ps))]
|
||||
pub async fn on_accept_async(
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
log_net!("WS: on_accept_async: enter");
|
||||
let request_path_len = self.arc.request_path.len() + 2;
|
||||
|
||||
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
|
||||
match io::timeout(
|
||||
Duration::from_micros(self.arc.connection_initial_timeout),
|
||||
match timeout(
|
||||
self.arc.connection_initial_timeout_ms,
|
||||
ps.peek_exact(&mut peekbuf),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::TimedOut {
|
||||
return Err(e).map_err(map_to_string);
|
||||
}
|
||||
return Err(e).map_err(map_to_string);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +192,6 @@ impl WebsocketProtocolHandler {
|
||||
SocketAddress::from_socket_addr(self.arc.local_address),
|
||||
),
|
||||
ws_stream,
|
||||
tcp_stream,
|
||||
));
|
||||
|
||||
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
|
||||
@@ -238,6 +235,9 @@ impl WebsocketProtocolHandler {
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let tcp_stream = tcp_stream.compat();
|
||||
|
||||
// Make our connection descriptor
|
||||
let descriptor = ConnectionDescriptor::new(
|
||||
dial_info.to_peer_address(),
|
||||
@@ -247,7 +247,7 @@ impl WebsocketProtocolHandler {
|
||||
if tls {
|
||||
let connector = TlsConnector::default();
|
||||
let tls_stream = connector
|
||||
.connect(domain.to_string(), tcp_stream.clone())
|
||||
.connect(domain.to_string(), tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
@@ -257,15 +257,15 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
Ok(ProtocolNetworkConnection::Wss(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
))
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
Ok(ProtocolNetworkConnection::Ws(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -319,9 +319,8 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -319,7 +319,6 @@ impl Network {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
@@ -416,7 +415,6 @@ impl Network {
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(80)
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
@@ -548,7 +546,6 @@ impl Network {
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(443)
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
@@ -662,7 +659,6 @@ impl Network {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
|
Reference in New Issue
Block a user