massive network refactor

This commit is contained in:
John Smith 2022-06-04 20:18:26 -04:00
parent 8148c37708
commit cfcf430a99
16 changed files with 515 additions and 294 deletions

View File

@ -3,7 +3,6 @@ use core::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
use core::convert::{TryFrom, TryInto};
use core::fmt;
use core::hash::{Hash, Hasher};
use hex;
use crate::veilid_rng::*;
use ed25519_dalek::{Keypair, PublicKey, Signature};

View File

@ -0,0 +1,38 @@
use super::*;
#[derive(Clone, Debug)]
pub struct ConnectionHandle {
descriptor: ConnectionDescriptor,
channel: flume::Sender<Vec<u8>>,
}
impl ConnectionHandle {
pub(super) fn new(descriptor: ConnectionDescriptor, channel: flume::Sender<Vec<u8>>) -> Self {
Self {
descriptor,
channel,
}
}
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub fn send(&self, message: Vec<u8>) -> Result<(), String> {
self.channel.send(message).map_err(map_to_string)
}
pub async fn send_async(&self, message: Vec<u8>) -> Result<(), String> {
self.channel
.send_async(message)
.await
.map_err(map_to_string)
}
}
impl PartialEq for ConnectionHandle {
fn eq(&self, other: &Self) -> bool {
self.descriptor == other.descriptor
}
}
impl Eq for ConnectionHandle {}

View File

