massive network refactor

This commit is contained in:
John Smith 2022-06-04 20:18:26 -04:00
parent 8148c37708
commit cfcf430a99
16 changed files with 515 additions and 294 deletions

View File

@ -3,7 +3,6 @@ use core::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
use core::convert::{TryFrom, TryInto}; use core::convert::{TryFrom, TryInto};
use core::fmt; use core::fmt;
use core::hash::{Hash, Hasher}; use core::hash::{Hash, Hasher};
use hex;
use crate::veilid_rng::*; use crate::veilid_rng::*;
use ed25519_dalek::{Keypair, PublicKey, Signature}; use ed25519_dalek::{Keypair, PublicKey, Signature};

View File

@ -0,0 +1,38 @@
use super::*;
#[derive(Clone, Debug)]
pub struct ConnectionHandle {
descriptor: ConnectionDescriptor,
channel: flume::Sender<Vec<u8>>,
}
impl ConnectionHandle {
pub(super) fn new(descriptor: ConnectionDescriptor, channel: flume::Sender<Vec<u8>>) -> Self {
Self {
descriptor,
channel,
}
}
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub fn send(&self, message: Vec<u8>) -> Result<(), String> {
self.channel.send(message).map_err(map_to_string)
}
pub async fn send_async(&self, message: Vec<u8>) -> Result<(), String> {
self.channel
.send_async(message)
.await
.map_err(map_to_string)
}
}
impl PartialEq for ConnectionHandle {
fn eq(&self, other: &Self) -> bool {
self.descriptor == other.descriptor
}
}
impl Eq for ConnectionHandle {}

View File

