massive network refactor

This commit is contained in:
John Smith
2022-06-04 20:18:26 -04:00
parent 8148c37708
commit cfcf430a99
16 changed files with 515 additions and 294 deletions

View File

@@ -341,7 +341,7 @@ impl Network {
log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it
conn.send(data).await.map_err(logthru_net!())?;
conn.send_async(data).await.map_err(logthru_net!())?;
// Network accounting
self.network_manager()
@@ -389,7 +389,7 @@ impl Network {
.get_or_create_connection(Some(local_addr), dial_info.clone())
.await?;
let res = conn.send(data).await.map_err(logthru_net!(error));
let res = conn.send_async(data).await.map_err(logthru_net!(error));
if res.is_ok() {
// Network accounting
self.network_manager()

View File

@@ -7,7 +7,7 @@ use sockets::*;
#[derive(Clone)]
pub struct ListenerState {
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub protocol_accept_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_acceptor: Option<TlsAcceptor>,
}
@@ -15,7 +15,7 @@ pub struct ListenerState {
impl ListenerState {
pub fn new() -> Self {
Self {
protocol_handlers: Vec::new(),
protocol_accept_handlers: Vec::new(),
tls_protocol_handlers: Vec::new(),
tls_acceptor: None,
}
@@ -46,7 +46,7 @@ impl Network {
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout: u64,
) -> Result<Option<NetworkConnection>, String> {
) -> Result<Option<ProtocolNetworkConnection>, String> {
let ts = tls_acceptor
.accept(stream)
.await
@@ -76,9 +76,9 @@ impl Network {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<NetworkConnection>, String> {
for ah in protocol_handlers.iter() {
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<ProtocolNetworkConnection>, String> {
for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr)
.await
@@ -185,7 +185,7 @@ impl Network {
)
.await
} else {
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
.await
};
@@ -207,7 +207,10 @@ impl Network {
};
// Register the new connection in the connection manager
if let Err(e) = connection_manager.on_new_connection(conn).await {
if let Err(e) = connection_manager
.on_accepted_protocol_network_connection(conn)
.await
{
log_net!(error "failed to register new connection: {}", e);
}
})
@@ -270,7 +273,7 @@ impl Network {
));
} else {
ls.write()
.protocol_handlers
.protocol_accept_handlers
.push(new_protocol_accept_handler(
self.network_manager().config(),
false,

View File

@@ -21,7 +21,7 @@ impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo");
@@ -55,6 +55,16 @@ impl ProtocolNetworkConnection {
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
match self {
Self::Dummy(d) => d.descriptor(),
Self::RawTcp(t) => t.descriptor(),
Self::WsAccepted(w) => w.descriptor(),
Self::Ws(w) => w.descriptor(),
Self::Wss(w) => w.descriptor(),
}
}
pub async fn close(&self) -> Result<(), String> {
match self {
Self::Dummy(d) => d.close(),

View File

@@ -3,6 +3,7 @@ use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection {
descriptor: ConnectionDescriptor,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
}
@@ -14,8 +15,20 @@ impl fmt::Debug for RawTcpNetworkConnection {
}
impl RawTcpNetworkConnection {
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
Self { stream, tcp_stream }
pub fn new(
descriptor: ConnectionDescriptor,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
) -> Self {
Self {
descriptor,
stream,
tcp_stream,
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> {
@@ -33,7 +46,7 @@ impl RawTcpNetworkConnection {
.map_err(logthru_net!())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
log_net!("sending TCP message of size {}", message.len());
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned());
@@ -41,7 +54,6 @@ impl RawTcpNetworkConnection {
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
let mut stream = self.stream.clone();
stream
.write_all(&header)
.await
@@ -54,6 +66,11 @@ impl RawTcpNetworkConnection {
.map_err(logthru_net!())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let stream = self.stream.clone();
Self::send_internal(stream, message).await
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4];
@@ -108,7 +125,7 @@ impl RawTcpProtocolHandler {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> {
) -> Result<Option<ProtocolNetworkConnection>, String> {
log_net!("TCP: on_accept_async: enter");
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream
@@ -123,10 +140,11 @@ impl RawTcpProtocolHandler {
ProtocolType::TCP,
);
let local_address = self.inner.lock().local_address;
let conn = NetworkConnection::from_protocol(
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
);
stream,
tcp_stream,
));
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
@@ -136,7 +154,7 @@ impl RawTcpProtocolHandler {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
// Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr();
@@ -161,13 +179,15 @@ impl RawTcpProtocolHandler {
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: dial_info.to_peer_address(),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
ps,
ts,
));
Ok(conn)
}
@@ -194,24 +214,15 @@ impl RawTcpProtocolHandler {
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?;
// let actual_local_address = ts
// .local_addr()
// .map_err(map_to_string)
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
conn.send(data).await
// Send directly from the raw network connection
// this builds the connection and tears it down immediately after the send
RawTcpNetworkConnection::send_internal(ps, data).await
}
}
@@ -221,7 +232,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
}
}

View File

@@ -15,6 +15,7 @@ pub struct WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
descriptor: ConnectionDescriptor,
stream: CloneStream<WebSocketStream<T>>,
tcp_stream: TcpStream,
}
@@ -32,13 +33,22 @@ impl<T> WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
pub fn new(
descriptor: ConnectionDescriptor,
stream: WebSocketStream<T>,
tcp_stream: TcpStream,
) -> Self {
Self {
descriptor,
stream: CloneStream::new(stream),
tcp_stream,
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
@@ -132,7 +142,7 @@ impl WebsocketProtocolHandler {
ps: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> {
) -> Result<Option<ProtocolNetworkConnection>, String> {
log_net!("WS: on_accept_async: enter");
let request_path_len = self.arc.request_path.len() + 2;
@@ -179,25 +189,24 @@ impl WebsocketProtocolHandler {
let peer_addr =
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
let conn = NetworkConnection::from_protocol(
let conn = ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(self.arc.local_address),
),
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
);
ws_stream,
tcp_stream,
));
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
Ok(Some(conn))
}
pub async fn connect(
async fn connect_internal(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
// Split dial info up
let (tls, scheme) = match &dial_info {
DialInfo::WS(_) => (false, "ws"),
@@ -251,26 +260,27 @@ impl WebsocketProtocolHandler {
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
Ok(ProtocolNetworkConnection::Wss(
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
))
} else {
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
Ok(ProtocolNetworkConnection::Ws(
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
))
}
}
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
Self::connect_internal(local_address, dial_info).await
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());
@@ -281,11 +291,11 @@ impl WebsocketProtocolHandler {
dial_info,
);
let conn = Self::connect(None, dial_info.clone())
let protconn = Self::connect_internal(None, dial_info.clone())
.await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
conn.send(data).await
protconn.send(data).await
}
}
@@ -295,7 +305,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
}
}