massive network refactor
This commit is contained in:
@@ -341,7 +341,7 @@ impl Network {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send(data).await.map_err(logthru_net!())?;
|
||||
conn.send_async(data).await.map_err(logthru_net!())?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -389,7 +389,7 @@ impl Network {
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?;
|
||||
|
||||
let res = conn.send(data).await.map_err(logthru_net!(error));
|
||||
let res = conn.send_async(data).await.map_err(logthru_net!(error));
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
|
@@ -7,7 +7,7 @@ use sockets::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ListenerState {
|
||||
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub protocol_accept_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_acceptor: Option<TlsAcceptor>,
|
||||
}
|
||||
@@ -15,7 +15,7 @@ pub struct ListenerState {
|
||||
impl ListenerState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
protocol_handlers: Vec::new(),
|
||||
protocol_accept_handlers: Vec::new(),
|
||||
tls_protocol_handlers: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
}
|
||||
@@ -46,7 +46,7 @@ impl Network {
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout: u64,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
let ts = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
@@ -76,9 +76,9 @@ impl Network {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
for ah in protocol_handlers.iter() {
|
||||
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), tcp_stream.clone(), addr)
|
||||
.await
|
||||
@@ -185,7 +185,7 @@ impl Network {
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
@@ -207,7 +207,10 @@ impl Network {
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager.on_new_connection(conn).await {
|
||||
if let Err(e) = connection_manager
|
||||
.on_accepted_protocol_network_connection(conn)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
})
|
||||
@@ -270,7 +273,7 @@ impl Network {
|
||||
));
|
||||
} else {
|
||||
ls.write()
|
||||
.protocol_handlers
|
||||
.protocol_accept_handlers
|
||||
.push(new_protocol_accept_handler(
|
||||
self.network_manager().config(),
|
||||
false,
|
||||
|
@@ -21,7 +21,7 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
@@ -55,6 +55,16 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
match self {
|
||||
Self::Dummy(d) => d.descriptor(),
|
||||
Self::RawTcp(t) => t.descriptor(),
|
||||
Self::WsAccepted(w) => w.descriptor(),
|
||||
Self::Ws(w) => w.descriptor(),
|
||||
Self::Wss(w) => w.descriptor(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.close(),
|
||||
|
@@ -3,6 +3,7 @@ use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
use sockets::*;
|
||||
|
||||
pub struct RawTcpNetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
@@ -14,8 +15,20 @@ impl fmt::Debug for RawTcpNetworkConnection {
|
||||
}
|
||||
|
||||
impl RawTcpNetworkConnection {
|
||||
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
|
||||
Self { stream, tcp_stream }
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
stream,
|
||||
tcp_stream,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
@@ -33,7 +46,7 @@ impl RawTcpNetworkConnection {
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
@@ -41,7 +54,6 @@ impl RawTcpNetworkConnection {
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
let mut stream = self.stream.clone();
|
||||
stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
@@ -54,6 +66,11 @@ impl RawTcpNetworkConnection {
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
let stream = self.stream.clone();
|
||||
Self::send_internal(stream, message).await
|
||||
}
|
||||
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
@@ -108,7 +125,7 @@ impl RawTcpProtocolHandler {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
log_net!("TCP: on_accept_async: enter");
|
||||
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
|
||||
let peeklen = stream
|
||||
@@ -123,10 +140,11 @@ impl RawTcpProtocolHandler {
|
||||
ProtocolType::TCP,
|
||||
);
|
||||
let local_address = self.inner.lock().local_address;
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
|
||||
);
|
||||
stream,
|
||||
tcp_stream,
|
||||
));
|
||||
|
||||
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
|
||||
|
||||
@@ -136,7 +154,7 @@ impl RawTcpProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
// Get remote socket address to connect to
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
@@ -161,13 +179,15 @@ impl RawTcpProtocolHandler {
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: dial_info.to_peer_address(),
|
||||
},
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
|
||||
);
|
||||
ps,
|
||||
ts,
|
||||
));
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
@@ -194,24 +214,15 @@ impl RawTcpProtocolHandler {
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
// let actual_local_address = ts
|
||||
// .local_addr()
|
||||
// .map_err(map_to_string)
|
||||
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(socket_addr),
|
||||
ProtocolType::TCP,
|
||||
),
|
||||
},
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
|
||||
);
|
||||
conn.send(data).await
|
||||
// Send directly from the raw network connection
|
||||
// this builds the connection and tears it down immediately after the send
|
||||
RawTcpNetworkConnection::send_internal(ps, data).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +232,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -15,6 +15,7 @@ pub struct WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: CloneStream<WebSocketStream<T>>,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
@@ -32,13 +33,22 @@ impl<T> WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
|
||||
pub fn new(
|
||||
descriptor: ConnectionDescriptor,
|
||||
stream: WebSocketStream<T>,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
descriptor,
|
||||
stream: CloneStream::new(stream),
|
||||
tcp_stream,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
@@ -132,7 +142,7 @@ impl WebsocketProtocolHandler {
|
||||
ps: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
log_net!("WS: on_accept_async: enter");
|
||||
let request_path_len = self.arc.request_path.len() + 2;
|
||||
|
||||
@@ -179,25 +189,24 @@ impl WebsocketProtocolHandler {
|
||||
let peer_addr =
|
||||
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
|
||||
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
let conn = ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(self.arc.local_address),
|
||||
),
|
||||
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
);
|
||||
ws_stream,
|
||||
tcp_stream,
|
||||
));
|
||||
|
||||
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
|
||||
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
async fn connect_internal(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match &dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
@@ -251,26 +260,27 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
Ok(NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
Ok(ProtocolNetworkConnection::Wss(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
|
||||
))
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
Ok(NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
Ok(ProtocolNetworkConnection::Ws(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
Self::connect_internal(local_address, dial_info).await
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
@@ -281,11 +291,11 @@ impl WebsocketProtocolHandler {
|
||||
dial_info,
|
||||
);
|
||||
|
||||
let conn = Self::connect(None, dial_info.clone())
|
||||
let protconn = Self::connect_internal(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
|
||||
conn.send(data).await
|
||||
protconn.send(data).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,7 +305,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user