@ -1,7 +1,5 @@
use crate::xx::*; use super::*;
use crate::*;
use alloc::collections::btree_map::Entry; use alloc::collections::btree_map::Entry;
use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressFilterError { pub enum AddressFilterError {

View File

@ -3,8 +3,6 @@ use crate::xx::*;
use connection_table::*; use connection_table::*;
use network_connection::*; use network_connection::*;
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
/////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////
// Connection manager // Connection manager
@ -59,8 +57,7 @@ impl ConnectionManager {
} }
pub async fn shutdown(&self) { pub async fn shutdown(&self) {
// xxx close all connections in the connection table // Drops connection table, which drops all connections in it
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config()); *self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
} }
@ -68,47 +65,48 @@ impl ConnectionManager {
pub async fn get_connection( pub async fn get_connection(
&self, &self,
descriptor: ConnectionDescriptor, descriptor: ConnectionDescriptor,
) -> Option<NetworkConnection> { ) -> Option<ConnectionHandle> {
let mut inner = self.arc.inner.lock().await; let mut inner = self.arc.inner.lock().await;
inner.connection_table.get_connection(descriptor) inner.connection_table.get_connection(descriptor)
} }
// Internal routine to register new connection atomically // Internal routine to register new connection atomically.
fn on_new_connection_internal( // Registers connection in the connection table for later access
// and spawns a message processing loop for the connection
fn on_new_protocol_network_connection(
&self, &self,
inner: &mut ConnectionManagerInner, inner: &mut ConnectionManagerInner,
conn: NetworkConnection, conn: ProtocolNetworkConnection,
) -> Result<(), String> { ) -> Result<ConnectionHandle, String> {
log_net!("on_new_connection_internal: {:?}", conn); log_net!("on_new_protocol_network_connection: {:?}", conn);
let tx = inner
.connection_add_channel_tx
.as_ref()
.ok_or_else(fn_string!("connection channel isn't open yet"))?
.clone();
let receiver_loop_future = Self::process_connection(self.clone(), conn.clone()); // Wrap with NetworkConnection object to start the connection processing loop
tx.try_send(receiver_loop_future) let conn = NetworkConnection::from_protocol(self.clone(), conn);
.map_err(map_to_string) let handle = conn.get_handle();
.map_err(logthru_net!(error "failed to start receiver loop"))?; // Add to the connection table
inner.connection_table.add_connection(conn)?;
// If the receiver loop started successfully, Ok(handle)
// add the new connection to the table
inner.connection_table.add_connection(conn)
} }
// Called by low-level network when any connection-oriented protocol connection appears // Called by low-level network when any connection-oriented protocol connection appears
// either from incoming or outgoing connections. Registers connection in the connection table for later access // either from incoming connections.
// and spawns a message processing loop for the connection pub(super) async fn on_accepted_protocol_network_connection(
pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> { &self,
conn: ProtocolNetworkConnection,
) -> Result<(), String> {
let mut inner = self.arc.inner.lock().await; let mut inner = self.arc.inner.lock().await;
self.on_new_connection_internal(&mut *inner, conn) self.on_new_protocol_network_connection(&mut *inner, conn)
.map(drop)
} }
// Called when we want to create a new connection or get the current one that already exists
// This will kill off any connections that are in conflict with the new connection to be made
// in order to make room for the new connection in the system's connection table
pub async fn get_or_create_connection( pub async fn get_or_create_connection(
&self, &self,
local_addr: Option<SocketAddr>, local_addr: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<ConnectionHandle, String> {
log_net!( log_net!(
"== get_or_create_connection local_addr={:?} dial_info={:?}", "== get_or_create_connection local_addr={:?} dial_info={:?}",
local_addr.green(), local_addr.green(),
@ -146,8 +144,10 @@ impl ConnectionManager {
if local_addr.port() != 0 { if local_addr.port() != 0 {
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] { for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
let pa = PeerAddress::new(descriptor.remote.socket_address, pt); let pa = PeerAddress::new(descriptor.remote.socket_address, pt);
for conn in inner.connection_table.get_connections_by_remote(pa) { for desc in inner
let desc = conn.connection_descriptor(); .connection_table
.get_connection_descriptors_by_remote(pa)
{
let mut kill = false; let mut kill = false;
if let Some(conn_local) = desc.local { if let Some(conn_local) = desc.local {
if (local_addr.ip().is_unspecified() if (local_addr.ip().is_unspecified()
@ -163,7 +163,9 @@ impl ConnectionManager {
local_addr.green(), local_addr.green(),
pa.green() pa.green()
); );
conn.close().await?; if let Err(e) = inner.connection_table.remove_connection(descriptor) {
log_net!(error e);
}
} }
} }
} }
@ -171,73 +173,17 @@ impl ConnectionManager {
} }
// Attempt new connection // Attempt new connection
let conn = NetworkConnection::connect(local_addr, dial_info).await?; let conn = ProtocolNetworkConnection::connect(local_addr, dial_info).await?;
self.on_new_connection_internal(&mut *inner, conn.clone())?; self.on_new_protocol_network_connection(&mut *inner, conn)
Ok(conn)
} }
// Connection receiver loop // Callback from network connection receive loop when it exits
fn process_connection( // cleans up the entry in the connection table
this: ConnectionManager, pub(super) async fn report_connection_finished(&self, descriptor: ConnectionDescriptor) {
conn: NetworkConnection, let mut inner = self.arc.inner.lock().await;
) -> SystemPinBoxFuture<()> { if let Err(e) = inner.connection_table.remove_connection(descriptor) {
log_net!("Starting process_connection loop for {:?}", conn.green()); log_net!(error e);
let network_manager = this.network_manager(); }
Box::pin(async move {
//
let descriptor = conn.connection_descriptor();
let inactivity_timeout = this
.network_manager()
.config()
.get()
.network
.connection_inactivity_timeout_ms;
loop {
// process inactivity timeout on receives only
// if you want a keepalive, it has to be requested from the other side
let message = select! {
res = conn.recv().fuse() => {
match res {
Ok(v) => v,
Err(e) => {
log_net!(debug e);
break;
}
}
}
_ = intf::sleep(inactivity_timeout).fuse()=> {
// timeout
log_net!("connection timeout on {:?}", descriptor.green());
break;
}
};
if let Err(e) = network_manager
.on_recv_envelope(message.as_slice(), descriptor)
.await
{
log_net!(error e);
break;
}
}
log_net!(
"== Connection loop finished local_addr={:?} remote={:?}",
descriptor.local.green(),
descriptor.remote.green()
);
if let Err(e) = this
.arc
.inner
.lock()
.await
.connection_table
.remove_connection(descriptor)
{
log_net!(error e);
}
})
} }
} }

View File

@ -1,7 +1,4 @@
use super::connection_limits::*; use super::*;
use super::network_connection::*;
use crate::xx::*;
use crate::*;
use alloc::collections::btree_map::Entry; use alloc::collections::btree_map::Entry;
use hashlink::LruCache; use hashlink::LruCache;
@ -9,7 +6,7 @@ use hashlink::LruCache;
pub struct ConnectionTable { pub struct ConnectionTable {
max_connections: Vec<usize>, max_connections: Vec<usize>,
conn_by_descriptor: Vec<LruCache<ConnectionDescriptor, NetworkConnection>>, conn_by_descriptor: Vec<LruCache<ConnectionDescriptor, NetworkConnection>>,
conns_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnection>>, descriptors_by_remote: BTreeMap<PeerAddress, Vec<ConnectionDescriptor>>,
address_filter: ConnectionLimits, address_filter: ConnectionLimits,
} }
@ -39,7 +36,7 @@ impl ConnectionTable {
LruCache::new_unbounded(), LruCache::new_unbounded(),
LruCache::new_unbounded(), LruCache::new_unbounded(),
], ],
conns_by_remote: BTreeMap::new(), descriptors_by_remote: BTreeMap::new(),
address_filter: ConnectionLimits::new(config), address_filter: ConnectionLimits::new(config),
} }
} }
@ -60,7 +57,7 @@ impl ConnectionTable {
self.address_filter.add(ip_addr).map_err(map_to_string)?; self.address_filter.add(ip_addr).map_err(map_to_string)?;
// Add the connection to the table // Add the connection to the table
let res = self.conn_by_descriptor[index].insert(descriptor, conn.clone()); let res = self.conn_by_descriptor[index].insert(descriptor.clone(), conn);
assert!(res.is_none()); assert!(res.is_none());
// if we have reached the maximum number of connections per protocol type // if we have reached the maximum number of connections per protocol type
@ -73,49 +70,54 @@ impl ConnectionTable {
} }
// add connection records // add connection records
let conns = self.conns_by_remote.entry(descriptor.remote).or_default(); let descriptors = self
.descriptors_by_remote
.entry(descriptor.remote)
.or_default();
warn!("add_connection: {:?}", conn); warn!("add_connection: {:?}", descriptor);
conns.push(conn); descriptors.push(descriptor);
Ok(()) Ok(())
} }
pub fn get_connection( pub fn get_connection(&mut self, descriptor: ConnectionDescriptor) -> Option<ConnectionHandle> {
&mut self, warn!("get_connection: {:?}", descriptor);
descriptor: ConnectionDescriptor,
) -> Option<NetworkConnection> {
let index = protocol_to_index(descriptor.protocol_type()); let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index].get(&descriptor).cloned(); let out = self.conn_by_descriptor[index].get(&descriptor);
warn!("get_connection: {:?} -> {:?}", descriptor, out); out.map(|c| c.get_handle())
out
} }
pub fn get_last_connection_by_remote( pub fn get_last_connection_by_remote(
&mut self, &mut self,
remote: PeerAddress, remote: PeerAddress,
) -> Option<NetworkConnection> { ) -> Option<ConnectionHandle> {
let out = self warn!("get_last_connection_by_remote: {:?}", remote);
.conns_by_remote let descriptor = self
.descriptors_by_remote
.get(&remote) .get(&remote)
.map(|v| v[(v.len() - 1)].clone()); .map(|v| v[(v.len() - 1)].clone());
warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out); if let Some(descriptor) = descriptor {
if let Some(connection) = &out {
// lru bump // lru bump
let index = protocol_to_index(connection.connection_descriptor().protocol_type()); let index = protocol_to_index(descriptor.protocol_type());
let _ = self.conn_by_descriptor[index].get(&connection.connection_descriptor()); let handle = self.conn_by_descriptor[index]
.get(&descriptor)
.map(|c| c.get_handle());
handle
} else {
None
} }
out
} }
pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec<NetworkConnection> { pub fn get_connection_descriptors_by_remote(
let out = self &mut self,
.conns_by_remote remote: PeerAddress,
) -> Vec<ConnectionDescriptor> {
warn!("get_connection_descriptors_by_remote: {:?}", remote);
self.descriptors_by_remote
.get(&remote) .get(&remote)
.cloned() .cloned()
.unwrap_or_default(); .unwrap_or_default()
warn!("get_connections_by_remote: {:?} -> {:?}", remote, out);
out
} }
pub fn connection_count(&self) -> usize { pub fn connection_count(&self) -> usize {
@ -126,7 +128,7 @@ impl ConnectionTable {
let ip_addr = descriptor.remote.socket_address.to_ip_addr(); let ip_addr = descriptor.remote.socket_address.to_ip_addr();
// conns_by_remote // conns_by_remote
match self.conns_by_remote.entry(descriptor.remote) { match self.descriptors_by_remote.entry(descriptor.remote) {
Entry::Vacant(_) => { Entry::Vacant(_) => {
panic!("inconsistency in connection table") panic!("inconsistency in connection table")
} }
@ -135,7 +137,7 @@ impl ConnectionTable {
// Remove one matching connection from the list // Remove one matching connection from the list
for (n, elem) in v.iter().enumerate() { for (n, elem) in v.iter().enumerate() {
if elem.connection_descriptor() == descriptor { if *elem == descriptor {
v.remove(n); v.remove(n);
break; break;
} }
@ -151,18 +153,14 @@ impl ConnectionTable {
.expect("Inconsistency in connection table"); .expect("Inconsistency in connection table");
} }
pub fn remove_connection( pub fn remove_connection(&mut self, descriptor: ConnectionDescriptor) -> Result<(), String> {
&mut self,
descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> {
warn!("remove_connection: {:?}", descriptor); warn!("remove_connection: {:?}", descriptor);
let index = protocol_to_index(descriptor.protocol_type()); let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index] let _ = self.conn_by_descriptor[index]
.remove(&descriptor) .remove(&descriptor)
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?; .ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
self.remove_connection_records(descriptor); self.remove_connection_records(descriptor);
Ok(())
Ok(out)
} }
} }

View File

@ -5,6 +5,7 @@ mod native;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
mod wasm; mod wasm;
mod connection_handle;
mod connection_limits; mod connection_limits;
mod connection_manager; mod connection_manager;
mod connection_table; mod connection_table;
@ -17,8 +18,9 @@ pub mod tests;
pub use network_connection::*; pub use network_connection::*;
//////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////
use connection_limits::*;
use connection_manager::*; use connection_manager::*;
use connection_handle::*;
use dht::*; use dht::*;
use hashlink::LruCache; use hashlink::LruCache;
use intf::*; use intf::*;
@ -1034,7 +1036,7 @@ impl NetworkManager {
// Called when a packet potentially containing an RPC envelope is received by a low-level // Called when a packet potentially containing an RPC envelope is received by a low-level
// network protocol handler. Processes the envelope, authenticates and decrypts the RPC message // network protocol handler. Processes the envelope, authenticates and decrypts the RPC message
// and passes it to the RPC handler // and passes it to the RPC handler
pub async fn on_recv_envelope( async fn on_recv_envelope(
&self, &self,
data: &[u8], data: &[u8],
descriptor: ConnectionDescriptor, descriptor: ConnectionDescriptor,

View File

@ -341,7 +341,7 @@ impl Network {
log_net!("send_data_to_existing_connection to {:?}", descriptor); log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it // connection exists, send over it
conn.send(data).await.map_err(logthru_net!())?; conn.send_async(data).await.map_err(logthru_net!())?;
// Network accounting // Network accounting
self.network_manager() self.network_manager()
@ -389,7 +389,7 @@ impl Network {
.get_or_create_connection(Some(local_addr), dial_info.clone()) .get_or_create_connection(Some(local_addr), dial_info.clone())
.await?; .await?;
let res = conn.send(data).await.map_err(logthru_net!(error)); let res = conn.send_async(data).await.map_err(logthru_net!(error));
if res.is_ok() { if res.is_ok() {
// Network accounting // Network accounting
self.network_manager() self.network_manager()

View File

@ -7,7 +7,7 @@ use sockets::*;
#[derive(Clone)] #[derive(Clone)]
pub struct ListenerState { pub struct ListenerState {
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>, pub protocol_accept_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>, pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_acceptor: Option<TlsAcceptor>, pub tls_acceptor: Option<TlsAcceptor>,
} }
@ -15,7 +15,7 @@ pub struct ListenerState {
impl ListenerState { impl ListenerState {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
protocol_handlers: Vec::new(), protocol_accept_handlers: Vec::new(),
tls_protocol_handlers: Vec::new(), tls_protocol_handlers: Vec::new(),
tls_acceptor: None, tls_acceptor: None,
} }
@ -46,7 +46,7 @@ impl Network {
addr: SocketAddr, addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>], protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout: u64, tls_connection_initial_timeout: u64,
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<ProtocolNetworkConnection>, String> {
let ts = tls_acceptor let ts = tls_acceptor
.accept(stream) .accept(stream)
.await .await
@ -76,9 +76,9 @@ impl Network {
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
addr: SocketAddr, addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>], protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<ProtocolNetworkConnection>, String> {
for ah in protocol_handlers.iter() { for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr) .on_accept(stream.clone(), tcp_stream.clone(), addr)
.await .await
@ -185,7 +185,7 @@ impl Network {
) )
.await .await
} else { } else {
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers) this.try_handlers(ps, tcp_stream, addr, &ls.protocol_accept_handlers)
.await .await
}; };
@ -207,7 +207,10 @@ impl Network {
}; };
// Register the new connection in the connection manager // Register the new connection in the connection manager
if let Err(e) = connection_manager.on_new_connection(conn).await { if let Err(e) = connection_manager
.on_accepted_protocol_network_connection(conn)
.await
{
log_net!(error "failed to register new connection: {}", e); log_net!(error "failed to register new connection: {}", e);
} }
}) })
@ -270,7 +273,7 @@ impl Network {
)); ));
} else { } else {
ls.write() ls.write()
.protocol_handlers .protocol_accept_handlers
.push(new_protocol_accept_handler( .push(new_protocol_accept_handler(
self.network_manager().config(), self.network_manager().config(),
false, false,

View File

@ -21,7 +21,7 @@ impl ProtocolNetworkConnection {
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<ProtocolNetworkConnection, String> {
match dial_info.protocol_type() { match dial_info.protocol_type() {
ProtocolType::UDP => { ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo"); panic!("Should not connect to UDP dialinfo");
@ -55,6 +55,16 @@ impl ProtocolNetworkConnection {
} }
} }
pub fn descriptor(&self) -> ConnectionDescriptor {
match self {
Self::Dummy(d) => d.descriptor(),
Self::RawTcp(t) => t.descriptor(),
Self::WsAccepted(w) => w.descriptor(),
Self::Ws(w) => w.descriptor(),
Self::Wss(w) => w.descriptor(),
}
}
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
match self { match self {
Self::Dummy(d) => d.close(), Self::Dummy(d) => d.close(),

View File

@ -3,6 +3,7 @@ use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*; use sockets::*;
pub struct RawTcpNetworkConnection { pub struct RawTcpNetworkConnection {
descriptor: ConnectionDescriptor,
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
} }
@ -14,8 +15,20 @@ impl fmt::Debug for RawTcpNetworkConnection {
} }
impl RawTcpNetworkConnection { impl RawTcpNetworkConnection {
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self { pub fn new(
Self { stream, tcp_stream } descriptor: ConnectionDescriptor,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
) -> Self {
Self {
descriptor,
stream,
tcp_stream,
}
}
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
} }
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
@ -33,7 +46,7 @@ impl RawTcpNetworkConnection {
.map_err(logthru_net!()) .map_err(logthru_net!())
} }
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> { async fn send_internal(mut stream: AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
log_net!("sending TCP message of size {}", message.len()); log_net!("sending TCP message of size {}", message.len());
if message.len() > MAX_MESSAGE_SIZE { if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned()); return Err("sending too large TCP message".to_owned());
@ -41,7 +54,6 @@ impl RawTcpNetworkConnection {
let len = message.len() as u16; let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8]; let header = [b'V', b'L', len as u8, (len >> 8) as u8];
let mut stream = self.stream.clone();
stream stream
.write_all(&header) .write_all(&header)
.await .await
@ -54,6 +66,11 @@ impl RawTcpNetworkConnection {
.map_err(logthru_net!()) .map_err(logthru_net!())
} }
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let stream = self.stream.clone();
Self::send_internal(stream, message).await
}
pub async fn recv(&self) -> Result<Vec<u8>, String> { pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4]; let mut header = [0u8; 4];
@ -108,7 +125,7 @@ impl RawTcpProtocolHandler {
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
socket_addr: SocketAddr, socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<ProtocolNetworkConnection>, String> {
log_net!("TCP: on_accept_async: enter"); log_net!("TCP: on_accept_async: enter");
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN]; let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream let peeklen = stream
@ -123,10 +140,11 @@ impl RawTcpProtocolHandler {
ProtocolType::TCP, ProtocolType::TCP,
); );
let local_address = self.inner.lock().local_address; let local_address = self.inner.lock().local_address;
let conn = NetworkConnection::from_protocol( let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)), ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)), stream,
); tcp_stream,
));
log_net!(debug "TCP: on_accept_async from: {}", socket_addr); log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
@ -136,7 +154,7 @@ impl RawTcpProtocolHandler {
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<ProtocolNetworkConnection, String> {
// Get remote socket address to connect to // Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr(); let remote_socket_addr = dial_info.to_socket_addr();
@ -161,13 +179,15 @@ impl RawTcpProtocolHandler {
let ps = AsyncPeekStream::new(ts.clone()); let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it // Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol( let conn = ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(
ConnectionDescriptor { ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)), local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: dial_info.to_peer_address(), remote: dial_info.to_peer_address(),
}, },
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)), ps,
); ts,
));
Ok(conn) Ok(conn)
} }
@ -194,24 +214,15 @@ impl RawTcpProtocolHandler {
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?; .map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
// See what local address we ended up with and turn this into a stream // See what local address we ended up with and turn this into a stream
let actual_local_address = ts // let actual_local_address = ts
.local_addr() // .local_addr()
.map_err(map_to_string) // .map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?; // .map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts.clone()); let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it // Send directly from the raw network connection
let conn = NetworkConnection::from_protocol( // this builds the connection and tears it down immediately after the send
ConnectionDescriptor { RawTcpNetworkConnection::send_internal(ps, data).await
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
conn.send(data).await
} }
} }
@ -221,7 +232,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> { ) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr)) Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
} }
} }