@ -1,7 +1,5 @@
use crate::xx::*;
use crate::*;
use super::*;
use alloc::collections::btree_map::Entry;
use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressFilterError {

View File

@ -3,8 +3,6 @@ use crate::xx::*;
use connection_table::*;
use network_connection::*;
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
///////////////////////////////////////////////////////////
// Connection manager
@ -59,8 +57,7 @@ impl ConnectionManager {
}
pub async fn shutdown(&self) {
// xxx close all connections in the connection table
// Drops connection table, which drops all connections in it
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
}
@ -68,47 +65,48 @@ impl ConnectionManager {
pub async fn get_connection(
&self,
descriptor: ConnectionDescriptor,
) -> Option<NetworkConnection> {
) -> Option<ConnectionHandle> {
let mut inner = self.arc.inner.lock().await;
inner.connection_table.get_connection(descriptor)
}
// Internal routine to register new connection atomically
fn on_new_connection_internal(
// Internal routine to register new connection atomically.
// Registers connection in the connection table for later access
// and spawns a message processing loop for the connection
fn on_new_protocol_network_connection(
&self,
inner: &mut ConnectionManagerInner,
conn: NetworkConnection,
) -> Result<(), String> {
log_net!("on_new_connection_internal: {:?}", conn);
let tx = inner
.connection_add_channel_tx
.as_ref()
.ok_or_else(fn_string!("connection channel isn't open yet"))?
.clone();
conn: ProtocolNetworkConnection,
) -> Result<ConnectionHandle, String> {
log_net!("on_new_protocol_network_connection: {:?}", conn);
let receiver_loop_future = Self::process_connection(self.clone(), conn.clone());
tx.try_send(receiver_loop_future)
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to start receiver loop"))?;
// If the receiver loop started successfully,
// add the new connection to the table
inner.connection_table.add_connection(conn)
// Wrap with NetworkConnection object to start the connection processing loop
let conn = NetworkConnection::from_protocol(self.clone(), conn);
let handle = conn.get_handle();
// Add to the connection table
inner.connection_table.add_connection(conn)?;
Ok(handle)
}
// Called by low-level network when any connection-oriented protocol connection appears
// either from incoming or outgoing connections. Registers connection in the connection table for later access
// and spawns a message processing loop for the connection
pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> {
// either from incoming connections.
pub(super) async fn on_accepted_protocol_network_connection(
&self,
conn: ProtocolNetworkConnection,
) -> Result<(), String> {
let mut inner = self.arc.inner.lock().await;
self.on_new_connection_internal(&mut *inner, conn)
self.on_new_protocol_network_connection(&mut *inner, conn)
.map(drop)
}
// Called when we want to create a new connection or get the current one that already exists
// This will kill off any connections that are in conflict with the new connection to be made
// in order to make room for the new connection in the system's connection table
pub async fn get_or_create_connection(
&self,
local_addr: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ConnectionHandle, String> {
log_net!(
"== get_or_create_connection local_addr={:?} dial_info={:?}",
local_addr.green(),
@ -146,8 +144,10 @@ impl ConnectionManager {
if local_addr.port() != 0 {
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
let pa = PeerAddress::new(descriptor.remote.socket_address, pt);
for conn in inner.connection_table.get_connections_by_remote(pa) {
let desc = conn.connection_descriptor();
for desc in inner
.connection_table
.get_connection_descriptors_by_remote(pa)
{
let mut kill = false;
if let Some(conn_local) = desc.local {
if (local_addr.ip().is_unspecified()
@ -163,7 +163,9 @@ impl ConnectionManager {
local_addr.green(),
pa.green()
);
conn.close().await?;
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
log_net!(error e);
}
}
}
}
@ -171,73 +173,17 @@ impl ConnectionManager {
}
// Attempt new connection
let conn = NetworkConnection::connect(local_addr, dial_info).await?;
let conn = ProtocolNetworkConnection::connect(local_addr, dial_info).await?;
self.on_new_connection_internal(&mut *inner, conn.clone())?;
Ok(conn)
self.on_new_protocol_network_connection(&mut *inner, conn)
}
// Connection receiver loop
fn process_connection(
this: ConnectionManager,
conn: NetworkConnection,
) -> SystemPinBoxFuture<()> {
log_net!("Starting process_connection loop for {:?}", conn.green());
let network_manager = this.network_manager();
Box::pin(async move {
//
let descriptor = conn.connection_descriptor();
let inactivity_timeout = this
.network_manager()
.config()
.get()
.network
.connection_inactivity_timeout_ms;
loop {
// process inactivity timeout on receives only
// if you want a keepalive, it has to be requested from the other side
let message = select! {
res = conn.recv().fuse() => {
match res {
Ok(v) => v,
Err(e) => {
log_net!(debug e);
break;
}
}
}
_ = intf::sleep(inactivity_timeout).fuse()=> {
// timeout
log_net!("connection timeout on {:?}", descriptor.green());
break;
}
};
if let Err(e) = network_manager
.on_recv_envelope(message.as_slice(), descriptor)
.await
{
log_net!(error e);
break;
}
}
log_net!(
"== Connection loop finished local_addr={:?} remote={:?}",
descriptor.local.green(),
descriptor.remote.green()
);
if let Err(e) = this
.arc
.inner
.lock()
.await
.connection_table
.remove_connection(descriptor)
{
log_net!(error e);
}
})
// Callback from network connection receive loop when it exits
// cleans up the entry in the connection table
pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
let mut inner = self.arc.inner.lock().await;
if let Err(e) = inner.connection_table.remove_connection(descriptor) {
log_net!(error e);
}
}
}

View File

@ -1,7 +1,4 @@
use super::connection_limits::*;
use super::network_connection::*;
use crate::xx::*;
use crate::*;
use super::*;
use alloc::collections::btree_map::Entry;
use hashlink::LruCache;
@ -9,7 +6,7 @@ use hashlink::LruCache;
pub struct ConnectionTable {
max_connections: Vec<usize>,
conn_by_descriptor: Vec<LruCache<ConnectionDescriptor, NetworkConnection>>,
conns_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnection>>,
descriptors_by_remote: BTreeMap<PeerAddress, Vec<ConnectionDescriptor>>,
address_filter: ConnectionLimits,
}
@ -39,7 +36,7 @@ impl ConnectionTable {
LruCache::new_unbounded(),
LruCache::new_unbounded(),
],
conns_by_remote: BTreeMap::new(),
descriptors_by_remote: BTreeMap::new(),
address_filter: ConnectionLimits::new(config),
}
}
@ -60,7 +57,7 @@ impl ConnectionTable {
self.address_filter.add(ip_addr).map_err(map_to_string)?;
// Add the connection to the table
let res = self.conn_by_descriptor[index].insert(descriptor, conn.clone());
let res = self.conn_by_descriptor[index].insert(descriptor.clone(), conn);
assert!(res.is_none());
// if we have reached the maximum number of connections per protocol type
@ -73,49 +70,54 @@ impl ConnectionTable {
}
// add connection records
let conns = self.conns_by_remote.entry(descriptor.remote).or_default();
let descriptors = self
.descriptors_by_remote
.entry(descriptor.remote)
.or_default();
warn!("add_connection: {:?}", conn);
conns.push(conn);
warn!("add_connection: {:?}", descriptor);
descriptors.push(descriptor);
Ok(())
}
pub fn get_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Option<NetworkConnection> {
pub fn get_connection(&mut self, descriptor: ConnectionDescriptor) -> Option<ConnectionHandle> {
warn!("get_connection: {:?}", descriptor);
let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index].get(&descriptor).cloned();
warn!("get_connection: {:?} -> {:?}", descriptor, out);
out
let out = self.conn_by_descriptor[index].get(&descriptor);
out.map(|c| c.get_handle())
}
pub fn get_last_connection_by_remote(
&mut self,
remote: PeerAddress,
) -> Option<NetworkConnection> {
let out = self
.conns_by_remote
) -> Option<ConnectionHandle> {
warn!("get_last_connection_by_remote: {:?}", remote);
let descriptor = self
.descriptors_by_remote
.get(&remote)
.map(|v| v[(v.len() - 1)].clone());
warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out);
if let Some(connection) = &out {
if let Some(descriptor) = descriptor {
// lru bump
let index = protocol_to_index(connection.connection_descriptor().protocol_type());
let _ = self.conn_by_descriptor[index].get(&connection.connection_descriptor());
let index = protocol_to_index(descriptor.protocol_type());
let handle = self.conn_by_descriptor[index]
.get(&descriptor)
.map(|c| c.get_handle());
handle
} else {
None
}
out
}
pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec<NetworkConnection> {
let out = self
.conns_by_remote
pub fn get_connection_descriptors_by_remote(
&mut self,
remote: PeerAddress,
) -> Vec<ConnectionDescriptor> {
warn!("get_connection_descriptors_by_remote: {:?}", remote);
self.descriptors_by_remote
.get(&remote)
.cloned()
.unwrap_or_default();
warn!("get_connections_by_remote: {:?} -> {:?}", remote, out);
out
.unwrap_or_default()
}
pub fn connection_count(&self) -> usize {
@ -126,7 +128,7 @@ impl ConnectionTable {
let ip_addr = descriptor.remote.socket_address.to_ip_addr();
// conns_by_remote
match self.conns_by_remote.entry(descriptor.remote) {
match self.descriptors_by_remote.entry(descriptor.remote) {
Entry::Vacant(_) => {
panic!("inconsistency in connection table")
}
@ -135,7 +137,7 @@ impl ConnectionTable {
// Remove one matching connection from the list
for (n, elem) in v.iter().enumerate() {
if elem.connection_descriptor() == descriptor {
if *elem == descriptor {
v.remove(n);
break;
}
@ -151,18 +153,14 @@ impl ConnectionTable {
.expect("Inconsistency in connection table");
}
pub fn remove_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> {
pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> {
warn!("remove_connection: {:?}", descriptor);
let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index]
let _ = self.conn_by_descriptor[index]
.remove(&descriptor)
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
self.remove_connection_records(descriptor);
Ok(out)
Ok(())
}
}

View File

@ -5,6 +5,7 @@ mod native;
#[cfg(target_arch = "wasm32")]
mod wasm;
mod connection_handle;
mod connection_limits;
mod connection_manager;
mod connection_table;
@ -17,8 +18,9 @@ pub mod tests;
pub use network_connection::*;
////////////////////////////////////////////////////////////////////////////////////////
use connection_limits::*;
use connection_manager::*;
use connection_handle::*;
use dht::*;
use hashlink::LruCache;
use intf::*;
@ -1034,7 +1036,7 @@ impl NetworkManager {
// Called when a packet potentially containing an RPC envelope is received by a low-level
// network protocol handler. Processes the envelope, authenticates and decrypts the RPC message
// and passes it to the RPC handler
pub async fn on_recv_envelope(
async fn on_recv_envelope(
&self,
data: &[u8],
descriptor: ConnectionDescriptor,

View File

@ -341,7 +341,7 @@ impl Network {
log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it
conn.send(data).await.map_err(logthru_net!())?;
conn.send_async(data).await.map_err(logthru_net!())?;
// Network accounting
self.network_manager()
@ -389,7 +389,7 @@ impl Network {
.get_or_create_connection(Some(local_addr), dial_info.clone())
.await?;
let res = conn.send(data).await.map_err(logthru_net!(error));
let res = conn.send_async(data).await.map_err(logthru_net!(error));
if res.is_ok() {
// Network accounting
self.network_manager()

View File

@ -7,7 +7,7 @@ use sockets::*;
#[derive(Clone)]
pub struct ListenerState {
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub protocol_accept_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_acceptor: Option<TlsAcceptor>,
}
@ -15,7 +15,7 @@ pub struct ListenerState {
impl ListenerState {
pub fn new() -> Self {
Self {
protocol_handlers: Vec::new(),
protocol_accept_handlers: Vec::new(),
tls_protocol_handlers: Vec::new(),
tls_acceptor: None,
}
@ -46,7 +46,7 @@ impl Network {
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout: u64,
) -> Result<Option<NetworkConnection>, String> {
) -> Result<Option<ProtocolNetworkConnection>, String> {
let ts = tls_acceptor
.accept(stream)
.await
@ -76,9 +76,9 @@ impl Network {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<NetworkConnection>, String> {
for ah in protocol_handlers.iter() {
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<ProtocolNetworkConnection>, String> {
for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr)
.await
@ -185,7 +185,7 @@ impl Network {
)
.await
} else {
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
.await
};
@ -207,7 +207,10 @@ impl Network {
};
// Register the new connection in the connection manager
if let Err(e) = connection_manager.on_new_connection(conn).await {
if let Err(e) = connection_manager
.on_accepted_protocol_network_connection(conn)
.await
{
log_net!(error "failed to register new connection: {}", e);
}
})
@ -270,7 +273,7 @@ impl Network {
));
} else {
ls.write()
.protocol_handlers
.protocol_accept_handlers
.push(new_protocol_accept_handler(
self.network_manager().config(),
false,

View File

@ -21,7 +21,7 @@ impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo");
@ -55,6 +55,16 @@ impl ProtocolNetworkConnection {
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
match self {
Self::Dummy(d) => d.descriptor(),
Self::RawTcp(t) => t.descriptor(),
Self::WsAccepted(w) => w.descriptor(),
Self::Ws(w) => w.descriptor(),
Self::Wss(w) => w.descriptor(),
}
}
pub async fn close(&self) -> Result<(), String> {
match self {
Self::Dummy(d) => d.close(),

View File

@ -3,6 +3,7 @@ use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection {
descriptor: ConnectionDescriptor,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
}
@ -14,8 +15,20 @@ impl fmt::Debug for RawTcpNetworkConnection {
}
impl RawTcpNetworkConnection {
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
Self { stream, tcp_stream }
pub fn new(
descriptor: ConnectionDescriptor,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
) -> Self {
Self {
descriptor,
stream,
tcp_stream,
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> {
@ -33,7 +46,7 @@ impl RawTcpNetworkConnection {
.map_err(logthru_net!())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
log_net!("sending TCP message of size {}", message.len());
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned());
@ -41,7 +54,6 @@ impl RawTcpNetworkConnection {
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
let mut stream = self.stream.clone();
stream
.write_all(&header)
.await
@ -54,6 +66,11 @@ impl RawTcpNetworkConnection {
.map_err(logthru_net!())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let stream = self.stream.clone();
Self::send_internal(stream, message).await
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4];
@ -108,7 +125,7 @@ impl RawTcpProtocolHandler {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> {
) -> Result<Option<ProtocolNetworkConnection>, String> {
log_net!("TCP: on_accept_async: enter");
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream
@ -123,10 +140,11 @@ impl RawTcpProtocolHandler {
ProtocolType::TCP,
);
let local_address = self.inner.lock().local_address;
let conn = NetworkConnection::from_protocol(
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
);
stream,
tcp_stream,
));
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
@ -136,7 +154,7 @@ impl RawTcpProtocolHandler {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
// Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr();
@ -161,13 +179,15 @@ impl RawTcpProtocolHandler {
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: dial_info.to_peer_address(),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
ps,
ts,
));
Ok(conn)
}
@ -194,24 +214,15 @@ impl RawTcpProtocolHandler {
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?;
// let actual_local_address = ts
// .local_addr()
// .map_err(map_to_string)
// .map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
conn.send(data).await
// Send directly from the raw network connection
// this builds the connection and tears it down immediately after the send
RawTcpNetworkConnection::send_internal(ps, data).await
}
}
@ -221,7 +232,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
}
}

View File

@ -15,6 +15,7 @@ pub struct WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
descriptor: ConnectionDescriptor,
stream: CloneStream<WebSocketStream<T>>,
tcp_stream: TcpStream,
}
@ -32,13 +33,22 @@ impl<T> WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
pub fn new(
descriptor: ConnectionDescriptor,
stream: WebSocketStream<T>,
tcp_stream: TcpStream,
) -> Self {
Self {
descriptor,
stream: CloneStream::new(stream),
tcp_stream,
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
@ -132,7 +142,7 @@ impl WebsocketProtocolHandler {
ps: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> {
) -> Result<Option<ProtocolNetworkConnection>, String> {
log_net!("WS: on_accept_async: enter");
let request_path_len = self.arc.request_path.len() + 2;
@ -179,25 +189,24 @@ impl WebsocketProtocolHandler {
let peer_addr =
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
let conn = NetworkConnection::from_protocol(
let conn = ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(self.arc.local_address),
),
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
);
ws_stream,
tcp_stream,
));
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
Ok(Some(conn))
}
pub async fn connect(
async fn connect_internal(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
// Split dial info up
let (tls, scheme) = match &dial_info {
DialInfo::WS(_) => (false, "ws"),
@ -251,26 +260,27 @@ impl WebsocketProtocolHandler {
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
Ok(ProtocolNetworkConnection::Wss(
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
))
} else {
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
Ok(ProtocolNetworkConnection::Ws(
WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
))
}
}
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
Self::connect_internal(local_address, dial_info).await
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());
@ -281,11 +291,11 @@ impl WebsocketProtocolHandler {
dial_info,
);
let conn = Self::connect(None, dial_info.clone())
let protconn = Self::connect_internal(None, dial_info.clone())
.await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
conn.send(data).await
protconn.send(data).await
}
}
@ -295,7 +305,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
}
}

View File

@ -1,5 +1,5 @@
use super::*;
use crate::xx::*;
use futures_util::{FutureExt, StreamExt};
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
@ -16,7 +16,7 @@ cfg_if::cfg_if! {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>;
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>>;
}
pub trait ProtocolAcceptHandlerClone {
@ -45,9 +45,14 @@ cfg_if::cfg_if! {
// Dummy protocol network connection for testing
#[derive(Debug)]
pub struct DummyNetworkConnection {}
pub struct DummyNetworkConnection {
descriptor: ConnectionDescriptor,
}
impl DummyNetworkConnection {
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub fn close(&self) -> Result<(), String> {
Ok(())
}
@ -62,6 +67,14 @@ impl DummyNetworkConnection {
///////////////////////////////////////////////////////////
// Top-level protocol independent network connection object
#[derive(Clone, Copy, Debug)]
enum RecvLoopAction {
Send,
Recv,
Finish,
Timeout,
}
#[derive(Debug, Clone)]
pub struct NetworkConnectionStats {
last_message_sent_time: Option<u64>,
@ -69,107 +82,249 @@ pub struct NetworkConnectionStats {
}
#[derive(Debug)]
struct NetworkConnectionInner {
stats: NetworkConnectionStats,
}
#[derive(Debug)]
struct NetworkConnectionArc {
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
established_time: u64,
inner: Mutex<NetworkConnectionInner>,
}
#[derive(Clone, Debug)]
pub struct NetworkConnection {
arc: Arc<NetworkConnectionArc>,
descriptor: ConnectionDescriptor,
_processor: Option<JoinHandle<()>>,
established_time: u64,
stats: Arc<Mutex<NetworkConnectionStats>>,
sender: flume::Sender<Vec<u8>>,
}
impl PartialEq for NetworkConnection {
fn eq(&self, other: &Self) -> bool {
Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc)
}
}
impl Eq for NetworkConnection {}
impl NetworkConnection {
fn new_inner() -> NetworkConnectionInner {
NetworkConnectionInner {
stats: NetworkConnectionStats {
pub(super) fn dummy(descriptor: ConnectionDescriptor) -> Self {
// Create handle for sending (dummy is immediately disconnected)
let (sender, _receiver) = flume::bounded(intf::get_concurrency() as usize);
Self {
descriptor,
_processor: None,
established_time: intf::get_timestamp(),
stats: Arc::new(Mutex::new(NetworkConnectionStats {
last_message_sent_time: None,
last_message_recv_time: None,
},
}
}
fn new_arc(
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
) -> NetworkConnectionArc {
NetworkConnectionArc {
descriptor,
protocol_connection,
established_time: intf::get_timestamp(),
inner: Mutex::new(Self::new_inner()),
})),
sender,
}
}
pub fn dummy(descriptor: ConnectionDescriptor) -> Self {
NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}),
)
}
pub fn from_protocol(
descriptor: ConnectionDescriptor,
pub(super) fn from_protocol(
connection_manager: ConnectionManager,
protocol_connection: ProtocolNetworkConnection,
) -> Self {
Self {
arc: Arc::new(Self::new_arc(descriptor, protocol_connection)),
}
}
// Get timeout
let network_manager = connection_manager.network_manager();
let inactivity_timeout = network_manager
.config()
.get()
.network
.connection_inactivity_timeout_ms;
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
ProtocolNetworkConnection::connect(local_address, dial_info).await
// Get descriptor
let descriptor = protocol_connection.descriptor();
// Create handle for sending
let (sender, receiver) = flume::bounded(intf::get_concurrency() as usize);
// Create stats
let stats = Arc::new(Mutex::new(NetworkConnectionStats {
last_message_sent_time: None,
last_message_recv_time: None,
}));
// Spawn connection processor and pass in protocol connection
let processor = intf::spawn_local(Self::process_connection(
connection_manager,
descriptor.clone(),
receiver,
protocol_connection,
inactivity_timeout,
stats.clone(),
));
// Return the connection
Self {
descriptor,
_processor: Some(processor),
established_time: intf::get_timestamp(),
stats,
sender,
}
}
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
self.arc.descriptor
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> {
self.arc.protocol_connection.close().await
pub fn get_handle(&self) -> ConnectionHandle {
ConnectionHandle::new(self.descriptor.clone(), self.sender.clone())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
async fn send_internal(
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
message: Vec<u8>,
) -> Result<(), String> {
let ts = intf::get_timestamp();
let out = self.arc.protocol_connection.send(message).await;
let out = protocol_connection.send(message).await;
if out.is_ok() {
let mut inner = self.arc.inner.lock();
inner.stats.last_message_sent_time.max_assign(Some(ts));
let mut stats = stats.lock();
stats.last_message_sent_time.max_assign(Some(ts));
}
out
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
async fn recv_internal(
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
) -> Result<Vec<u8>, String> {
let ts = intf::get_timestamp();
let out = self.arc.protocol_connection.recv().await;
let out = protocol_connection.recv().await;
if out.is_ok() {
let mut inner = self.arc.inner.lock();
inner.stats.last_message_recv_time.max_assign(Some(ts));
let mut stats = stats.lock();
stats.last_message_recv_time.max_assign(Some(ts));
}
out
}
pub fn stats(&self) -> NetworkConnectionStats {
let inner = self.arc.inner.lock();
inner.stats.clone()
let stats = self.stats.lock();
stats.clone()
}
pub fn established_time(&self) -> u64 {
self.arc.established_time
self.established_time
}
// Connection receiver loop
fn process_connection(
connection_manager: ConnectionManager,
descriptor: ConnectionDescriptor,
receiver: flume::Receiver<Vec<u8>>,
protocol_connection: ProtocolNetworkConnection,
connection_inactivity_timeout_ms: u32,
stats: Arc<Mutex<NetworkConnectionStats>>,
) -> SystemPinBoxFuture<()> {
Box::pin(async move {
log_net!(
"Starting process_connection loop for {:?}",
descriptor.green()
);
let network_manager = connection_manager.network_manager();
let mut unord = FuturesUnordered::new();
let mut need_receiver = true;
let mut need_sender = true;
// Push mutable timer so we can reset it
// Normally we would use an io::timeout here, but WASM won't support that, so we use a mutable sleep future
let new_timer = || {
intf::sleep(connection_inactivity_timeout_ms).then(|_| async {
// timeout
log_net!("connection timeout on {:?}", descriptor.green());
RecvLoopAction::Timeout
})
};
let timer = MutableFuture::new(new_timer());
unord.push(timer.clone().boxed());
loop {
// Add another message sender future if necessary
if need_sender {
need_sender = false;
unord.push(
receiver
.recv_async()
.then(|res| async {
match res {
Ok(message) => {
// send the packet
if let Err(e) = Self::send_internal(
&protocol_connection,
stats.clone(),
message,
)
.await
{
// Sending the packet along can fail, if so, this connection is dead
log_net!(debug e);
RecvLoopAction::Finish
} else {
RecvLoopAction::Send
}
}
Err(e) => {
// All senders gone, shouldn't happen since we store one alongside the join handle
log_net!(warn e);
RecvLoopAction::Finish
}
}
})
.boxed(),
);
}
// Add another message receiver future if necessary
if need_receiver {
need_sender = false;
unord.push(
Self::recv_internal(&protocol_connection, stats.clone())
.then(|res| async {
match res {
Ok(message) => {
// Pass received messages up to the network manager for processing
if let Err(e) = network_manager
.on_recv_envelope(message.as_slice(), descriptor)
.await
{
log_net!(error e);
RecvLoopAction::Finish
} else {
RecvLoopAction::Recv
}
}
Err(e) => {
// Connection unable to receive, closed
log_net!(warn e);
RecvLoopAction::Finish
}
}
})
.boxed(),
);
}
// Process futures
match unord.next().await {
Some(RecvLoopAction::Send) => {
// Don't reset inactivity timer if we're only sending
need_sender = true;
}
Some(RecvLoopAction::Recv) => {
// Reset inactivity timer since we got something from this connection
timer.set(new_timer());
need_receiver = true;
}
Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => {
break;
}
None => {
// Should not happen
unreachable!();
}
}
}
log_net!(
"== Connection loop finished local_addr={:?} remote={:?}",
descriptor.local.green(),
descriptor.remote.green()
);
connection_manager
.report_connection_finished(descriptor)
.await
})
}
}

View File

@ -52,10 +52,15 @@ pub async fn test_add_get_remove() {
);
let c1 = NetworkConnection::dummy(a1);
let c1h = c1.get_handle();
let c2 = NetworkConnection::dummy(a2);
//let c2h = c2.get_handle();
let c3 = NetworkConnection::dummy(a3);
//let c3h = c3.get_handle();
let c4 = NetworkConnection::dummy(a4);
//let c4h = c4.get_handle();
let c5 = NetworkConnection::dummy(a5);
//let c5h = c5.get_handle();
assert_eq!(a1, c2.connection_descriptor());
assert_ne!(a3, c4.connection_descriptor());
@ -63,36 +68,39 @@ pub async fn test_add_get_remove() {
assert_eq!(table.connection_count(), 0);
assert_eq!(table.get_connection(a1), None);
table.add_connection(c1.clone()).unwrap();
table.add_connection(c1).unwrap();
assert_eq!(table.connection_count(), 1);
assert_err!(table.remove_connection(a3));
assert_err!(table.remove_connection(a4));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.connection_count(), 1);
assert_err!(table.add_connection(c1.clone()));
assert_err!(table.add_connection(c2.clone()));
assert_err!(table.add_connection(c2));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.remove_connection(a2), Ok(c1.clone()));
assert_eq!(table.remove_connection(a2), Ok(()));
assert_eq!(table.connection_count(), 0);
assert_err!(table.remove_connection(a2));
assert_eq!(table.connection_count(), 0);
assert_eq!(table.get_connection(a2), None);
assert_eq!(table.get_connection(a1), None);
assert_eq!(table.connection_count(), 0);
table.add_connection(c1.clone()).unwrap();
let c1 = NetworkConnection::dummy(a1);
//let c1h = c1.get_handle();
table.add_connection(c1).unwrap();
let c2 = NetworkConnection::dummy(a2);
//let c2h = c2.get_handle();
assert_err!(table.add_connection(c2));
table.add_connection(c3.clone()).unwrap();
table.add_connection(c4.clone()).unwrap();
table.add_connection(c3).unwrap();
table.add_connection(c4).unwrap();
assert_eq!(table.connection_count(), 3);
assert_eq!(table.remove_connection(a2), Ok(c1));
assert_eq!(table.remove_connection(a3), Ok(c3));
assert_eq!(table.remove_connection(a4), Ok(c4));
assert_eq!(table.remove_connection(a2), Ok(()));
assert_eq!(table.remove_connection(a3), Ok(()));
assert_eq!(table.remove_connection(a4), Ok(()));
assert_eq!(table.connection_count(), 0);
}

View File

@ -13,6 +13,7 @@ struct WebsocketNetworkConnectionInner {
#[derive(Clone)]
pub struct WebsocketNetworkConnection {
descriptor: ConnectionDescriptor,
inner: Arc<WebsocketNetworkConnectionInner>,
}
@ -23,8 +24,11 @@ impl fmt::Debug for WebsocketNetworkConnection {
}
impl WebsocketNetworkConnection {
pub fn new(ws_meta: WsMeta, ws_stream: WsStream) -> Self {
pub fn new(
descriptor: ConnectionDescriptor,
ws_meta: WsMeta, ws_stream: WsStream) -> Self {
Self {
descriptor,
inner: Arc::new(WebsocketNetworkConnectionInner {
ws_meta,
ws_stream: CloneStream::new(ws_stream),
@ -32,6 +36,10 @@ impl WebsocketNetworkConnection {
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> {
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
}
@ -73,7 +81,7 @@ impl WebsocketProtocolHandler {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
) -> Result<ProtocolNetworkConnection, String> {
assert!(local_address.is_none());
@ -96,10 +104,10 @@ impl WebsocketProtocolHandler {
// Make our connection descriptor
Ok(NetworkConnection::from_protocol(ConnectionDescriptor {
Ok(ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(ConnectionDescriptor {
local: None,
remote: dial_info.to_peer_address(),
},ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(wsmeta, wsio))))
}, wsmeta, wsio)))
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {

View File

@ -8,6 +8,7 @@ mod eventual_value_clone;
mod ip_addr_port;
mod ip_extra;
mod log_thru;
mod mutable_future;
mod single_future;
mod single_shot_eventual;
mod split_url;
@ -104,6 +105,7 @@ pub use eventual_value::*;
pub use eventual_value_clone::*;
pub use ip_addr_port::*;
pub use ip_extra::*;
pub use mutable_future::*;
pub use single_future::*;
pub use single_shot_eventual::*;
pub use tick_task::*;

View File

@ -0,0 +1,33 @@
use super::*;
pub struct MutableFuture<O, T: Future<Output = O>> {
inner: Arc<Mutex<Pin<Box<T>>>>,
}
impl<O, T: Future<Output = O>> MutableFuture<O, T> {
pub fn new(inner: T) -> Self {
Self {
inner: Arc::new(Mutex::new(Box::pin(inner))),
}
}
pub fn set(&self, inner: T) {
*self.inner.lock() = Box::pin(inner);
}
}
impl<O, T: Future<Output = O>> Clone for MutableFuture<O, T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<O, T: Future<Output = O>> Future for MutableFuture<O, T> {
type Output = O;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
let mut inner = self.inner.lock();
T::poll(inner.as_mut(), cx)
}
}