View File

@ -15,6 +15,7 @@ pub struct WebsocketNetworkConnection<T>
where where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
descriptor: ConnectionDescriptor,
stream: CloneStream<WebSocketStream<T>>, stream: CloneStream<WebSocketStream<T>>,
tcp_stream: TcpStream, tcp_stream: TcpStream,
} }
@ -32,13 +33,22 @@ impl<T> WebsocketNetworkConnection<T>
where where
T: io::Read + io::Write + Send + Unpin + 'static, T: io::Read + io::Write + Send + Unpin + 'static,
{ {
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self { pub fn new(
descriptor: ConnectionDescriptor,
stream: WebSocketStream<T>,
tcp_stream: TcpStream,
) -> Self {
Self { Self {
descriptor,
stream: CloneStream::new(stream), stream: CloneStream::new(stream),
tcp_stream, tcp_stream,
} }
} }
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream // Make an attempt to flush the stream
self.stream self.stream
@ -132,7 +142,7 @@ impl WebsocketProtocolHandler {
ps: AsyncPeekStream, ps: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
socket_addr: SocketAddr, socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> { ) -> Result<Option<ProtocolNetworkConnection>, String> {
log_net!("WS: on_accept_async: enter"); log_net!("WS: on_accept_async: enter");
let request_path_len = self.arc.request_path.len() + 2; let request_path_len = self.arc.request_path.len() + 2;
@ -179,25 +189,24 @@ impl WebsocketProtocolHandler {
let peer_addr = let peer_addr =
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type); PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
let conn = NetworkConnection::from_protocol( let conn = ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ConnectionDescriptor::new( ConnectionDescriptor::new(
peer_addr, peer_addr,
SocketAddress::from_socket_addr(self.arc.local_address), SocketAddress::from_socket_addr(self.arc.local_address),
), ),
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new( ws_stream,
ws_stream, tcp_stream, tcp_stream,
)), ));
);
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr); log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
Ok(Some(conn)) Ok(Some(conn))
} }
pub async fn connect( async fn connect_internal(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<ProtocolNetworkConnection, String> {
// Split dial info up // Split dial info up
let (tls, scheme) = match &dial_info { let (tls, scheme) = match &dial_info {
DialInfo::WS(_) => (false, "ws"), DialInfo::WS(_) => (false, "ws"),
@ -251,26 +260,27 @@ impl WebsocketProtocolHandler {
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol( Ok(ProtocolNetworkConnection::Wss(
descriptor, WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
)) ))
} else { } else {
let (ws_stream, _response) = client_async(request, tcp_stream.clone()) let (ws_stream, _response) = client_async(request, tcp_stream.clone())
.await .await
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_net!(error))?; .map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol( Ok(ProtocolNetworkConnection::Ws(
descriptor, WebsocketNetworkConnection::new(descriptor, ws_stream, tcp_stream),
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
)) ))
} }
} }
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
Self::connect_internal(local_address, dial_info).await
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> { pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE { if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned()); return Err("sending too large unbound WS message".to_owned());
@ -281,11 +291,11 @@ impl WebsocketProtocolHandler {
dial_info, dial_info,
); );
let conn = Self::connect(None, dial_info.clone()) let protconn = Self::connect_internal(None, dial_info.clone())
.await .await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?; .map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
conn.send(data).await protconn.send(data).await
} }
} }
@ -295,7 +305,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> { ) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr)) Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
} }
} }

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use crate::xx::*; use futures_util::{FutureExt, StreamExt};
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] { if #[cfg(target_arch = "wasm32")] {
@ -16,7 +16,7 @@ cfg_if::cfg_if! {
stream: AsyncPeekStream, stream: AsyncPeekStream,
tcp_stream: TcpStream, tcp_stream: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>; ) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>>;
} }
pub trait ProtocolAcceptHandlerClone { pub trait ProtocolAcceptHandlerClone {
@ -45,9 +45,14 @@ cfg_if::cfg_if! {
// Dummy protocol network connection for testing // Dummy protocol network connection for testing
#[derive(Debug)] #[derive(Debug)]
pub struct DummyNetworkConnection {} pub struct DummyNetworkConnection {
descriptor: ConnectionDescriptor,
}
impl DummyNetworkConnection { impl DummyNetworkConnection {
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub fn close(&self) -> Result<(), String> { pub fn close(&self) -> Result<(), String> {
Ok(()) Ok(())
} }
@ -62,6 +67,14 @@ impl DummyNetworkConnection {
/////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////
// Top-level protocol independent network connection object // Top-level protocol independent network connection object
#[derive(Clone, Copy, Debug)]
enum RecvLoopAction {
Send,
Recv,
Finish,
Timeout,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct NetworkConnectionStats { pub struct NetworkConnectionStats {
last_message_sent_time: Option<u64>, last_message_sent_time: Option<u64>,
@ -69,107 +82,249 @@ pub struct NetworkConnectionStats {
} }
#[derive(Debug)] #[derive(Debug)]
struct NetworkConnectionInner {
stats: NetworkConnectionStats,
}
#[derive(Debug)]
struct NetworkConnectionArc {
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
established_time: u64,
inner: Mutex<NetworkConnectionInner>,
}
#[derive(Clone, Debug)]
pub struct NetworkConnection { pub struct NetworkConnection {
arc: Arc<NetworkConnectionArc>, descriptor: ConnectionDescriptor,
_processor: Option<JoinHandle<()>>,
established_time: u64,
stats: Arc<Mutex<NetworkConnectionStats>>,
sender: flume::Sender<Vec<u8>>,
} }
impl PartialEq for NetworkConnection {
fn eq(&self, other: &Self) -> bool {
Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc)
}
}
impl Eq for NetworkConnection {}
impl NetworkConnection { impl NetworkConnection {
fn new_inner() -> NetworkConnectionInner { pub(super) fn dummy(descriptor: ConnectionDescriptor) -> Self {
NetworkConnectionInner { // Create handle for sending (dummy is immediately disconnected)
stats: NetworkConnectionStats { let (sender, _receiver) = flume::bounded(intf::get_concurrency() as usize);
Self {
descriptor,
_processor: None,
established_time: intf::get_timestamp(),
stats: Arc::new(Mutex::new(NetworkConnectionStats {
last_message_sent_time: None, last_message_sent_time: None,
last_message_recv_time: None, last_message_recv_time: None,
}, })),
} sender,
}
fn new_arc(
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
) -> NetworkConnectionArc {
NetworkConnectionArc {
descriptor,
protocol_connection,
established_time: intf::get_timestamp(),
inner: Mutex::new(Self::new_inner()),
} }
} }
pub fn dummy(descriptor: ConnectionDescriptor) -> Self { pub(super) fn from_protocol(
NetworkConnection::from_protocol( connection_manager: ConnectionManager,
descriptor,
ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}),
)
}
pub fn from_protocol(
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection, protocol_connection: ProtocolNetworkConnection,
) -> Self { ) -> Self {
Self { // Get timeout
arc: Arc::new(Self::new_arc(descriptor, protocol_connection)), let network_manager = connection_manager.network_manager();
} let inactivity_timeout = network_manager
} .config()
.get()
.network
.connection_inactivity_timeout_ms;
pub async fn connect( // Get descriptor
local_address: Option<SocketAddr>, let descriptor = protocol_connection.descriptor();
dial_info: DialInfo,
) -> Result<NetworkConnection, String> { // Create handle for sending
ProtocolNetworkConnection::connect(local_address, dial_info).await let (sender, receiver) = flume::bounded(intf::get_concurrency() as usize);
// Create stats
let stats = Arc::new(Mutex::new(NetworkConnectionStats {
last_message_sent_time: None,
last_message_recv_time: None,
}));
// Spawn connection processor and pass in protocol connection
let processor = intf::spawn_local(Self::process_connection(
connection_manager,
descriptor.clone(),
receiver,
protocol_connection,
inactivity_timeout,
stats.clone(),
));
// Return the connection
Self {
descriptor,
_processor: Some(processor),
established_time: intf::get_timestamp(),
stats,
sender,
}
} }
pub fn connection_descriptor(&self) -> ConnectionDescriptor { pub fn connection_descriptor(&self) -> ConnectionDescriptor {
self.arc.descriptor self.descriptor.clone()
} }
pub async fn close(&self) -> Result<(), String> { pub fn get_handle(&self) -> ConnectionHandle {
self.arc.protocol_connection.close().await ConnectionHandle::new(self.descriptor.clone(), self.sender.clone())
} }
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> { async fn send_internal(
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
message: Vec<u8>,
) -> Result<(), String> {
let ts = intf::get_timestamp(); let ts = intf::get_timestamp();
let out = self.arc.protocol_connection.send(message).await; let out = protocol_connection.send(message).await;
if out.is_ok() { if out.is_ok() {
let mut inner = self.arc.inner.lock(); let mut stats = stats.lock();
inner.stats.last_message_sent_time.max_assign(Some(ts)); stats.last_message_sent_time.max_assign(Some(ts));
} }
out out
} }
pub async fn recv(&self) -> Result<Vec<u8>, String> { async fn recv_internal(
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
) -> Result<Vec<u8>, String> {
let ts = intf::get_timestamp(); let ts = intf::get_timestamp();
let out = self.arc.protocol_connection.recv().await; let out = protocol_connection.recv().await;
if out.is_ok() { if out.is_ok() {
let mut inner = self.arc.inner.lock(); let mut stats = stats.lock();
inner.stats.last_message_recv_time.max_assign(Some(ts)); stats.last_message_recv_time.max_assign(Some(ts));
} }
out out
} }
pub fn stats(&self) -> NetworkConnectionStats { pub fn stats(&self) -> NetworkConnectionStats {
let inner = self.arc.inner.lock(); let stats = self.stats.lock();
inner.stats.clone() stats.clone()
} }
pub fn established_time(&self) -> u64 { pub fn established_time(&self) -> u64 {
self.arc.established_time self.established_time
}
// Connection receiver loop
fn process_connection(
connection_manager: ConnectionManager,
descriptor: ConnectionDescriptor,
receiver: flume::Receiver<Vec<u8>>,
protocol_connection: ProtocolNetworkConnection,
connection_inactivity_timeout_ms: u32,
stats: Arc<Mutex<NetworkConnectionStats>>,
) -> SystemPinBoxFuture<()> {
Box::pin(async move {
log_net!(
"Starting process_connection loop for {:?}",
descriptor.green()
);
let network_manager = connection_manager.network_manager();
let mut unord = FuturesUnordered::new();
let mut need_receiver = true;
let mut need_sender = true;
// Push mutable timer so we can reset it
// Normally we would use an io::timeout here, but WASM won't support that, so we use a mutable sleep future
let new_timer = || {
intf::sleep(connection_inactivity_timeout_ms).then(|_| async {
// timeout
log_net!("connection timeout on {:?}", descriptor.green());
RecvLoopAction::Timeout
})
};
let timer = MutableFuture::new(new_timer());
unord.push(timer.clone().boxed());
loop {
// Add another message sender future if necessary
if need_sender {
need_sender = false;
unord.push(
receiver
.recv_async()
.then(|res| async {
match res {
Ok(message) => {
// send the packet
if let Err(e) = Self::send_internal(
&protocol_connection,
stats.clone(),
message,
)
.await
{
// Sending the packet along can fail, if so, this connection is dead
log_net!(debug e);
RecvLoopAction::Finish
} else {
RecvLoopAction::Send
}
}
Err(e) => {
// All senders gone, shouldn't happen since we store one alongside the join handle
log_net!(warn e);
RecvLoopAction::Finish
}
}
})
.boxed(),
);
}
// Add another message receiver future if necessary
if need_receiver {
need_sender = false;
unord.push(
Self::recv_internal(&protocol_connection, stats.clone())
.then(|res| async {
match res {
Ok(message) => {
// Pass received messages up to the network manager for processing
if let Err(e) = network_manager
.on_recv_envelope(message.as_slice(), descriptor)
.await
{
log_net!(error e);
RecvLoopAction::Finish
} else {
RecvLoopAction::Recv
}
}
Err(e) => {
// Connection unable to receive, closed
log_net!(warn e);
RecvLoopAction::Finish
}
}
})
.boxed(),
);
}
// Process futures
match unord.next().await {
Some(RecvLoopAction::Send) => {
// Don't reset inactivity timer if we're only sending
need_sender = true;
}
Some(RecvLoopAction::Recv) => {
// Reset inactivity timer since we got something from this connection
timer.set(new_timer());
need_receiver = true;
}
Some(RecvLoopAction::Finish) | Some(RecvLoopAction::Timeout) => {
break;
}
None => {
// Should not happen
unreachable!();
}
}
}
log_net!(
"== Connection loop finished local_addr={:?} remote={:?}",
descriptor.local.green(),
descriptor.remote.green()
);
connection_manager
.report_connection_finished(descriptor)
.await
})
} }
} }

View File

@ -52,10 +52,15 @@ pub async fn test_add_get_remove() {
); );
let c1 = NetworkConnection::dummy(a1); let c1 = NetworkConnection::dummy(a1);
let c1h = c1.get_handle();
let c2 = NetworkConnection::dummy(a2); let c2 = NetworkConnection::dummy(a2);
//let c2h = c2.get_handle();
let c3 = NetworkConnection::dummy(a3); let c3 = NetworkConnection::dummy(a3);
//let c3h = c3.get_handle();
let c4 = NetworkConnection::dummy(a4); let c4 = NetworkConnection::dummy(a4);
//let c4h = c4.get_handle();
let c5 = NetworkConnection::dummy(a5); let c5 = NetworkConnection::dummy(a5);
//let c5h = c5.get_handle();
assert_eq!(a1, c2.connection_descriptor()); assert_eq!(a1, c2.connection_descriptor());
assert_ne!(a3, c4.connection_descriptor()); assert_ne!(a3, c4.connection_descriptor());
@ -63,36 +68,39 @@ pub async fn test_add_get_remove() {
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
assert_eq!(table.get_connection(a1), None); assert_eq!(table.get_connection(a1), None);
table.add_connection(c1.clone()).unwrap(); table.add_connection(c1).unwrap();
assert_eq!(table.connection_count(), 1); assert_eq!(table.connection_count(), 1);
assert_err!(table.remove_connection(a3)); assert_err!(table.remove_connection(a3));
assert_err!(table.remove_connection(a4)); assert_err!(table.remove_connection(a4));
assert_eq!(table.connection_count(), 1); assert_eq!(table.connection_count(), 1);
assert_eq!(table.get_connection(a1), Some(c1.clone())); assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.get_connection(a1), Some(c1.clone())); assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.connection_count(), 1); assert_eq!(table.connection_count(), 1);
assert_err!(table.add_connection(c1.clone())); assert_err!(table.add_connection(c2));
assert_err!(table.add_connection(c2.clone()));
assert_eq!(table.connection_count(), 1); assert_eq!(table.connection_count(), 1);
assert_eq!(table.get_connection(a1), Some(c1.clone())); assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.get_connection(a1), Some(c1.clone())); assert_eq!(table.get_connection(a1), Some(c1h.clone()));
assert_eq!(table.connection_count(), 1); assert_eq!(table.connection_count(), 1);
assert_eq!(table.remove_connection(a2), Ok(c1.clone())); assert_eq!(table.remove_connection(a2), Ok(()));
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
assert_err!(table.remove_connection(a2)); assert_err!(table.remove_connection(a2));
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
assert_eq!(table.get_connection(a2), None); assert_eq!(table.get_connection(a2), None);
assert_eq!(table.get_connection(a1), None); assert_eq!(table.get_connection(a1), None);
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
table.add_connection(c1.clone()).unwrap(); let c1 = NetworkConnection::dummy(a1);
//let c1h = c1.get_handle();
table.add_connection(c1).unwrap();
let c2 = NetworkConnection::dummy(a2);
//let c2h = c2.get_handle();
assert_err!(table.add_connection(c2)); assert_err!(table.add_connection(c2));
table.add_connection(c3.clone()).unwrap(); table.add_connection(c3).unwrap();
table.add_connection(c4.clone()).unwrap(); table.add_connection(c4).unwrap();
assert_eq!(table.connection_count(), 3); assert_eq!(table.connection_count(), 3);
assert_eq!(table.remove_connection(a2), Ok(c1)); assert_eq!(table.remove_connection(a2), Ok(()));
assert_eq!(table.remove_connection(a3), Ok(c3)); assert_eq!(table.remove_connection(a3), Ok(()));
assert_eq!(table.remove_connection(a4), Ok(c4)); assert_eq!(table.remove_connection(a4), Ok(()));
assert_eq!(table.connection_count(), 0); assert_eq!(table.connection_count(), 0);
} }

View File

@ -13,6 +13,7 @@ struct WebsocketNetworkConnectionInner {
#[derive(Clone)] #[derive(Clone)]
pub struct WebsocketNetworkConnection { pub struct WebsocketNetworkConnection {
descriptor: ConnectionDescriptor,
inner: Arc<WebsocketNetworkConnectionInner>, inner: Arc<WebsocketNetworkConnectionInner>,
} }
@ -23,8 +24,11 @@ impl fmt::Debug for WebsocketNetworkConnection {
} }
impl WebsocketNetworkConnection { impl WebsocketNetworkConnection {
pub fn new(ws_meta: WsMeta, ws_stream: WsStream) -> Self { pub fn new(
descriptor: ConnectionDescriptor,
ws_meta: WsMeta, ws_stream: WsStream) -> Self {
Self { Self {
descriptor,
inner: Arc::new(WebsocketNetworkConnectionInner { inner: Arc::new(WebsocketNetworkConnectionInner {
ws_meta, ws_meta,
ws_stream: CloneStream::new(ws_stream), ws_stream: CloneStream::new(ws_stream),
@ -32,6 +36,10 @@ impl WebsocketNetworkConnection {
} }
} }
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub async fn close(&self) -> Result<(), String> { pub async fn close(&self) -> Result<(), String> {
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop) self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
} }
@ -73,7 +81,7 @@ impl WebsocketProtocolHandler {
pub async fn connect( pub async fn connect(
local_address: Option<SocketAddr>, local_address: Option<SocketAddr>,
dial_info: DialInfo, dial_info: DialInfo,
) -> Result<NetworkConnection, String> { ) -> Result<ProtocolNetworkConnection, String> {
assert!(local_address.is_none()); assert!(local_address.is_none());
@ -96,10 +104,10 @@ impl WebsocketProtocolHandler {
// Make our connection descriptor // Make our connection descriptor
Ok(NetworkConnection::from_protocol(ConnectionDescriptor { Ok(ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(ConnectionDescriptor {
local: None, local: None,
remote: dial_info.to_peer_address(), remote: dial_info.to_peer_address(),
},ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(wsmeta, wsio)))) }, wsmeta, wsio)))
} }
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> { pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {

View File

@ -8,6 +8,7 @@ mod eventual_value_clone;
mod ip_addr_port; mod ip_addr_port;
mod ip_extra; mod ip_extra;
mod log_thru; mod log_thru;
mod mutable_future;
mod single_future; mod single_future;
mod single_shot_eventual; mod single_shot_eventual;
mod split_url; mod split_url;
@ -104,6 +105,7 @@ pub use eventual_value::*;
pub use eventual_value_clone::*; pub use eventual_value_clone::*;
pub use ip_addr_port::*; pub use ip_addr_port::*;
pub use ip_extra::*; pub use ip_extra::*;
pub use mutable_future::*;
pub use single_future::*; pub use single_future::*;
pub use single_shot_eventual::*; pub use single_shot_eventual::*;
pub use tick_task::*; pub use tick_task::*;

View File

@ -0,0 +1,33 @@
use super::*;
pub struct MutableFuture<O, T: Future<Output = O>> {
inner: Arc<Mutex<Pin<Box<T>>>>,
}
impl<O, T: Future<Output = O>> MutableFuture<O, T> {
pub fn new(inner: T) -> Self {
Self {
inner: Arc::new(Mutex::new(Box::pin(inner))),
}
}
pub fn set(&self, inner: T) {
*self.inner.lock() = Box::pin(inner);
}
}
impl<O, T: Future<Output = O>> Clone for MutableFuture<O, T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<O, T: Future<Output = O>> Future for MutableFuture<O, T> {
type Output = O;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
let mut inner = self.inner.lock();
T::poll(inner.as_mut(), cx)
}
}