refactor network manager

This commit is contained in:
John Smith
2022-05-31 19:54:52 -04:00
parent ad4b6328ac
commit 8148c37708
51 changed files with 500 additions and 389 deletions

View File

@@ -0,0 +1,218 @@
use crate::xx::*;
use crate::*;
use alloc::collections::btree_map::Entry;
use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressFilterError {
CountExceeded,
RateExceeded,
}
impl fmt::Display for AddressFilterError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
match *self {
Self::CountExceeded => "Count exceeded",
Self::RateExceeded => "Rate exceeded",
}
)
}
}
impl std::error::Error for AddressFilterError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AddressNotInTableError {}
impl fmt::Display for AddressNotInTableError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Address not in table")
}
}
impl std::error::Error for AddressNotInTableError {}
#[derive(Debug)]
pub struct ConnectionLimits {
max_connections_per_ip4: usize,
max_connections_per_ip6_prefix: usize,
max_connections_per_ip6_prefix_size: usize,
max_connection_frequency_per_min: usize,
conn_count_by_ip4: BTreeMap<Ipv4Addr, usize>,
conn_count_by_ip6_prefix: BTreeMap<Ipv6Addr, usize>,
conn_timestamps_by_ip4: BTreeMap<Ipv4Addr, Vec<u64>>,
conn_timestamps_by_ip6_prefix: BTreeMap<Ipv6Addr, Vec<u64>>,
}
impl ConnectionLimits {
pub fn new(config: VeilidConfig) -> Self {
let c = config.get();
Self {
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
as usize,
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min as usize,
conn_count_by_ip4: BTreeMap::new(),
conn_count_by_ip6_prefix: BTreeMap::new(),
conn_timestamps_by_ip4: BTreeMap::new(),
conn_timestamps_by_ip6_prefix: BTreeMap::new(),
}
}
// Converts an ip to a ip block by applying a netmask
// to the host part of the ip address
// ipv4 addresses are treated as single hosts
// ipv6 addresses are treated as prefix allocated blocks
fn ip_to_ipblock(&self, addr: IpAddr) -> IpAddr {
match addr {
IpAddr::V4(_) => addr,
IpAddr::V6(v6) => {
let mut hostlen = 128usize.saturating_sub(self.max_connections_per_ip6_prefix_size);
let mut out = v6.octets();
for i in (0..16).rev() {
if hostlen >= 8 {
out[i] = 0xFF;
hostlen -= 8;
} else {
out[i] |= !(0xFFu8 << hostlen);
break;
}
}
IpAddr::V6(Ipv6Addr::from(out))
}
}
}
fn purge_old_timestamps(&mut self, cur_ts: u64) {
// v4
{
let mut dead_keys = Vec::<Ipv4Addr>::new();
for (key, value) in &mut self.conn_timestamps_by_ip4 {
value.retain(|v| {
// keep timestamps that are less than a minute away
cur_ts.saturating_sub(*v) < 60_000_000u64
});
if value.is_empty() {
dead_keys.push(*key);
}
}
for key in dead_keys {
self.conn_timestamps_by_ip4.remove(&key);
}
}
// v6
{
let mut dead_keys = Vec::<Ipv6Addr>::new();
for (key, value) in &mut self.conn_timestamps_by_ip6_prefix {
value.retain(|v| {
// keep timestamps that are less than a minute away
cur_ts.saturating_sub(*v) < 60_000_000u64
});
if value.is_empty() {
dead_keys.push(*key);
}
}
for key in dead_keys {
self.conn_timestamps_by_ip6_prefix.remove(&key);
}
}
}
pub fn add(&mut self, addr: IpAddr) -> Result<(), AddressFilterError> {
let ipblock = self.ip_to_ipblock(addr);
let ts = intf::get_timestamp();
self.purge_old_timestamps(ts);
match ipblock {
IpAddr::V4(v4) => {
// See if we have too many connections from this ip block
let cnt = &mut *self.conn_count_by_ip4.entry(v4).or_default();
assert!(*cnt <= self.max_connections_per_ip4);
if *cnt == self.max_connections_per_ip4 {
warn!("address filter count exceeded: {:?}", v4);
return Err(AddressFilterError::CountExceeded);
}
// See if this ip block has connected too frequently
let tstamps = &mut self.conn_timestamps_by_ip4.entry(v4).or_default();
tstamps.retain(|v| {
// keep timestamps that are less than a minute away
ts.saturating_sub(*v) < 60_000_000u64
});
assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.max_connection_frequency_per_min {
warn!("address filter rate exceeded: {:?}", v4);
return Err(AddressFilterError::RateExceeded);
}
// If it's okay, add the counts and timestamps
*cnt += 1;
tstamps.push(ts);
}
IpAddr::V6(v6) => {
// See if we have too many connections from this ip block
let cnt = &mut *self.conn_count_by_ip6_prefix.entry(v6).or_default();
assert!(*cnt <= self.max_connections_per_ip6_prefix);
if *cnt == self.max_connections_per_ip6_prefix {
warn!("address filter count exceeded: {:?}", v6);
return Err(AddressFilterError::CountExceeded);
}
// See if this ip block has connected too frequently
let tstamps = &mut self.conn_timestamps_by_ip6_prefix.entry(v6).or_default();
assert!(tstamps.len() <= self.max_connection_frequency_per_min);
if tstamps.len() == self.max_connection_frequency_per_min {
warn!("address filter rate exceeded: {:?}", v6);
return Err(AddressFilterError::RateExceeded);
}
// If it's okay, add the counts and timestamps
*cnt += 1;
tstamps.push(ts);
}
}
Ok(())
}
pub fn remove(&mut self, addr: IpAddr) -> Result<(), AddressNotInTableError> {
let ipblock = self.ip_to_ipblock(addr);
let ts = intf::get_timestamp();
self.purge_old_timestamps(ts);
match ipblock {
IpAddr::V4(v4) => {
match self.conn_count_by_ip4.entry(v4) {
Entry::Vacant(_) => {
return Err(AddressNotInTableError {});
}
Entry::Occupied(mut o) => {
let cnt = o.get_mut();
assert!(*cnt > 0);
if *cnt == 0 {
self.conn_count_by_ip4.remove(&v4);
} else {
*cnt -= 1;
}
}
};
}
IpAddr::V6(v6) => {
match self.conn_count_by_ip6_prefix.entry(v6) {
Entry::Vacant(_) => {
return Err(AddressNotInTableError {});
}
Entry::Occupied(mut o) => {
let cnt = o.get_mut();
assert!(*cnt > 0);
if *cnt == 0 {
self.conn_count_by_ip6_prefix.remove(&v6);
} else {
*cnt -= 1;
}
}
};
}
}
Ok(())
}
}

View File

@@ -0,0 +1,243 @@
use super::*;
use crate::xx::*;
use connection_table::*;
use network_connection::*;
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
///////////////////////////////////////////////////////////
// Connection manager
#[derive(Debug)]
struct ConnectionManagerInner {
connection_table: ConnectionTable,
}
struct ConnectionManagerArc {
network_manager: NetworkManager,
inner: AsyncMutex<ConnectionManagerInner>,
}
impl core::fmt::Debug for ConnectionManagerArc {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConnectionManagerArc")
.field("inner", &self.inner)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ConnectionManager {
arc: Arc<ConnectionManagerArc>,
}
impl ConnectionManager {
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
ConnectionManagerInner {
connection_table: ConnectionTable::new(config),
}
}
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
let config = network_manager.config();
ConnectionManagerArc {
network_manager,
inner: AsyncMutex::new(Self::new_inner(config)),
}
}
pub fn new(network_manager: NetworkManager) -> Self {
Self {
arc: Arc::new(Self::new_arc(network_manager)),
}
}
pub fn network_manager(&self) -> NetworkManager {
self.arc.network_manager.clone()
}
pub async fn startup(&self) {
trace!("startup connection manager");
//let mut inner = self.arc.inner.lock().await;
}
pub async fn shutdown(&self) {
// xxx close all connections in the connection table
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
}
// Returns a network connection if one already is established
pub async fn get_connection(
&self,
descriptor: ConnectionDescriptor,
) -> Option<NetworkConnection> {
let mut inner = self.arc.inner.lock().await;
inner.connection_table.get_connection(descriptor)
}
// Internal routine to register new connection atomically
fn on_new_connection_internal(
&self,
inner: &mut ConnectionManagerInner,
conn: NetworkConnection,
) -> Result<(), String> {
log_net!("on_new_connection_internal: {:?}", conn);
let tx = inner
.connection_add_channel_tx
.as_ref()
.ok_or_else(fn_string!("connection channel isn't open yet"))?
.clone();
let receiver_loop_future = Self::process_connection(self.clone(), conn.clone());
tx.try_send(receiver_loop_future)
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to start receiver loop"))?;
// If the receiver loop started successfully,
// add the new connection to the table
inner.connection_table.add_connection(conn)
}
// Called by low-level network when any connection-oriented protocol connection appears
// either from incoming or outgoing connections. Registers connection in the connection table for later access
// and spawns a message processing loop for the connection
pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> {
let mut inner = self.arc.inner.lock().await;
self.on_new_connection_internal(&mut *inner, conn)
}
pub async fn get_or_create_connection(
&self,
local_addr: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
log_net!(
"== get_or_create_connection local_addr={:?} dial_info={:?}",
local_addr.green(),
dial_info.green()
);
let peer_address = dial_info.to_peer_address();
let descriptor = match local_addr {
Some(la) => {
ConnectionDescriptor::new(peer_address, SocketAddress::from_socket_addr(la))
}
None => ConnectionDescriptor::new_no_local(peer_address),
};
// If any connection to this remote exists that has the same protocol, return it
// Any connection will do, we don't have to match the local address
let mut inner = self.arc.inner.lock().await;
if let Some(conn) = inner
.connection_table
.get_last_connection_by_remote(descriptor.remote)
{
log_net!(
"== Returning existing connection local_addr={:?} peer_address={:?}",
local_addr.green(),
peer_address.green()
);
return Ok(conn);
}
// Drop any other protocols connections that have the same local addr
// otherwise this connection won't succeed due to binding
if let Some(local_addr) = local_addr {
if local_addr.port() != 0 {
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
let pa = PeerAddress::new(descriptor.remote.socket_address, pt);
for conn in inner.connection_table.get_connections_by_remote(pa) {
let desc = conn.connection_descriptor();
let mut kill = false;
if let Some(conn_local) = desc.local {
if (local_addr.ip().is_unspecified()
|| (local_addr.ip() == conn_local.to_ip_addr()))
&& conn_local.port() == local_addr.port()
{
kill = true;
}
}
if kill {
log_net!(debug
">< Terminating connection local_addr={:?} peer_address={:?}",
local_addr.green(),
pa.green()
);
conn.close().await?;
}
}
}
}
}
// Attempt new connection
let conn = NetworkConnection::connect(local_addr, dial_info).await?;
self.on_new_connection_internal(&mut *inner, conn.clone())?;
Ok(conn)
}
// Connection receiver loop
fn process_connection(
this: ConnectionManager,
conn: NetworkConnection,
) -> SystemPinBoxFuture<()> {
log_net!("Starting process_connection loop for {:?}", conn.green());
let network_manager = this.network_manager();
Box::pin(async move {
//
let descriptor = conn.connection_descriptor();
let inactivity_timeout = this
.network_manager()
.config()
.get()
.network
.connection_inactivity_timeout_ms;
loop {
// process inactivity timeout on receives only
// if you want a keepalive, it has to be requested from the other side
let message = select! {
res = conn.recv().fuse() => {
match res {
Ok(v) => v,
Err(e) => {
log_net!(debug e);
break;
}
}
}
_ = intf::sleep(inactivity_timeout).fuse()=> {
// timeout
log_net!("connection timeout on {:?}", descriptor.green());
break;
}
};
if let Err(e) = network_manager
.on_recv_envelope(message.as_slice(), descriptor)
.await
{
log_net!(error e);
break;
}
}
log_net!(
"== Connection loop finished local_addr={:?} remote={:?}",
descriptor.local.green(),
descriptor.remote.green()
);
if let Err(e) = this
.arc
.inner
.lock()
.await
.connection_table
.remove_connection(descriptor)
{
log_net!(error e);
}
})
}
}

View File

@@ -0,0 +1,168 @@
use super::connection_limits::*;
use super::network_connection::*;
use crate::xx::*;
use crate::*;
use alloc::collections::btree_map::Entry;
use hashlink::LruCache;
#[derive(Debug)]
pub struct ConnectionTable {
max_connections: Vec<usize>,
conn_by_descriptor: Vec<LruCache<ConnectionDescriptor, NetworkConnection>>,
conns_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnection>>,
address_filter: ConnectionLimits,
}
fn protocol_to_index(protocol: ProtocolType) -> usize {
match protocol {
ProtocolType::TCP => 0,
ProtocolType::WS => 1,
ProtocolType::WSS => 2,
ProtocolType::UDP => panic!("not a connection-oriented protocol"),
}
}
impl ConnectionTable {
pub fn new(config: VeilidConfig) -> Self {
let max_connections = {
let c = config.get();
vec![
c.network.protocol.tcp.max_connections as usize,
c.network.protocol.ws.max_connections as usize,
c.network.protocol.wss.max_connections as usize,
]
};
Self {
max_connections,
conn_by_descriptor: vec![
LruCache::new_unbounded(),
LruCache::new_unbounded(),
LruCache::new_unbounded(),
],
conns_by_remote: BTreeMap::new(),
address_filter: ConnectionLimits::new(config),
}
}
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
let descriptor = conn.connection_descriptor();
let ip_addr = descriptor.remote.socket_address.to_ip_addr();
let index = protocol_to_index(descriptor.protocol_type());
if self.conn_by_descriptor[index].contains_key(&descriptor) {
return Err(format!(
"Connection already added to table: {:?}",
descriptor
));
}
// Filter by ip for connection limits
self.address_filter.add(ip_addr).map_err(map_to_string)?;
// Add the connection to the table
let res = self.conn_by_descriptor[index].insert(descriptor, conn.clone());
assert!(res.is_none());
// if we have reached the maximum number of connections per protocol type
// then drop the least recently used connection
if self.conn_by_descriptor[index].len() > self.max_connections[index] {
if let Some((lruk, _)) = self.conn_by_descriptor[index].remove_lru() {
warn!("XX: connection lru out: {:?}", lruk);
self.remove_connection_records(lruk);
}
}
// add connection records
let conns = self.conns_by_remote.entry(descriptor.remote).or_default();
warn!("add_connection: {:?}", conn);
conns.push(conn);
Ok(())
}
pub fn get_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Option<NetworkConnection> {
let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index].get(&descriptor).cloned();
warn!("get_connection: {:?} -> {:?}", descriptor, out);
out
}
pub fn get_last_connection_by_remote(
&mut self,
remote: PeerAddress,
) -> Option<NetworkConnection> {
let out = self
.conns_by_remote
.get(&remote)
.map(|v| v[(v.len() - 1)].clone());
warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out);
if let Some(connection) = &out {
// lru bump
let index = protocol_to_index(connection.connection_descriptor().protocol_type());
let _ = self.conn_by_descriptor[index].get(&connection.connection_descriptor());
}
out
}
pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec<NetworkConnection> {
let out = self
.conns_by_remote
.get(&remote)
.cloned()
.unwrap_or_default();
warn!("get_connections_by_remote: {:?} -> {:?}", remote, out);
out
}
pub fn connection_count(&self) -> usize {
self.conn_by_descriptor.iter().fold(0, |b, c| b + c.len())
}
fn remove_connection_records(&mut self, descriptor: ConnectionDescriptor) {
let ip_addr = descriptor.remote.socket_address.to_ip_addr();
// conns_by_remote
match self.conns_by_remote.entry(descriptor.remote) {
Entry::Vacant(_) => {
panic!("inconsistency in connection table")
}
Entry::Occupied(mut o) => {
let v = o.get_mut();
// Remove one matching connection from the list
for (n, elem) in v.iter().enumerate() {
if elem.connection_descriptor() == descriptor {
v.remove(n);
break;
}
}
// No connections left for this remote, remove the entry from conns_by_remote
if v.is_empty() {
o.remove_entry();
}
}
}
self.address_filter
.remove(ip_addr)
.expect("Inconsistency in connection table");
}
pub fn remove_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> {
warn!("remove_connection: {:?}", descriptor);
let index = protocol_to_index(descriptor.protocol_type());
let out = self.conn_by_descriptor[index]
.remove(&descriptor)
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
self.remove_connection_records(descriptor);
Ok(out)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,550 @@
mod network_class_discovery;
mod network_tcp;
mod network_udp;
mod protocol;
mod start_protocols;
use crate::intf::*;
use crate::network_manager::*;
use crate::routing_table::*;
use connection_manager::*;
use network_tcp::*;
use protocol::tcp::RawTcpProtocolHandler;
use protocol::udp::RawUdpProtocolHandler;
use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*;
use utils::network_interfaces::*;
use async_std::io;
use async_std::net::*;
use async_tls::TlsAcceptor;
use futures_util::StreamExt;
// xxx: rustls ^0.20
//use rustls::{server::NoClientAuth, Certificate, PrivateKey, ServerConfig};
use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::time::Duration;
/////////////////////////////////////////////////////////////////
pub const PEEK_DETECT_LEN: usize = 64;
/////////////////////////////////////////////////////////////////
struct NetworkInner {
routing_table: RoutingTable,
network_manager: NetworkManager,
network_started: bool,
network_needs_restart: bool,
protocol_config: Option<ProtocolConfig>,
static_public_dialinfo: ProtocolSet,
network_class: Option<NetworkClass>,
join_handles: Vec<JoinHandle<()>>,
udp_port: u16,
tcp_port: u16,
ws_port: u16,
wss_port: u16,
interfaces: NetworkInterfaces,
// udp
bound_first_udp: BTreeMap<u16, Option<(socket2::Socket, socket2::Socket)>>,
inbound_udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
outbound_udpv4_protocol_handler: Option<RawUdpProtocolHandler>,
outbound_udpv6_protocol_handler: Option<RawUdpProtocolHandler>,
//tcp
bound_first_tcp: BTreeMap<u16, Option<(socket2::Socket, socket2::Socket)>>,
tls_acceptor: Option<TlsAcceptor>,
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
}
struct NetworkUnlockedInner {
// Background processes
update_network_class_task: TickTask,
}
#[derive(Clone)]
pub struct Network {
config: VeilidConfig,
inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>,
}
impl Network {
fn new_inner(network_manager: NetworkManager) -> NetworkInner {
NetworkInner {
routing_table: network_manager.routing_table(),
network_manager,
network_started: false,
network_needs_restart: false,
protocol_config: None,
static_public_dialinfo: ProtocolSet::empty(),
network_class: None,
join_handles: Vec::new(),
udp_port: 0u16,
tcp_port: 0u16,
ws_port: 0u16,
wss_port: 0u16,
interfaces: NetworkInterfaces::new(),
bound_first_udp: BTreeMap::new(),
inbound_udp_protocol_handlers: BTreeMap::new(),
outbound_udpv4_protocol_handler: None,
outbound_udpv6_protocol_handler: None,
bound_first_tcp: BTreeMap::new(),
tls_acceptor: None,
listener_states: BTreeMap::new(),
}
}
fn new_unlocked_inner() -> NetworkUnlockedInner {
NetworkUnlockedInner {
update_network_class_task: TickTask::new(1),
}
}
pub fn new(network_manager: NetworkManager) -> Self {
let this = Self {
config: network_manager.config(),
inner: Arc::new(Mutex::new(Self::new_inner(network_manager))),
unlocked_inner: Arc::new(Self::new_unlocked_inner()),
};
// Set update network class tick task
{
let this2 = this.clone();
this.unlocked_inner
.update_network_class_task
.set_routine(move |l, t| {
Box::pin(this2.clone().update_network_class_task_routine(l, t))
});
}
this
}
fn network_manager(&self) -> NetworkManager {
self.inner.lock().network_manager.clone()
}
fn routing_table(&self) -> RoutingTable {
self.inner.lock().routing_table.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.inner.lock().network_manager.connection_manager()
}
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
let cvec = certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
Ok(cvec.into_iter().map(Certificate).collect())
}
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
{
if let Ok(v) = rsa_private_keys(&mut BufReader::new(File::open(path)?)) {
if !v.is_empty() {
return Ok(v.into_iter().map(PrivateKey).collect());
}
}
}
{
if let Ok(v) = pkcs8_private_keys(&mut BufReader::new(File::open(path)?)) {
if !v.is_empty() {
return Ok(v.into_iter().map(PrivateKey).collect());
}
}
}
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid TLS private key",
))
}
fn load_server_config(&self) -> io::Result<ServerConfig> {
let c = self.config.get();
//
trace!(
"loading certificate from {}",
c.network.tls.certificate_path
);
let certs = Self::load_certs(&PathBuf::from(&c.network.tls.certificate_path))?;
trace!("loaded {} certificates", certs.len());
if certs.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, format!("Certificates at {} could not be loaded.\nEnsure it is in PEM format, beginning with '-----BEGIN CERTIFICATE-----'",c.network.tls.certificate_path)));
}
//
trace!(
"loading private key from {}",
c.network.tls.private_key_path
);
let mut keys = Self::load_keys(&PathBuf::from(&c.network.tls.private_key_path))?;
trace!("loaded {} keys", keys.len());
if keys.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, format!("Private key at {} could not be loaded.\nEnsure it is unencrypted and in RSA or PKCS8 format, beginning with '-----BEGIN RSA PRIVATE KEY-----' or '-----BEGIN PRIVATE KEY-----'",c.network.tls.private_key_path)));
}
// xxx: rustls ^0.20
// let mut config = ServerConfig::builder()
// .with_safe_defaults()
// .with_no_client_auth()
// .with_single_cert(certs, keys.remove(0))
// .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let mut config = ServerConfig::new(NoClientAuth::new());
config
.set_single_cert(certs, keys.remove(0))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
Ok(config)
}
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
let mut inner = self.inner.lock();
inner.join_handles.push(jh);
}
fn translate_unspecified_address(inner: &NetworkInner, from: &SocketAddr) -> Vec<SocketAddr> {
if !from.ip().is_unspecified() {
vec![*from]
} else {
inner
.interfaces
.best_addresses()
.iter()
.filter_map(|a| {
// We create sockets that are only ipv6 or ipv6 (not dual, so only translate matching unspecified address)
if (a.is_ipv4() && from.is_ipv4()) || (a.is_ipv6() && from.is_ipv6()) {
Some(SocketAddr::new(*a, from.port()))
} else {
None
}
})
.collect()
}
}
fn get_preferred_local_address(&self, dial_info: &DialInfo) -> SocketAddr {
let inner = self.inner.lock();
let local_port = match dial_info.protocol_type() {
ProtocolType::UDP => inner.udp_port,
ProtocolType::TCP => inner.tcp_port,
ProtocolType::WS => inner.ws_port,
ProtocolType::WSS => inner.wss_port,
};
match dial_info.address_type() {
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), local_port),
AddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), local_port),
}
}
pub fn with_interface_addresses<F, R>(&self, f: F) -> R
where
F: FnOnce(&[IpAddr]) -> R,
{
let inner = self.inner.lock();
inner.interfaces.with_best_addresses(f)
}
// See if our interface addresses have changed, if so we need to punt the network
// and redo all our addresses. This is overkill, but anything more accurate
// would require inspection of routing tables that we dont want to bother with
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
let mut inner = self.inner.lock();
if !inner.interfaces.refresh().await? {
return Ok(false);
}
inner.network_needs_restart = true;
Ok(true)
}
////////////////////////////////////////////////////////////
// Send data to a dial info, unbound, using a new connection from a random port
// This creates a short-lived connection in the case of connection-oriented protocols
// for the purpose of sending this one message.
// This bypasses the connection table as it is not a 'node to node' connection.
pub async fn send_data_unbound_to_dial_info(
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
let data_len = data.len();
let res = match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
}
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
.await
.map_err(logthru_net!())
}
};
if res.is_ok() {
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
}
res
}
pub async fn send_data_to_existing_connection(
&self,
descriptor: ConnectionDescriptor,
data: Vec<u8>,
) -> Result<Option<Vec<u8>>, String> {
let data_len = data.len();
// Handle connectionless protocol
if descriptor.protocol_type() == ProtocolType::UDP {
// send over the best udp socket we have bound since UDP is not connection oriented
let peer_socket_addr = descriptor.remote.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(
&peer_socket_addr,
&descriptor.local.map(|sa| sa.to_socket_addr()),
) {
log_net!(
"send_data_to_existing_connection connectionless to {:?}",
descriptor
);
ph.clone()
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!())?;
// Network accounting
self.network_manager()
.stats_packet_sent(peer_socket_addr.ip(), data_len as u64);
// Data was consumed
return Ok(None);
}
}
// Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it
conn.send(data).await.map_err(logthru_net!())?;
// Network accounting
self.network_manager()
.stats_packet_sent(descriptor.remote.to_socket_addr().ip(), data_len as u64);
// Data was consumed
Ok(None)
} else {
// Connection or didn't exist
// Pass the data back out so we don't own it any more
Ok(Some(data))
}
}
// Send data directly to a dial info, possibly without knowing which node it is going to
pub async fn send_data_to_dial_info(
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
let data_len = data.len();
// Handle connectionless protocol
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
let res = ph
.send_message(data, peer_socket_addr)
.await
.map_err(logthru_net!());
if res.is_ok() {
// Network accounting
self.network_manager()
.stats_packet_sent(peer_socket_addr.ip(), data_len as u64);
}
return res;
}
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
.map_err(logthru_net!(error));
}
// Handle connection-oriented protocols
let local_addr = self.get_preferred_local_address(&dial_info);
let conn = self
.connection_manager()
.get_or_create_connection(Some(local_addr), dial_info.clone())
.await?;
let res = conn.send(data).await.map_err(logthru_net!(error));
if res.is_ok() {
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
}
res
}
/////////////////////////////////////////////////////////////////
pub fn get_protocol_config(&self) -> Option<ProtocolConfig> {
self.inner.lock().protocol_config
}
pub async fn startup(&self) -> Result<(), String> {
trace!("startup network");
// initialize interfaces
let mut interfaces = NetworkInterfaces::new();
interfaces.refresh().await?;
self.inner.lock().interfaces = interfaces;
// get protocol config
let protocol_config = {
let c = self.config.get();
let mut inbound = ProtocolSet::new();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
inbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
inbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
inbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
inbound.insert(ProtocolType::WSS);
}
let mut outbound = ProtocolSet::new();
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
outbound.insert(ProtocolType::UDP);
}
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
outbound.insert(ProtocolType::TCP);
}
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
outbound.insert(ProtocolType::WSS);
}
ProtocolConfig { inbound, outbound }
};
self.inner.lock().protocol_config = Some(protocol_config);
// start listeners
if protocol_config.inbound.contains(ProtocolType::UDP) {
self.start_udp_listeners().await?;
}
if protocol_config.inbound.contains(ProtocolType::WS) {
self.start_ws_listeners().await?;
}
if protocol_config.inbound.contains(ProtocolType::WSS) {
self.start_wss_listeners().await?;
}
if protocol_config.inbound.contains(ProtocolType::TCP) {
self.start_tcp_listeners().await?;
}
// release caches of available listener ports
// this releases the 'first bound' ports we use to guarantee
// that we have ports available to us
self.free_bound_first_ports();
// If we have static public dialinfo, upgrade our network class
{
let mut inner = self.inner.lock();
if !inner.static_public_dialinfo.is_empty() {
inner.network_class = Some(NetworkClass::InboundCapable);
}
}
info!("network started");
self.inner.lock().network_started = true;
// Inform routing table entries that our dial info has changed
self.routing_table().send_node_info_updates();
Ok(())
}
pub fn needs_restart(&self) -> bool {
self.inner.lock().network_needs_restart
}
pub fn is_started(&self) -> bool {
self.inner.lock().network_started
}
pub fn restart_network(&self) {
self.inner.lock().network_needs_restart = true;
}
pub async fn shutdown(&self) {
info!("stopping network");
let network_manager = self.network_manager();
let routing_table = self.routing_table();
// Cancel all tasks
if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await {
warn!("update_network_class_task not cancelled: {}", e);
}
// Drop all dial info
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
// Reset state including network class
// Cancels all async background tasks by dropping join handles
*self.inner.lock() = Self::new_inner(network_manager);
info!("network stopped");
}
//////////////////////////////////////////
pub fn get_network_class(&self) -> Option<NetworkClass> {
let inner = self.inner.lock();
inner.network_class
}
pub fn reset_network_class(&self) {
let mut inner = self.inner.lock();
inner.network_class = None;
}
//////////////////////////////////////////
pub async fn tick(&self) -> Result<(), String> {
let network_class = self.get_network_class().unwrap_or(NetworkClass::Invalid);
let routing_table = self.routing_table();
// If we need to figure out our network class, tick the task for it
if network_class == NetworkClass::Invalid {
let rth = routing_table.get_routing_table_health();
// Need at least two entries to do this
if rth.unreliable_entry_count + rth.reliable_entry_count >= 2 {
self.unlocked_inner.update_network_class_task.tick().await?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,611 @@
use super::*;
use futures_util::stream::FuturesUnordered;
use futures_util::FutureExt;
struct DetectedPublicDialInfo {
dial_info: DialInfo,
class: DialInfoClass,
}
struct DiscoveryContextInner {
// per-protocol
intf_addrs: Option<Vec<SocketAddress>>,
protocol_type: Option<ProtocolType>,
address_type: Option<AddressType>,
// first node contacted
external_1_dial_info: Option<DialInfo>,
external_1_address: Option<SocketAddress>,
node_1: Option<NodeRef>,
// detected public dialinfo
detected_network_class: Option<NetworkClass>,
detected_public_dial_info: Option<DetectedPublicDialInfo>,
}
pub struct DiscoveryContext {
routing_table: RoutingTable,
net: Network,
inner: Arc<Mutex<DiscoveryContextInner>>,
}
impl DiscoveryContext {
pub fn new(routing_table: RoutingTable, net: Network) -> Self {
Self {
routing_table,
net,
inner: Arc::new(Mutex::new(DiscoveryContextInner {
// per-protocol
intf_addrs: None,
protocol_type: None,
address_type: None,
external_1_dial_info: None,
external_1_address: None,
node_1: None,
detected_network_class: None,
detected_public_dial_info: None,
})),
}
}
///////
// Utilities
// Pick the best network class we have seen so far
pub fn set_detected_network_class(&self, network_class: NetworkClass) {
let mut inner = self.inner.lock();
log_net!( debug
"=== set_detected_network_class {:?} {:?}: {:?} ===",
inner.protocol_type,
inner.address_type,
network_class
);
inner.detected_network_class = Some(network_class);
}
pub fn set_detected_public_dial_info(&self, dial_info: DialInfo, class: DialInfoClass) {
let mut inner = self.inner.lock();
log_net!( debug
"=== set_detected_public_dial_info {:?} {:?}: {} {:?} ===",
inner.protocol_type,
inner.address_type,
dial_info,
class
);
inner.detected_public_dial_info = Some(DetectedPublicDialInfo { dial_info, class });
}
// Ask for a public address check from a particular noderef
// This is done over the normal port using RPC
async fn request_public_address(&self, node_ref: NodeRef) -> Option<SocketAddress> {
let rpc = self.routing_table.rpc_processor();
rpc.rpc_call_status(node_ref.clone())
.await
.map_err(logthru_net!(
"failed to get status answer from {:?}",
node_ref
))
.map(|sa| {
let ret = sa.sender_info.socket_address;
log_net!("request_public_address: {:?}", ret);
ret
})
.unwrap_or(None)
}
// find fast peers with a particular address type, and ask them to tell us what our external address is
// This is done over the normal port using RPC
async fn discover_external_address(
&self,
protocol_type: ProtocolType,
address_type: AddressType,
ignore_node: Option<DHTKey>,
) -> Option<(SocketAddress, NodeRef)> {
let filter = DialInfoFilter::global()
.with_protocol_type(protocol_type)
.with_address_type(address_type);
let peers = self.routing_table.find_fast_public_nodes_filtered(&filter);
if peers.is_empty() {
log_net!("no peers of type '{:?}'", filter);
return None;
}
for peer in peers {
if let Some(ignore_node) = ignore_node {
if peer.node_id() == ignore_node {
continue;
}
}
if let Some(sa) = self.request_public_address(peer.clone()).await {
return Some((sa, peer));
}
}
log_net!("no peers responded with an external address");
None
}
// This pulls the already-detected local interface dial info from the routing table
fn get_local_addresses(
&self,
protocol_type: ProtocolType,
address_type: AddressType,
) -> Vec<SocketAddress> {
let filter = DialInfoFilter::local()
.with_protocol_type(protocol_type)
.with_address_type(address_type);
self.routing_table
.dial_info_details(RoutingDomain::LocalNetwork)
.iter()
.filter_map(|did| {
if did.dial_info.matches_filter(&filter) {
Some(did.dial_info.socket_address())
} else {
None
}
})
.collect()
}
async fn validate_dial_info(
&self,
node_ref: NodeRef,
dial_info: DialInfo,
redirect: bool,
) -> bool {
let rpc = self.routing_table.rpc_processor();
rpc.rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect)
.await
.map_err(logthru_net!(
"failed to send validate_dial_info to {:?}",
node_ref
))
.unwrap_or(false)
}
async fn try_port_mapping(&self) -> Option<DialInfo> {
//xxx
None
}
fn make_dial_info(&self, addr: SocketAddress, protocol_type: ProtocolType) -> DialInfo {
match protocol_type {
ProtocolType::UDP => DialInfo::udp(addr),
ProtocolType::TCP => DialInfo::tcp(addr),
ProtocolType::WS => {
let c = self.net.config.get();
DialInfo::try_ws(
addr,
format!("ws://{}/{}", addr, c.network.protocol.ws.path),
)
.unwrap()
}
ProtocolType::WSS => panic!("none of the discovery functions are used for wss"),
}
}
///////
// Per-protocol discovery routines
pub fn protocol_begin(&self, protocol_type: ProtocolType, address_type: AddressType) {
// Get our interface addresses
let intf_addrs = self.get_local_addresses(protocol_type, address_type);
let mut inner = self.inner.lock();
inner.intf_addrs = Some(intf_addrs);
inner.protocol_type = Some(protocol_type);
inner.address_type = Some(address_type);
inner.external_1_dial_info = None;
inner.external_1_address = None;
inner.node_1 = None;
}
// Get our first node's view of our external IP address via normal RPC
pub async fn protocol_get_external_address_1(&self) -> bool {
let (protocol_type, address_type) = {
let inner = self.inner.lock();
(inner.protocol_type.unwrap(), inner.address_type.unwrap())
};
// Get our external address from some fast node, call it node 1
let (external_1, node_1) = match self
.discover_external_address(protocol_type, address_type, None)
.await
{
None => {
// If we can't get an external address, exit but don't throw an error so we can try again later
log_net!(debug "couldn't get external address 1 for {:?} {:?}", protocol_type, address_type);
return false;
}
Some(v) => v,
};
let external_1_dial_info = self.make_dial_info(external_1, protocol_type);
let mut inner = self.inner.lock();
inner.external_1_dial_info = Some(external_1_dial_info);
inner.external_1_address = Some(external_1);
inner.node_1 = Some(node_1);
log_net!(debug "external_1_dial_info: {:?}\nexternal_1_address: {:?}\nnode_1: {:?}", inner.external_1_dial_info, inner.external_1_address, inner.node_1);
true
}
// If we know we are not behind NAT, check our firewall status
pub async fn protocol_process_no_nat(&self) -> Result<(), String> {
let (node_b, external_1_dial_info) = {
let inner = self.inner.lock();
(
inner.node_1.as_ref().unwrap().clone(),
inner.external_1_dial_info.as_ref().unwrap().clone(),
)
};
// Do a validate_dial_info on the external address from a redirected node
if self
.validate_dial_info(node_b.clone(), external_1_dial_info.clone(), true)
.await
{
// Add public dial info with Direct dialinfo class
self.set_detected_public_dial_info(external_1_dial_info, DialInfoClass::Direct);
}
// Attempt a port mapping via all available and enabled mechanisms
else if let Some(external_mapped_dial_info) = self.try_port_mapping().await {
// Got a port mapping, let's use it
self.set_detected_public_dial_info(external_mapped_dial_info, DialInfoClass::Mapped);
} else {
// Add public dial info with Blocked dialinfo class
self.set_detected_public_dial_info(external_1_dial_info, DialInfoClass::Blocked);
}
self.set_detected_network_class(NetworkClass::InboundCapable);
Ok(())
}
// If we know we are behind NAT check what kind
pub async fn protocol_process_nat(&self) -> Result<bool, String> {
let (node_1, external_1_dial_info, external_1_address, protocol_type, address_type) = {
let inner = self.inner.lock();
(
inner.node_1.as_ref().unwrap().clone(),
inner.external_1_dial_info.as_ref().unwrap().clone(),
inner.external_1_address.unwrap(),
inner.protocol_type.unwrap(),
inner.address_type.unwrap(),
)
};
// Attempt a UDP port mapping via all available and enabled mechanisms
if let Some(external_mapped_dial_info) = self.try_port_mapping().await {
// Got a port mapping, let's use it
self.set_detected_public_dial_info(external_mapped_dial_info, DialInfoClass::Mapped);
self.set_detected_network_class(NetworkClass::InboundCapable);
// No more retries
return Ok(true);
}
// Port mapping was not possible, let's see what kind of NAT we have
// Does a redirected dial info validation from a different address and a random port find us?
if self
.validate_dial_info(node_1.clone(), external_1_dial_info.clone(), true)
.await
{
// Yes, another machine can use the dial info directly, so Full Cone
// Add public dial info with full cone NAT network class
self.set_detected_public_dial_info(external_1_dial_info, DialInfoClass::FullConeNAT);
self.set_detected_network_class(NetworkClass::InboundCapable);
// No more retries
return Ok(true);
}
// No, we are restricted, determine what kind of restriction
// Get our external address from some fast node, that is not node 1, call it node 2
let (external_2_address, node_2) = match self
.discover_external_address(protocol_type, address_type, Some(node_1.node_id()))
.await
{
None => {
// If we can't get an external address, allow retry
return Ok(false);
}
Some(v) => v,
};
// If we have two different external addresses, then this is a symmetric NAT
if external_2_address != external_1_address {
// Symmetric NAT is outbound only, no public dial info will work
self.set_detected_network_class(NetworkClass::OutboundOnly);
// No more retries
return Ok(true);
}
// If we're going to end up as a restricted NAT of some sort
// Address is the same, so it's address or port restricted
// Do a validate_dial_info on the external address from a random port
if self
.validate_dial_info(node_2.clone(), external_1_dial_info.clone(), false)
.await
{
// Got a reply from a non-default port, which means we're only address restricted
self.set_detected_public_dial_info(
external_1_dial_info,
DialInfoClass::AddressRestrictedNAT,
);
} else {
// Didn't get a reply from a non-default port, which means we are also port restricted
self.set_detected_public_dial_info(
external_1_dial_info,
DialInfoClass::PortRestrictedNAT,
);
}
self.set_detected_network_class(NetworkClass::InboundCapable);
// Allow another retry because sometimes trying again will get us Full Cone NAT instead
Ok(false)
}
}
impl Network {
pub async fn update_ipv4_protocol_dialinfo(
&self,
context: &DiscoveryContext,
protocol_type: ProtocolType,
) -> Result<(), String> {
let mut retry_count = {
let c = self.config.get();
c.network.restricted_nat_retries
};
// Start doing ipv4 protocol
context.protocol_begin(protocol_type, AddressType::IPV4);
// Loop for restricted NAT retries
loop {
log_net!(debug
"=== update_ipv4_protocol_dialinfo {:?} tries_left={} ===",
protocol_type,
retry_count
);
// Get our external address from some fast node, call it node 1
if !context.protocol_get_external_address_1().await {
// If we couldn't get an external address, then we should just try the whole network class detection again later
return Ok(());
}
// If our local interface list contains external_1 then there is no NAT in place
{
let res = {
let inner = context.inner.lock();
inner
.intf_addrs
.as_ref()
.unwrap()
.contains(inner.external_1_address.as_ref().unwrap())
};
if res {
// No NAT
context.protocol_process_no_nat().await?;
// No more retries
break;
}
}
// There is -some NAT-
if context.protocol_process_nat().await? {
// We either got dial info or a network class without one
break;
}
// If we tried everything, break anyway after N attempts
if retry_count == 0 {
break;
}
retry_count -= 1;
}
Ok(())
}
pub async fn update_ipv6_protocol_dialinfo(
&self,
context: &DiscoveryContext,
protocol_type: ProtocolType,
) -> Result<(), String> {
// Start doing ipv6 protocol
context.protocol_begin(protocol_type, AddressType::IPV6);
log_net!(debug "=== update_ipv6_protocol_dialinfo {:?} ===", protocol_type);
// Get our external address from some fast node, call it node 1
if !context.protocol_get_external_address_1().await {
// If we couldn't get an external address, then we should just try the whole network class detection again later
return Ok(());
}
// If our local interface list doesn't contain external_1 then there is an Ipv6 NAT in place
{
let inner = context.inner.lock();
if !inner
.intf_addrs
.as_ref()
.unwrap()
.contains(inner.external_1_address.as_ref().unwrap())
{
// IPv6 NAT is not supported today
log_net!(warn
"IPv6 NAT is not supported for external address: {}",
inner.external_1_address.unwrap()
);
return Ok(());
}
}
// No NAT
context.protocol_process_no_nat().await?;
Ok(())
}
pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> {
log_net!("--- updating network class");
// Ensure we aren't trying to update this without clearing it first
let old_network_class = self.inner.lock().network_class;
assert_eq!(old_network_class, None);
let protocol_config = self.inner.lock().protocol_config.unwrap_or_default();
let mut unord = FuturesUnordered::new();
// Do UDPv4+v6 at the same time as everything else
if protocol_config.inbound.contains(ProtocolType::UDP) {
// UDPv4
unord.push(
async {
let udpv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
if let Err(e) = self
.update_ipv4_protocol_dialinfo(&udpv4_context, ProtocolType::UDP)
.await
{
log_net!(debug "Failed UDPv4 dialinfo discovery: {}", e);
return None;
}
Some(vec![udpv4_context])
}
.boxed(),
);
// UDPv6
unord.push(
async {
let udpv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
if let Err(e) = self
.update_ipv6_protocol_dialinfo(&udpv6_context, ProtocolType::UDP)
.await
{
log_net!(debug "Failed UDPv6 dialinfo discovery: {}", e);
return None;
}
Some(vec![udpv6_context])
}
.boxed(),
);
}
// Do TCPv4 + WSv4 in series because they may use the same connection 5-tuple
unord.push(
async {
// TCPv4
let mut out = Vec::<DiscoveryContext>::new();
if protocol_config.inbound.contains(ProtocolType::TCP) {
let tcpv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
if let Err(e) = self
.update_ipv4_protocol_dialinfo(&tcpv4_context, ProtocolType::TCP)
.await
{
log_net!(debug "Failed TCPv4 dialinfo discovery: {}", e);
return None;
}
out.push(tcpv4_context);
}
// WSv4
if protocol_config.inbound.contains(ProtocolType::WS) {
let wsv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
if let Err(e) = self
.update_ipv4_protocol_dialinfo(&wsv4_context, ProtocolType::WS)
.await
{
log_net!(debug "Failed WSv4 dialinfo discovery: {}", e);
return None;
}
out.push(wsv4_context);
}
Some(out)
}
.boxed(),
);
// Do TCPv6 + WSv6 in series because they may use the same connection 5-tuple
unord.push(
async {
// TCPv6
let mut out = Vec::<DiscoveryContext>::new();
if protocol_config.inbound.contains(ProtocolType::TCP) {
let tcpv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
if let Err(e) = self
.update_ipv6_protocol_dialinfo(&tcpv6_context, ProtocolType::TCP)
.await
{
log_net!(debug "Failed TCPv6 dialinfo discovery: {}", e);
return None;
}
out.push(tcpv6_context);
}
// WSv6
if protocol_config.inbound.contains(ProtocolType::WS) {
let wsv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
if let Err(e) = self
.update_ipv6_protocol_dialinfo(&wsv6_context, ProtocolType::WS)
.await
{
log_net!(debug "Failed WSv6 dialinfo discovery: {}", e);
return None;
}
out.push(wsv6_context);
}
Some(out)
}
.boxed(),
);
// Wait for all discovery futures to complete and collect contexts
let mut contexts = Vec::<DiscoveryContext>::new();
let mut network_class = Option::<NetworkClass>::None;
while let Some(ctxvec) = unord.next().await {
if let Some(ctxvec) = ctxvec {
for ctx in ctxvec {
if let Some(nc) = ctx.inner.lock().detected_network_class {
if let Some(last_nc) = network_class {
if nc < last_nc {
network_class = Some(nc);
}
} else {
network_class = Some(nc);
}
}
contexts.push(ctx);
}
}
}
// Get best network class
if network_class.is_some() {
// Update public dial info
let routing_table = self.routing_table();
for ctx in contexts {
let inner = ctx.inner.lock();
if let Some(pdi) = &inner.detected_public_dial_info {
if let Err(e) = routing_table.register_dial_info(
RoutingDomain::PublicInternet,
pdi.dial_info.clone(),
pdi.class,
) {
log_net!(warn "Failed to register detected public dial info: {}", e);
}
}
}
// Update network class
self.inner.lock().network_class = network_class;
log_net!(debug "network class changed to {:?}", network_class);
// Send updates to everyone
routing_table.send_node_info_updates();
}
Ok(())
}
}

View File

@@ -0,0 +1,289 @@
use super::*;
use crate::intf::*;
use async_tls::TlsAcceptor;
use sockets::*;
/////////////////////////////////////////////////////////////////
#[derive(Clone)]
pub struct ListenerState {
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
pub tls_acceptor: Option<TlsAcceptor>,
}
impl ListenerState {
pub fn new() -> Self {
Self {
protocol_handlers: Vec::new(),
tls_protocol_handlers: Vec::new(),
tls_acceptor: None,
}
}
}
/////////////////////////////////////////////////////////////////
impl Network {
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
return Ok(ts.clone());
}
let server_config = self
.load_server_config()
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
self.inner.lock().tls_acceptor = Some(acceptor.clone());
Ok(acceptor)
}
async fn try_tls_handlers(
&self,
tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout: u64,
) -> Result<Option<NetworkConnection>, String> {
let ts = tls_acceptor
.accept(stream)
.await
.map_err(map_to_string)
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
let ps = AsyncPeekStream::new(CloneStream::new(ts));
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// Try the handlers but first get a chunk of data for them to process
// Don't waste more than N seconds getting it though, in case someone
// is trying to DoS us with a bunch of connections or something
// read a chunk of the stream
io::timeout(
Duration::from_micros(tls_connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
.await
}
async fn try_handlers(
&self,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<NetworkConnection>, String> {
for ah in protocol_handlers.iter() {
if let Some(nc) = ah
.on_accept(stream.clone(), tcp_stream.clone(), addr)
.await
.map_err(logthru_net!())?
{
return Ok(Some(nc));
}
}
Ok(None)
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
// Get config
let (connection_initial_timeout, tls_connection_initial_timeout) = {
let c = self.config.get();
(
ms_to_us(c.network.connection_initial_timeout_ms),
ms_to_us(c.network.tls.connection_initial_timeout_ms),
)
};
// Create a reusable socket with no linger time, and no delay
let socket = new_bound_shared_tcp_socket(addr)?;
// Listen on the socket
socket
.listen(128)
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from(std_listener);
debug!("spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
self.inner
.lock()
.listener_states
.insert(addr, listener_state.clone());
// Spawn the socket task
let this = self.clone();
let connection_manager = self.connection_manager();
////////////////////////////////////////////////////////////
let jh = spawn(async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
listener
.incoming()
.for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap();
let listener_state = listener_state.clone();
let connection_manager = connection_manager.clone();
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
log_net!(error "failed to get peer address: {}", e);
return;
}
};
// XXX limiting
log_net!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream.clone());
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
log_net!(warn "connection initial timeout from: {:?}", addr);
return;
}
// Run accept handlers on accepted stream
// Check is this could be TLS
let ls = listener_state.read().clone();
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
this.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
tcp_stream,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await
} else {
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
.await
};
let conn = match conn {
Ok(Some(c)) => {
log_net!("protocol handler found for {:?}: {:?}", addr, c);
c
}
Ok(None) => {
// No protocol handlers matched? drop it.
log_net!(warn "no protocol handler for connection from {:?}", addr);
return;
}
Err(e) => {
// Failed to negotiate connection? drop it.
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
return;
}
};
// Register the new connection in the connection manager
if let Err(e) = connection_manager.on_new_connection(conn).await {
log_net!(error "failed to register new connection: {}", e);
}
})
.await;
log_net!(debug "exited incoming loop for {}", addr);
// Remove our listener state from this address if we're stopping
this.inner.lock().listener_states.remove(&addr);
log_net!(debug "listener state removed for {}", addr);
// If this happened our low-level listener socket probably died
// so it's time to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handles
self.add_to_join_handles(jh);
Ok(())
}
/////////////////////////////////////////////////////////////////
// TCP listener that multiplexes ports so multiple protocols can exist on a single port
pub(super) async fn start_tcp_listener(
&self,
ip_addrs: Vec<IpAddr>,
port: u16,
is_tls: bool,
new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
) -> Result<Vec<SocketAddress>, String> {
let mut out = Vec::<SocketAddress>::new();
for ip_addr in ip_addrs {
let addr = SocketAddr::new(ip_addr, port);
let idi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().listener_states.contains_key(&addr) {
self.clone().spawn_socket_listener(addr).await?;
}
let ls = if let Some(ls) = self.inner.lock().listener_states.get_mut(&addr) {
ls.clone()
} else {
panic!("this shouldn't happen");
};
if is_tls {
if ls.read().tls_acceptor.is_none() {
ls.write().tls_acceptor = Some(self.clone().get_or_create_tls_acceptor()?);
}
ls.write()
.tls_protocol_handlers
.push(new_protocol_accept_handler(
self.network_manager().config(),
true,
addr,
));
} else {
ls.write()
.protocol_handlers
.push(new_protocol_accept_handler(
self.network_manager().config(),
false,
addr,
));
}
// Return interface dial infos we listen on
for idi_addr in idi_addrs {
out.push(SocketAddress::from_socket_addr(idi_addr));
}
}
Ok(out)
}
}

View File

@@ -0,0 +1,209 @@
use super::*;
use sockets::*;
impl Network {
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
c.network.protocol.udp.socket_pool_size
};
if task_count == 0 {
task_count = intf::get_concurrency() / 2;
if task_count == 0 {
task_count = 1;
}
}
trace!("task_count: {}", task_count);
for _ in 0..task_count {
trace!("Spawning UDP listener task");
////////////////////////////////////////////////////////////
// Run thread task to process stream of messages
let this = self.clone();
let jh = spawn(async move {
trace!("UDP listener task spawned");
// Collect all our protocol handlers into a vector
let mut protocol_handlers: Vec<RawUdpProtocolHandler> = this
.inner
.lock()
.inbound_udp_protocol_handlers
.values()
.cloned()
.collect();
if let Some(ph) = this.inner.lock().outbound_udpv4_protocol_handler.clone() {
protocol_handlers.push(ph);
}
if let Some(ph) = this.inner.lock().outbound_udpv6_protocol_handler.clone() {
protocol_handlers.push(ph);
}
// Spawn a local async task for each socket
let mut protocol_handlers_unordered = FuturesUnordered::new();
let network_manager = this.network_manager();
for ph in protocol_handlers {
let network_manager = network_manager.clone();
let jh = spawn_local(async move {
let mut data = vec![0u8; 65536];
while let Ok((size, descriptor)) = ph.recv_message(&mut data).await {
// XXX: Limit the number of packets from the same IP address?
log_net!("UDP packet: {:?}", descriptor);
// Network accounting
network_manager.stats_packet_rcvd(
descriptor.remote.to_socket_addr().ip(),
size as u64,
);
// Pass it up for processing
if let Err(e) = network_manager
.on_recv_envelope(&data[..size], descriptor)
.await
{
log_net!(error "failed to process received udp envelope: {}", e);
}
}
});
protocol_handlers_unordered.push(jh);
}
// Now we wait for any join handle to exit,
// which would indicate an error needing
// us to completely restart the network
let _ = protocol_handlers_unordered.next().await;
trace!("UDP listener task stopped");
// If this loop fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handle
self.add_to_join_handles(jh);
}
Ok(())
}
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
let mut inner = self.inner.lock();
let mut port = inner.udp_port;
// v4
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v4) {
// Pull the port if we randomly bound, so v6 can be on the same port
port = socket
.local_addr()
.map_err(map_to_string)?
.as_socket_ipv4()
.ok_or_else(|| "expected ipv4 address type".to_owned())?
.port();
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv4_handler = RawUdpProtocolHandler::new(socket_arc);
inner.outbound_udpv4_protocol_handler = Some(udpv4_handler);
}
//v6
let socket_addr_v6 =
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v6) {
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv6_handler = RawUdpProtocolHandler::new(socket_arc);
inner.outbound_udpv6_protocol_handler = Some(udpv6_handler);
}
Ok(())
}
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
log_net!("create_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket
let socket = new_bound_shared_udp_socket(addr)?;
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let protocol_handler = RawUdpProtocolHandler::new(socket_arc);
// Create message_handler records
self.inner
.lock()
.inbound_udp_protocol_handlers
.insert(addr, protocol_handler);
Ok(())
}
pub(super) async fn create_udp_inbound_sockets(
&self,
ip_addrs: Vec<IpAddr>,
port: u16,
) -> Result<Vec<DialInfo>, String> {
let mut out = Vec::<DialInfo>::new();
for ip_addr in ip_addrs {
let addr = SocketAddr::new(ip_addr, port);
// see if we've already bound to this already
// if not, spawn a listener
if !self
.inner
.lock()
.inbound_udp_protocol_handlers
.contains_key(&addr)
{
let idi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
self.clone().create_udp_inbound_socket(addr).await?;
// Return interface dial infos we listen on
for idi_addr in idi_addrs {
out.push(DialInfo::udp_from_socketaddr(idi_addr));
}
}
}
Ok(out)
}
/////////////////////////////////////////////////////////////////
pub(super) fn find_best_udp_protocol_handler(
&self,
peer_socket_addr: &SocketAddr,
local_socket_addr: &Option<SocketAddr>,
) -> Option<RawUdpProtocolHandler> {
// if our last communication with this peer came from a particular inbound udp protocol handler, use it
if let Some(sa) = local_socket_addr {
if let Some(ph) = self.inner.lock().inbound_udp_protocol_handlers.get(sa) {
return Some(ph.clone());
}
}
// otherwise find the outbound udp protocol handler that matches the ip protocol version of the peer addr
let inner = self.inner.lock();
match peer_socket_addr {
SocketAddr::V4(_) => inner.outbound_udpv4_protocol_handler.clone(),
SocketAddr::V6(_) => inner.outbound_udpv6_protocol_handler.clone(),
}
}
}

View File

@@ -0,0 +1,86 @@
pub mod sockets;
pub mod tcp;
pub mod udp;
pub mod wrtc;
pub mod ws;
use super::*;
use crate::xx::*;
#[derive(Debug)]
pub enum ProtocolNetworkConnection {
Dummy(DummyNetworkConnection),
RawTcp(tcp::RawTcpNetworkConnection),
WsAccepted(ws::WebSocketNetworkConnectionAccepted),
Ws(ws::WebsocketNetworkConnectionWS),
Wss(ws::WebsocketNetworkConnectionWSS),
//WebRTC(wrtc::WebRTCNetworkConnection),
}
impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo");
}
ProtocolType::TCP => {
tcp::RawTcpProtocolHandler::connect(local_address, dial_info).await
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
}
}
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
.await
.map_err(logthru_net!())
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
}
}
}
pub async fn close(&self) -> Result<(), String> {
match self {
Self::Dummy(d) => d.close(),
Self::RawTcp(t) => t.close().await,
Self::WsAccepted(w) => w.close().await,
Self::Ws(w) => w.close().await,
Self::Wss(w) => w.close().await,
}
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
match self {
Self::Dummy(d) => d.send(message),
Self::RawTcp(t) => t.send(message).await,
Self::WsAccepted(w) => w.send(message).await,
Self::Ws(w) => w.send(message).await,
Self::Wss(w) => w.send(message).await,
}
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
match self {
Self::Dummy(d) => d.recv(),
Self::RawTcp(t) => t.recv().await,
Self::WsAccepted(w) => w.recv().await,
Self::Ws(w) => w.recv().await,
Self::Wss(w) => w.recv().await,
}
}
}

View File

@@ -0,0 +1,222 @@
use crate::xx::*;
use crate::*;
use async_io::Async;
use async_std::net::TcpStream;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
cfg_if! {
if #[cfg(windows)] {
use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
use winapi::ctypes::c_int;
use std::os::windows::io::AsRawSocket;
fn set_exclusiveaddruse(socket: &Socket) -> Result<(), String> {
unsafe {
let optval:c_int = 1;
if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
return Err("Unable to SO_EXCLUSIVEADDRUSE".to_owned());
}
Ok(())
}
}
}
}
pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<Socket, String> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
}
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
}
}
Ok(socket)
}
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
socket.bind(&socket2_addr).map_err(|e| {
format!(
"failed to bind UDP socket to '{}' in domain '{:?}': {} ",
local_address, domain, e
)
})?;
log_net!("created bound shared udp socket on {:?}", &local_address);
Ok(socket)
}
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
let domain = Domain::for_address(local_address);
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
}
// Bind the socket -first- before turning on 'reuse address' this way it will
// fail if the port is already taken
let socket2_addr = SockAddr::from(local_address);
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
cfg_if! {
if #[cfg(windows)] {
set_exclusiveaddruse(&socket)?;
}
}
socket
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind UDP socket: {}", e))?;
// Set 'reuse address' so future binds to this port will succeed
// This does not work on Windows, where reuse options can not be set after the bind
cfg_if! {
if #[cfg(unix)] {
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
}
}
log_net!("created bound first udp socket on {:?}", &local_address);
Ok(socket)
}
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(map_to_string)
.map_err(logthru_net!("failed to create TCP socket"))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
}
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
}
}
Ok(socket)
}
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
socket
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
log_net!("created bound shared tcp socket on {:?}", &local_address);
Ok(socket)
}
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
let domain = Domain::for_address(local_address);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(map_to_string)
.map_err(logthru_net!("failed to create TCP socket"))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
if let Err(e) = socket.set_nodelay(true) {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
}
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
cfg_if! {
if #[cfg(windows)] {
set_exclusiveaddruse(&socket)?;
}
}
// Bind the socket -first- before turning on 'reuse address' this way it will
// fail if the port is already taken
let socket2_addr = SockAddr::from(local_address);
socket
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
// Set 'reuse address' so future binds to this port will succeed
// This does not work on Windows, where reuse options can not be set after the bind
cfg_if! {
if #[cfg(unix)] {
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
}
}
log_net!("created bound first tcp socket on {:?}", &local_address);
Ok(socket)
}
// Non-blocking connect is tricky when you want to start with a prepared socket
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::Result<TcpStream> {
// Set for non blocking connect
socket.set_nonblocking(true)?;
// Make socket2 SockAddr
let socket2_addr = socket2::SockAddr::from(addr);
// Connect to the remote address
match socket.connect(&socket2_addr) {
Ok(()) => Ok(()),
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
Err(e) => Err(e),
}?;
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
// The stream becomes writable when connected
intf::timeout(2000, async_stream.writable())
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))??;
// Check low level error
let async_stream = match async_stream.get_ref().take_error()? {
None => Ok(async_stream),
Some(err) => Err(err),
}?;
// Convert back to inner and then return async version
Ok(TcpStream::from(async_stream.into_inner()?))
}

View File

@@ -0,0 +1,227 @@
use super::*;
use futures_util::{AsyncReadExt, AsyncWriteExt};
use sockets::*;
pub struct RawTcpNetworkConnection {
stream: AsyncPeekStream,
tcp_stream: TcpStream,
}
impl fmt::Debug for RawTcpNetworkConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RawTCPNetworkConnection").finish()
}
}
impl RawTcpNetworkConnection {
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
Self { stream, tcp_stream }
}
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
log_net!("sending TCP message of size {}", message.len());
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned());
}
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
let mut stream = self.stream.clone();
stream
.write_all(&header)
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
stream
.write_all(&message)
.await
.map_err(map_to_string)
.map_err(logthru_net!())
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let mut header = [0u8; 4];
let mut stream = self.stream.clone();
stream
.read_exact(&mut header)
.await
.map_err(|e| format!("TCP recv error: {}", e))?;
if header[0] != b'V' || header[1] != b'L' {
return Err("received invalid TCP frame header".to_owned());
}
let len = ((header[3] as usize) << 8) | (header[2] as usize);
if len > MAX_MESSAGE_SIZE {
return Err("received too large TCP frame".to_owned());
}
let mut out: Vec<u8> = vec![0u8; len];
stream.read_exact(&mut out).await.map_err(map_to_string)?;
Ok(out)
}
}
///////////////////////////////////////////////////////////
///
struct RawTcpProtocolHandlerInner {
local_address: SocketAddr,
}
#[derive(Clone)]
pub struct RawTcpProtocolHandler
where
Self: ProtocolAcceptHandler,
{
inner: Arc<Mutex<RawTcpProtocolHandlerInner>>,
}
impl RawTcpProtocolHandler {
fn new_inner(local_address: SocketAddr) -> RawTcpProtocolHandlerInner {
RawTcpProtocolHandlerInner { local_address }
}
pub fn new(local_address: SocketAddr) -> Self {
Self {
inner: Arc::new(Mutex::new(Self::new_inner(local_address))),
}
}
async fn on_accept_async(
self,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> {
log_net!("TCP: on_accept_async: enter");
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream
.peek(&mut peekbuf)
.await
.map_err(map_to_string)
.map_err(logthru_net!("could not peek tcp stream"))?;
assert_eq!(peeklen, PEEK_DETECT_LEN);
let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
);
let local_address = self.inner.lock().local_address;
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
);
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
Ok(Some(conn))
}
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
// Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket
let socket = match local_address {
Some(a) => new_bound_shared_tcp_socket(a)?,
None => {
new_unbound_shared_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?
}
};
// Non-blocking connect to remote address
let ts = nonblocking_connect(socket, remote_socket_addr).await
.map_err(map_to_string)
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: dial_info.to_peer_address(),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
Ok(conn)
}
pub async fn send_unbound_message(
socket_addr: SocketAddr,
data: Vec<u8>,
) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound TCP message".to_owned());
}
trace!(
"sending unbound message of length {} to {}",
data.len(),
socket_addr
);
// Make a shared socket
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
// Non-blocking connect to remote address
let ts = nonblocking_connect(socket, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?;
let ps = AsyncPeekStream::new(ts.clone());
// Wrap the stream in a network connection and return it
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
remote: PeerAddress::new(
SocketAddress::from_socket_addr(socket_addr),
ProtocolType::TCP,
),
},
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
);
conn.send(data).await
}
}
impl ProtocolAcceptHandler for RawTcpProtocolHandler {
fn on_accept(
&self,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
}
}

View File

@@ -0,0 +1,102 @@
use super::*;
#[derive(Clone)]
pub struct RawUdpProtocolHandler {
socket: Arc<UdpSocket>,
}
impl RawUdpProtocolHandler {
pub fn new(socket: Arc<UdpSocket>) -> Self {
Self { socket }
}
pub async fn recv_message(
&self,
data: &mut [u8],
) -> Result<(usize, ConnectionDescriptor), String> {
let (size, remote_addr) = self.socket.recv_from(data).await.map_err(map_to_string)?;
if size > MAX_MESSAGE_SIZE {
return Err("received too large UDP message".to_owned());
}
trace!(
"receiving UDP message of length {} from {}",
size,
remote_addr
);
let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(remote_addr),
ProtocolType::UDP,
);
let local_socket_addr = self.socket.local_addr().map_err(map_to_string)?;
let descriptor = ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(local_socket_addr),
);
Ok((size, descriptor))
}
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
}
log_net!(
"sending UDP message of length {} to {}",
data.len(),
socket_addr
);
let len = self
.socket
.send_to(&data, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed udp send: addr={}", socket_addr))?;
if len != data.len() {
Err("UDP partial send".to_owned()).map_err(logthru_net!(error))
} else {
Ok(())
}
}
pub async fn send_unbound_message(
socket_addr: SocketAddr,
data: Vec<u8>,
) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound UDP message".to_owned())
.map_err(logthru_net!(error));
}
log_net!(
"sending unbound message of length {} to {}",
data.len(),
socket_addr
);
// get local wildcard address for bind
let local_socket_addr = match socket_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
SocketAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
}
};
let socket = UdpSocket::bind(local_socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
let len = socket
.send_to(&data, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
if len != data.len() {
Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error))
} else {
Ok(())
}
}
}

View File

@@ -0,0 +1,301 @@
use super::*;
use async_std::io;
use async_tls::TlsConnector;
use async_tungstenite::tungstenite::protocol::Message;
use async_tungstenite::{accept_async, client_async, WebSocketStream};
use futures_util::SinkExt;
use sockets::*;
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
pub type WebsocketNetworkConnectionWSS =
WebsocketNetworkConnection<async_tls::client::TlsStream<TcpStream>>;
pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<TcpStream>;
pub struct WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
stream: CloneStream<WebSocketStream<T>>,
tcp_stream: TcpStream,
}
impl<T> fmt::Debug for WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", std::any::type_name::<Self>())
}
}
impl<T> WebsocketNetworkConnection<T>
where
T: io::Read + io::Write + Send + Unpin + 'static,
{
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
Self {
stream: CloneStream::new(stream),
tcp_stream,
}
}
pub async fn close(&self) -> Result<(), String> {
// Make an attempt to flush the stream
self.stream
.clone()
.close()
.await
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Then forcibly close the socket
self.tcp_stream
.shutdown(Shutdown::Both)
.map_err(map_to_string)
.map_err(logthru_net!())
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned());
}
self.stream
.clone()
.send(Message::binary(message))
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to send websocket message"))
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let out = match self.stream.clone().next().await {
Some(Ok(Message::Binary(v))) => v,
Some(Ok(_)) => {
return Err("Unexpected WS message type".to_owned()).map_err(logthru_net!(error));
}
Some(Err(e)) => {
return Err(e.to_string()).map_err(logthru_net!(error));
}
None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
Ok(out)
}
}
}
///////////////////////////////////////////////////////////
///
struct WebsocketProtocolHandlerArc {
tls: bool,
local_address: SocketAddr,
request_path: Vec<u8>,
connection_initial_timeout: u64,
}
#[derive(Clone)]
pub struct WebsocketProtocolHandler
where
Self: ProtocolAcceptHandler,
{
arc: Arc<WebsocketProtocolHandlerArc>,
}
impl WebsocketProtocolHandler {
pub fn new(config: VeilidConfig, tls: bool, local_address: SocketAddr) -> Self {
let c = config.get();
let path = if tls {
format!("GET /{}", c.network.protocol.ws.path.trim_end_matches('/'))
} else {
format!("GET /{}", c.network.protocol.wss.path.trim_end_matches('/'))
};
let connection_initial_timeout = if tls {
ms_to_us(c.network.tls.connection_initial_timeout_ms)
} else {
ms_to_us(c.network.connection_initial_timeout_ms)
};
Self {
arc: Arc::new(WebsocketProtocolHandlerArc {
tls,
local_address,
request_path: path.as_bytes().to_vec(),
connection_initial_timeout,
}),
}
}
pub async fn on_accept_async(
self,
ps: AsyncPeekStream,
tcp_stream: TcpStream,
socket_addr: SocketAddr,
) -> Result<Option<NetworkConnection>, String> {
log_net!("WS: on_accept_async: enter");
let request_path_len = self.arc.request_path.len() + 2;
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
match io::timeout(
Duration::from_micros(self.arc.connection_initial_timeout),
ps.peek_exact(&mut peekbuf),
)
.await
{
Ok(_) => (),
Err(e) => {
if e.kind() == io::ErrorKind::TimedOut {
return Err(e).map_err(map_to_string).map_err(logthru_net!());
}
return Err(e).map_err(map_to_string).map_err(logthru_net!(error));
}
}
// Check for websocket path
let matches_path = &peekbuf[0..request_path_len - 2] == self.arc.request_path.as_slice()
&& (peekbuf[request_path_len - 2] == b' '
|| (peekbuf[request_path_len - 2] == b'/'
&& peekbuf[request_path_len - 1] == b' '));
if !matches_path {
log_net!("WS: not websocket");
return Ok(None);
}
log_net!("WS: found websocket");
let ws_stream = accept_async(ps)
.await
.map_err(map_to_string)
.map_err(logthru_net!("failed websockets handshake"))?;
// Wrap the websocket in a NetworkConnection and register it
let protocol_type = if self.arc.tls {
ProtocolType::WSS
} else {
ProtocolType::WS
};
let peer_addr =
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
let conn = NetworkConnection::from_protocol(
ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(self.arc.local_address),
),
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
);
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
Ok(Some(conn))
}
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
// Split dial info up
let (tls, scheme) = match &dial_info {
DialInfo::WS(_) => (false, "ws"),
DialInfo::WSS(_) => (true, "wss"),
_ => panic!("invalid dialinfo for WS/WSS protocol"),
};
let request = dial_info.request().unwrap();
let split_url = SplitUrl::from_str(&request)?;
if split_url.scheme != scheme {
return Err("invalid websocket url scheme".to_string());
}
let domain = split_url.host.clone();
// Resolve remote address
let remote_socket_addr = dial_info.to_socket_addr();
// Make a shared socket
let socket = match local_address {
Some(a) => new_bound_shared_tcp_socket(a)?,
None => {
new_unbound_shared_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?
}
};
// Non-blocking connect to remote address
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await
.map_err(map_to_string)
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
// See what local address we ended up with
let actual_local_addr = tcp_stream
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!())?;
// Make our connection descriptor
let descriptor = ConnectionDescriptor {
local: Some(SocketAddress::from_socket_addr(actual_local_addr)),
remote: dial_info.to_peer_address(),
};
// Negotiate TLS if this is WSS
if tls {
let connector = TlsConnector::default();
let tls_stream = connector
.connect(domain.to_string(), tcp_stream.clone())
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
let (ws_stream, _response) = client_async(request, tls_stream)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
))
} else {
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
Ok(NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
ws_stream, tcp_stream,
)),
))
}
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());
}
trace!(
"sending unbound websocket message of length {} to {}",
data.len(),
dial_info,
);
let conn = Self::connect(None, dial_info.clone())
.await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
conn.send(data).await
}
}
impl ProtocolAcceptHandler for WebsocketProtocolHandler {
fn on_accept(
&self,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
}
}

View File

@@ -0,0 +1,711 @@
use super::sockets::*;
use super::*;
use lazy_static::*;
lazy_static! {
static ref BAD_PORTS: BTreeSet<u16> = BTreeSet::from([
1, // tcpmux
7, // echo
9, // discard
11, // systat
13, // daytime
15, // netstat
17, // qotd
19, // chargen
20, // ftp data
21, // ftp access
22, // ssh
23, // telnet
25, // smtp
37, // time
42, // name
43, // nicname
53, // domain
77, // priv-rjs
79, // finger
87, // ttylink
95, // supdup
101, // hostriame
102, // iso-tsap
103, // gppitnp
104, // acr-nema
109, // pop2
110, // pop3
111, // sunrpc
113, // auth
115, // sftp
117, // uucp-path
119, // nntp
123, // NTP
135, // loc-srv /epmap
139, // netbios
143, // imap2
179, // BGP
389, // ldap
427, // SLP (Also used by Apple Filing Protocol)
465, // smtp+ssl
512, // print / exec
513, // login
514, // shell
515, // printer
526, // tempo
530, // courier
531, // chat
532, // netnews
540, // uucp
548, // AFP (Apple Filing Protocol)
556, // remotefs
563, // nntp+ssl
587, // smtp (rfc6409)
601, // syslog-conn (rfc3195)
636, // ldap+ssl
993, // ldap+ssl
995, // pop3+ssl
2049, // nfs
3659, // apple-sasl / PasswordServer
4045, // lockd
6000, // X11
6665, // Alternate IRC [Apple addition]
6666, // Alternate IRC [Apple addition]
6667, // Standard IRC [Apple addition]
6668, // Alternate IRC [Apple addition]
6669, // Alternate IRC [Apple addition]
6697, // IRC + TLS
]);
}
impl Network {
/////////////////////////////////////////////////////
// Support for binding first on ports to ensure nobody binds ahead of us
// or two copies of the app don't accidentally collide. This is tricky
// because we use 'reuseaddr/port' and we can accidentally bind in front of ourselves :P
fn bind_first_udp_port(&self, udp_port: u16) -> bool {
let mut inner = self.inner.lock();
if inner.bound_first_udp.contains_key(&udp_port) {
return true;
}
// If the address is specified, only use the specified port and fail otherwise
let mut bound_first_socket_v4 = None;
let mut bound_first_socket_v6 = None;
if let Ok(bfs4) =
new_bound_first_udp_socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), udp_port))
{
if let Ok(bfs6) = new_bound_first_udp_socket(SocketAddr::new(
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
udp_port,
)) {
bound_first_socket_v4 = Some(bfs4);
bound_first_socket_v6 = Some(bfs6);
}
}
if let (Some(bfs4), Some(bfs6)) = (bound_first_socket_v4, bound_first_socket_v6) {
cfg_if! {
if #[cfg(windows)] {
// On windows, drop the socket. This is a race condition, but there's
// no way around it. This isn't for security anyway, it's to prevent multiple copies of the
// app from binding on the same port.
drop(bfs4);
drop(bfs6);
inner.bound_first_udp.insert(udp_port, None);
} else {
inner.bound_first_udp.insert(udp_port, Some((bfs4, bfs6)));
}
}
true
} else {
false
}
}
fn bind_first_tcp_port(&self, tcp_port: u16) -> bool {
let mut inner = self.inner.lock();
if inner.bound_first_tcp.contains_key(&tcp_port) {
return true;
}
// If the address is specified, only use the specified port and fail otherwise
let mut bound_first_socket_v4 = None;
let mut bound_first_socket_v6 = None;
if let Ok(bfs4) =
new_bound_first_tcp_socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), tcp_port))
{
if let Ok(bfs6) = new_bound_first_tcp_socket(SocketAddr::new(
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
tcp_port,
)) {
bound_first_socket_v4 = Some(bfs4);
bound_first_socket_v6 = Some(bfs6);
}
}
if let (Some(bfs4), Some(bfs6)) = (bound_first_socket_v4, bound_first_socket_v6) {
cfg_if! {
if #[cfg(windows)] {
// On windows, drop the socket. This is a race condition, but there's
// no way around it. This isn't for security anyway, it's to prevent multiple copies of the
// app from binding on the same port.
drop(bfs4);
drop(bfs6);
inner.bound_first_tcp.insert(tcp_port, None);
} else {
inner.bound_first_tcp.insert(tcp_port, Some((bfs4, bfs6)));
}
}
true
} else {
false
}
}
pub(super) fn free_bound_first_ports(&self) {
let mut inner = self.inner.lock();
inner.bound_first_udp.clear();
inner.bound_first_tcp.clear();
}
/////////////////////////////////////////////////////
fn find_available_udp_port(&self) -> Result<u16, String> {
// If the address is empty, iterate ports until we find one we can use.
let mut udp_port = 5150u16;
loop {
if BAD_PORTS.contains(&udp_port) {
continue;
}
if self.bind_first_udp_port(udp_port) {
break;
}
if udp_port == 65535 {
return Err("Could not find free udp port to listen on".to_owned());
}
udp_port += 1;
}
Ok(udp_port)
}
fn find_available_tcp_port(&self) -> Result<u16, String> {
// If the address is empty, iterate ports until we find one we can use.
let mut tcp_port = 5150u16;
loop {
if BAD_PORTS.contains(&tcp_port) {
continue;
}
if self.bind_first_tcp_port(tcp_port) {
break;
}
if tcp_port == 65535 {
return Err("Could not find free tcp port to listen on".to_owned());
}
tcp_port += 1;
}
Ok(tcp_port)
}
async fn allocate_udp_port(
&self,
listen_address: String,
) -> Result<(u16, Vec<IpAddr>), String> {
if listen_address.is_empty() {
// If listen address is empty, find us a port iteratively
let port = self.find_available_udp_port()?;
let ip_addrs = vec![
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
];
Ok((port, ip_addrs))
} else {
// If no address is specified, but the port is, use ipv4 and ipv6 unspecified
// If the address is specified, only use the specified port and fail otherwise
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
if sockaddrs.is_empty() {
return Err(format!("No valid listen address: {}", listen_address));
}
let port = sockaddrs[0].port();
if self.bind_first_udp_port(port) {
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
} else {
Err("Could not find free udp port to listen on".to_owned())
}
}
}
async fn allocate_tcp_port(
&self,
listen_address: String,
) -> Result<(u16, Vec<IpAddr>), String> {
if listen_address.is_empty() {
// If listen address is empty, find us a port iteratively
let port = self.find_available_tcp_port()?;
let ip_addrs = vec![
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
];
Ok((port, ip_addrs))
} else {
// If no address is specified, but the port is, use ipv4 and ipv6 unspecified
// If the address is specified, only use the specified port and fail otherwise
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
if sockaddrs.is_empty() {
return Err(format!("No valid listen address: {}", listen_address));
}
let port = sockaddrs[0].port();
if self.bind_first_tcp_port(port) {
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
} else {
Err("Could not find free tcp port to listen on".to_owned())
}
}
}
/////////////////////////////////////////////////////
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
trace!("starting udp listeners");
let routing_table = self.routing_table();
let (listen_address, public_address, enable_local_peer_scope) = {
let c = self.config.get();
(
c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(),
c.network.enable_local_peer_scope,
)
};
// Pick out UDP port we're going to use everywhere
// Keep sockets around until the end of this function
// to keep anyone else from binding in front of us
let (udp_port, ip_addrs) = self.allocate_udp_port(listen_address.clone()).await?;
// Save the bound udp port for use later on
self.inner.lock().udp_port = udp_port;
// First, create outbound sockets
// (unlike tcp where we create sockets for every connection)
// and we'll add protocol handlers for them too
self.create_udp_outbound_sockets().await?;
// Now create udp inbound sockets for whatever interfaces we're listening on
info!(
"UDP: starting listeners on port {} at {:?}",
udp_port, ip_addrs
);
let local_dial_info_list = self.create_udp_inbound_sockets(ip_addrs, udp_port).await?;
let mut static_public = false;
trace!("UDP: listener started on {:#?}", local_dial_info_list);
// Register local dial info
for di in &local_dial_info_list {
// If the local interface address is global, or we are enabling local peer scope
// register global dial info if no public address is specified
if public_address.is_none() && (di.is_global() || enable_local_peer_scope) {
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
di.clone(),
DialInfoClass::Direct,
)?;
static_public = true;
}
// Register interface dial info as well since the address is on the local interface
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
di.clone(),
DialInfoClass::Direct,
)?;
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
let pdi = DialInfo::udp_from_socketaddr(pdi_addr);
// Register the public address
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
pdi.clone(),
DialInfoClass::Direct,
)?;
// See if this public address is also a local interface address we haven't registered yet
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
for ip_addr in ip_addrs {
if pdi_addr.ip() == *ip_addr {
return true;
}
}
false
});
if !local_dial_info_list.contains(&pdi) && is_interface_address {
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
DialInfo::udp_from_socketaddr(pdi_addr),
DialInfoClass::Direct,
)?;
}
static_public = true;
}
}
if static_public {
self.inner
.lock()
.static_public_dialinfo
.insert(ProtocolType::UDP);
}
// Now create tasks for udp listeners
self.create_udp_listener_tasks().await
}
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
trace!("starting ws listeners");
let routing_table = self.routing_table();
let (listen_address, url, path, enable_local_peer_scope) = {
let c = self.config.get();
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
c.network.enable_local_peer_scope,
)
};
// Pick out TCP port we're going to use everywhere
// Keep sockets around until the end of this function
// to keep anyone else from binding in front of us
let (ws_port, ip_addrs) = self.allocate_tcp_port(listen_address.clone()).await?;
// Save the bound ws port for use later on
self.inner.lock().ws_port = ws_port;
trace!(
"WS: starting listener on port {} at {:?}",
ws_port,
ip_addrs
);
let socket_addresses = self
.start_tcp_listener(
ip_addrs,
ws_port,
false,
Box::new(|c, t, a| Box::new(WebsocketProtocolHandler::new(c, t, a))),
)
.await?;
trace!("WS: listener started on {:#?}", socket_addresses);
let mut static_public = false;
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
}
split_url.scheme = "ws".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host_port(80)
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
let pdi = DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
pdi.clone(),
DialInfoClass::Direct,
)?;
static_public = true;
// See if this public address is also a local interface address
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
for ip_addr in ip_addrs {
if gsa.ip() == *ip_addr {
return true;
}
}
false
});
if !registered_addresses.contains(&gsa.ip()) && is_interface_address {
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
pdi,
DialInfoClass::Direct,
)?;
}
registered_addresses.insert(gsa.ip());
}
}
for socket_address in socket_addresses {
// Skip addresses we already did
if registered_addresses.contains(&socket_address.to_ip_addr()) {
continue;
}
// Build dial info request url
let local_url = format!("ws://{}/{}", socket_address, path);
let local_di = DialInfo::try_ws(socket_address, local_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
if url.is_none() && (socket_address.address().is_global() || enable_local_peer_scope) {
// Register public dial info
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
local_di.clone(),
DialInfoClass::Direct,
)?;
static_public = true;
}
// Register local dial info
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
local_di,
DialInfoClass::Direct,
)?;
}
if static_public {
self.inner
.lock()
.static_public_dialinfo
.insert(ProtocolType::WS);
}
Ok(())
}
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
trace!("starting wss listeners");
let routing_table = self.routing_table();
let (listen_address, url) = {
let c = self.config.get();
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(),
)
};
// Pick out TCP port we're going to use everywhere
// Keep sockets around until the end of this function
// to keep anyone else from binding in front of us
let (wss_port, ip_addrs) = self.allocate_tcp_port(listen_address.clone()).await?;
// Save the bound wss port for use later on
self.inner.lock().wss_port = wss_port;
trace!(
"WSS: starting listener on port {} at {:?}",
wss_port,
ip_addrs
);
let socket_addresses = self
.start_tcp_listener(
ip_addrs,
wss_port,
true,
Box::new(|c, t, a| Box::new(WebsocketProtocolHandler::new(c, t, a))),
)
.await?;
trace!("WSS: listener started on {:#?}", socket_addresses);
// NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
// is specified, then TLS won't validate, so no local dialinfo is possible.
// This is not the case with unencrypted websockets, which can be specified solely by an IP address
let mut static_public = false;
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
// Add static public dialinfo if it's configured
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
}
split_url.scheme = "wss".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host_port(443)
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
let pdi = DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
pdi.clone(),
DialInfoClass::Direct,
)?;
static_public = true;
// See if this public address is also a local interface address
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
for ip_addr in ip_addrs {
if gsa.ip() == *ip_addr {
return true;
}
}
false
});
if !registered_addresses.contains(&gsa.ip()) && is_interface_address {
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
pdi,
DialInfoClass::Direct,
)?;
}
registered_addresses.insert(gsa.ip());
}
} else {
return Err("WSS URL must be specified due to TLS requirements".to_owned());
}
if static_public {
self.inner
.lock()
.static_public_dialinfo
.insert(ProtocolType::WSS);
}
Ok(())
}
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
trace!("starting tcp listeners");
let routing_table = self.routing_table();
let (listen_address, public_address, enable_local_peer_scope) = {
let c = self.config.get();
(
c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(),
c.network.enable_local_peer_scope,
)
};
// Pick out TCP port we're going to use everywhere
// Keep sockets around until the end of this function
// to keep anyone else from binding in front of us
let (tcp_port, ip_addrs) = self.allocate_tcp_port(listen_address.clone()).await?;
// Save the bound tcp port for use later on
self.inner.lock().tcp_port = tcp_port;
trace!(
"TCP: starting listener on port {} at {:?}",
tcp_port,
ip_addrs
);
let socket_addresses = self
.start_tcp_listener(
ip_addrs,
tcp_port,
false,
Box::new(|_, _, a| Box::new(RawTcpProtocolHandler::new(a))),
)
.await?;
trace!("TCP: listener started on {:#?}", socket_addresses);
let mut static_public = false;
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
for socket_address in socket_addresses {
let di = DialInfo::tcp(socket_address);
// Register global dial info if no public address is specified
if public_address.is_none() && (di.is_global() || enable_local_peer_scope) {
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
di.clone(),
DialInfoClass::Direct,
)?;
static_public = true;
}
// Register interface dial info
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
di.clone(),
DialInfoClass::Direct,
)?;
registered_addresses.insert(socket_address.to_ip_addr());
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
// Skip addresses we already did
if registered_addresses.contains(&pdi_addr.ip()) {
continue;
}
let pdi = DialInfo::tcp_from_socketaddr(pdi_addr);
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
pdi.clone(),
DialInfoClass::Direct,
)?;
static_public = true;
// See if this public address is also a local interface address
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
for ip_addr in ip_addrs {
if pdi_addr.ip() == *ip_addr {
return true;
}
}
false
});
if is_interface_address {
routing_table.register_dial_info(
RoutingDomain::LocalNetwork,
pdi,
DialInfoClass::Direct,
)?;
}
}
}
if static_public {
self.inner
.lock()
.static_public_dialinfo
.insert(ProtocolType::TCP);
}
Ok(())
}
}

View File

@@ -0,0 +1,175 @@
use super::*;
use crate::xx::*;
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
// No accept support for WASM
} else {
use async_std::net::*;
///////////////////////////////////////////////////////////
// Accept
pub trait ProtocolAcceptHandler: ProtocolAcceptHandlerClone + Send + Sync {
fn on_accept(
&self,
stream: AsyncPeekStream,
tcp_stream: TcpStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>;
}
pub trait ProtocolAcceptHandlerClone {
fn clone_box(&self) -> Box<dyn ProtocolAcceptHandler>;
}
impl<T> ProtocolAcceptHandlerClone for T
where
T: 'static + ProtocolAcceptHandler + Clone,
{
fn clone_box(&self) -> Box<dyn ProtocolAcceptHandler> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn ProtocolAcceptHandler> {
fn clone(&self) -> Box<dyn ProtocolAcceptHandler> {
self.clone_box()
}
}
pub type NewProtocolAcceptHandler =
dyn Fn(VeilidConfig, bool, SocketAddr) -> Box<dyn ProtocolAcceptHandler> + Send;
}
}
///////////////////////////////////////////////////////////
// Dummy protocol network connection for testing
#[derive(Debug)]
pub struct DummyNetworkConnection {}
impl DummyNetworkConnection {
pub fn close(&self) -> Result<(), String> {
Ok(())
}
pub fn send(&self, _message: Vec<u8>) -> Result<(), String> {
Ok(())
}
pub fn recv(&self) -> Result<Vec<u8>, String> {
Ok(Vec::new())
}
}
///////////////////////////////////////////////////////////
// Top-level protocol independent network connection object
#[derive(Debug, Clone)]
pub struct NetworkConnectionStats {
last_message_sent_time: Option<u64>,
last_message_recv_time: Option<u64>,
}
#[derive(Debug)]
struct NetworkConnectionInner {
stats: NetworkConnectionStats,
}
#[derive(Debug)]
struct NetworkConnectionArc {
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
established_time: u64,
inner: Mutex<NetworkConnectionInner>,
}
#[derive(Clone, Debug)]
pub struct NetworkConnection {
arc: Arc<NetworkConnectionArc>,
}
impl PartialEq for NetworkConnection {
fn eq(&self, other: &Self) -> bool {
Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc)
}
}
impl Eq for NetworkConnection {}
impl NetworkConnection {
fn new_inner() -> NetworkConnectionInner {
NetworkConnectionInner {
stats: NetworkConnectionStats {
last_message_sent_time: None,
last_message_recv_time: None,
},
}
}
fn new_arc(
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
) -> NetworkConnectionArc {
NetworkConnectionArc {
descriptor,
protocol_connection,
established_time: intf::get_timestamp(),
inner: Mutex::new(Self::new_inner()),
}
}
pub fn dummy(descriptor: ConnectionDescriptor) -> Self {
NetworkConnection::from_protocol(
descriptor,
ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}),
)
}
pub fn from_protocol(
descriptor: ConnectionDescriptor,
protocol_connection: ProtocolNetworkConnection,
) -> Self {
Self {
arc: Arc::new(Self::new_arc(descriptor, protocol_connection)),
}
}
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
ProtocolNetworkConnection::connect(local_address, dial_info).await
}
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
self.arc.descriptor
}
pub async fn close(&self) -> Result<(), String> {
self.arc.protocol_connection.close().await
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
let ts = intf::get_timestamp();
let out = self.arc.protocol_connection.send(message).await;
if out.is_ok() {
let mut inner = self.arc.inner.lock();
inner.stats.last_message_sent_time.max_assign(Some(ts));
}
out
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let ts = intf::get_timestamp();
let out = self.arc.protocol_connection.recv().await;
if out.is_ok() {
let mut inner = self.arc.inner.lock();
inner.stats.last_message_recv_time.max_assign(Some(ts));
}
out
}
pub fn stats(&self) -> NetworkConnectionStats {
let inner = self.arc.inner.lock();
inner.stats.clone()
}
pub fn established_time(&self) -> u64 {
self.arc.established_time
}
}

View File

@@ -0,0 +1,2 @@
pub mod test_connection_table;
use super::*;

View File

@@ -0,0 +1,101 @@
use super::connection_table::*;
use super::network_connection::*;
use crate::tests::common::test_veilid_config::*;
use crate::xx::*;
use crate::*;
pub async fn test_add_get_remove() {
let config = get_config();
let mut table = ConnectionTable::new(config);
let a1 = ConnectionDescriptor::new_no_local(PeerAddress::new(
SocketAddress::new(Address::IPV4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
ProtocolType::TCP,
));
let a2 = a1;
let a3 = ConnectionDescriptor::new(
PeerAddress::new(
SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090),
ProtocolType::TCP,
),
SocketAddress::from_socket_addr(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
8080,
0,
0,
))),
);
let a4 = ConnectionDescriptor::new(
PeerAddress::new(
SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090),
ProtocolType::TCP,
),
SocketAddress::from_socket_addr(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(1, 0, 0, 0, 0, 0, 0, 1),
8080,
0,
0,
))),
);
let a5 = ConnectionDescriptor::new(
PeerAddress::new(
SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090),
ProtocolType::WSS,
),
SocketAddress::from_socket_addr(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
8080,
0,
0,
))),
);
let c1 = NetworkConnection::dummy(a1);
let c2 = NetworkConnection::dummy(a2);
let c3 = NetworkConnection::dummy(a3);
let c4 = NetworkConnection::dummy(a4);
let c5 = NetworkConnection::dummy(a5);
assert_eq!(a1, c2.connection_descriptor());
assert_ne!(a3, c4.connection_descriptor());
assert_ne!(a4, c5.connection_descriptor());
assert_eq!(table.connection_count(), 0);
assert_eq!(table.get_connection(a1), None);
table.add_connection(c1.clone()).unwrap();
assert_eq!(table.connection_count(), 1);
assert_err!(table.remove_connection(a3));
assert_err!(table.remove_connection(a4));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.connection_count(), 1);
assert_err!(table.add_connection(c1.clone()));
assert_err!(table.add_connection(c2.clone()));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.get_connection(a1), Some(c1.clone()));
assert_eq!(table.connection_count(), 1);
assert_eq!(table.remove_connection(a2), Ok(c1.clone()));
assert_eq!(table.connection_count(), 0);
assert_err!(table.remove_connection(a2));
assert_eq!(table.connection_count(), 0);
assert_eq!(table.get_connection(a2), None);
assert_eq!(table.get_connection(a1), None);
assert_eq!(table.connection_count(), 0);
table.add_connection(c1.clone()).unwrap();
assert_err!(table.add_connection(c2));
table.add_connection(c3.clone()).unwrap();
table.add_connection(c4.clone()).unwrap();
assert_eq!(table.connection_count(), 3);
assert_eq!(table.remove_connection(a2), Ok(c1));
assert_eq!(table.remove_connection(a3), Ok(c3));
assert_eq!(table.remove_connection(a4), Ok(c4));
assert_eq!(table.connection_count(), 0);
}
pub async fn test_all() {
test_add_get_remove().await;
}

View File

@@ -0,0 +1,235 @@
mod protocol;
use crate::connection_manager::*;
use crate::network_manager::*;
use crate::routing_table::*;
use crate::intf::*;
use crate::*;
use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*;
/////////////////////////////////////////////////////////////////
struct NetworkInner {
network_manager: NetworkManager,
stop_network: Eventual,
network_started: bool,
network_needs_restart: bool,
protocol_config: Option<ProtocolConfig>,
//join_handle: TryJoin?
}
#[derive(Clone)]
pub struct Network {
config: VeilidConfig,
inner: Arc<Mutex<NetworkInner>>,
}
impl Network {
fn new_inner(network_manager: NetworkManager) -> NetworkInner {
NetworkInner {
network_manager,
stop_network: Eventual::new(),
network_started: false,
network_needs_restart: false,
protocol_config: None, //join_handle: None,
}
}
pub fn new(network_manager: NetworkManager) -> Self {
Self {
config: network_manager.config(),
inner: Arc::new(Mutex::new(Self::new_inner(network_manager))),
}
}
fn network_manager(&self) -> NetworkManager {
self.inner.lock().network_manager.clone()
}
fn connection_manager(&self) -> ConnectionManager {
self.inner.lock().network_manager.connection_manager()
}
/////////////////////////////////////////////////////////////////
pub async fn send_data_unbound_to_dial_info(
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
let data_len = data.len();
let res = match dial_info.protocol_type() {
ProtocolType::UDP => {
return Err("no support for UDP protocol".to_owned()).map_err(logthru_net!(error))
}
ProtocolType::TCP => {
return Err("no support for TCP protocol".to_owned()).map_err(logthru_net!(error))
}
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
.await
.map_err(logthru_net!())
}
};
if res.is_ok() {
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
}
res
}
pub async fn send_data_to_existing_connection(
&self,
descriptor: ConnectionDescriptor,
data: Vec<u8>,
) -> Result<Option<Vec<u8>>, String> {
let data_len = data.len();
match descriptor.protocol_type() {
ProtocolType::UDP => {
return Err("no support for udp protocol".to_owned()).map_err(logthru_net!(error))
}
ProtocolType::TCP => {
return Err("no support for tcp protocol".to_owned()).map_err(logthru_net!(error))
}
_ => {}
}
// Handle connection-oriented protocols
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
// connection exists, send over it
conn.send(data).await.map_err(logthru_net!())?;
// Network accounting
self.network_manager()
.stats_packet_sent(descriptor.remote.to_socket_addr().ip(), data_len as u64);
// Data was consumed
Ok(None)
} else {
// Connection or didn't exist
// Pass the data back out so we don't own it any more
Ok(Some(data))
}
}
pub async fn send_data_to_dial_info(
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
let data_len = data.len();
if dial_info.protocol_type() == ProtocolType::UDP {
return Err("no support for UDP protocol".to_owned()).map_err(logthru_net!(error))
}
if dial_info.protocol_type() == ProtocolType::TCP {
return Err("no support for TCP protocol".to_owned()).map_err(logthru_net!(error))
}
// Handle connection-oriented protocols
let conn = self
.connection_manager()
.get_or_create_connection(None, dial_info.clone())
.await?;
let res = conn.send(data).await.map_err(logthru_net!(error));
if res.is_ok() {
// Network accounting
self.network_manager()
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
}
res
}
/////////////////////////////////////////////////////////////////
pub async fn startup(&self) -> Result<(), String> {
// get protocol config
self.inner.lock().protocol_config = Some({
let c = self.config.get();
let inbound = ProtocolSet::new();
let mut outbound = ProtocolSet::new();
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
outbound.insert(ProtocolType::WS);
}
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
outbound.insert(ProtocolType::WSS);
}
ProtocolConfig { inbound, outbound }
});
self.inner.lock().network_started = true;
Ok(())
}
pub fn needs_restart(&self) -> bool {
self.inner.lock().network_needs_restart
}
pub fn is_started(&self) -> bool {
self.inner.lock().network_started
}
pub fn restart_network(&self) {
self.inner.lock().network_needs_restart = true;
}
pub async fn shutdown(&self) {
trace!("stopping network");
// Reset state
let network_manager = self.inner.lock().network_manager.clone();
let routing_table = network_manager.routing_table();
// Drop all dial info
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
// Cancels all async background tasks by dropping join handles
*self.inner.lock() = Self::new_inner(network_manager);
trace!("network stopped");
}
pub fn with_interface_addresses<F, R>(&self, f: F) -> R
where
F: FnOnce(&[IpAddr]) -> R,
{
f(&[])
}
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
Ok(false)
}
//////////////////////////////////////////
pub fn get_network_class(&self) -> Option<NetworkClass> {
// xxx eventually detect tor browser?
return if self.inner.lock().network_started {
Some(NetworkClass::WebApp)
} else {
None
};
}
pub fn reset_network_class(&self) {
//let mut inner = self.inner.lock();
//inner.network_class = None;
}
pub fn get_protocol_config(&self) -> Option<ProtocolConfig> {
self.inner.lock().protocol_config.clone()
}
//////////////////////////////////////////
pub async fn tick(&self) -> Result<(), String> {
Ok(())
}
}

View File

@@ -0,0 +1,68 @@
pub mod wrtc;
pub mod ws;
use crate::network_connection::*;
use crate::xx::*;
use crate::*;
#[derive(Debug)]
pub enum ProtocolNetworkConnection {
Dummy(DummyNetworkConnection),
Ws(ws::WebsocketNetworkConnection),
//WebRTC(wrtc::WebRTCNetworkConnection),
}
impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not support on WASM targets");
}
ProtocolType::TCP => {
panic!("TCP dial info is not support on WASM targets");
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
}
}
}
pub async fn send_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not support on WASM targets");
}
ProtocolType::TCP => {
panic!("TCP dial info is not support on WASM targets");
}
ProtocolType::WS | ProtocolType::WSS => {
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
}
}
}
pub async fn close(&self) -> Result<(), String> {
match self {
Self::Dummy(d) => d.close(),
Self::Ws(w) => w.close().await,
}
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
match self {
Self::Dummy(d) => d.send(message),
Self::Ws(w) => w.send(message).await,
}
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
match self {
Self::Dummy(d) => d.recv(),
Self::Ws(w) => w.recv().await,
}
}
}

View File

@@ -0,0 +1,123 @@
use crate::intf::*;
use crate::network_connection::*;
use crate::network_manager::MAX_MESSAGE_SIZE;
use crate::*;
use alloc::fmt;
use ws_stream_wasm::*;
use futures_util::{StreamExt, SinkExt};
struct WebsocketNetworkConnectionInner {
ws_meta: WsMeta,
ws_stream: CloneStream<WsStream>,
}
#[derive(Clone)]
pub struct WebsocketNetworkConnection {
inner: Arc<WebsocketNetworkConnectionInner>,
}
impl fmt::Debug for WebsocketNetworkConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", core::any::type_name::<Self>())
}
}
impl WebsocketNetworkConnection {
pub fn new(ws_meta: WsMeta, ws_stream: WsStream) -> Self {
Self {
inner: Arc::new(WebsocketNetworkConnectionInner {
ws_meta,
ws_stream: CloneStream::new(ws_stream),
}),
}
}
pub async fn close(&self) -> Result<(), String> {
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
}
self.inner.ws_stream.clone()
.send(WsMessage::Binary(message)).await
.map_err(|_| "failed to send to websocket".to_owned())
.map_err(logthru_net!(error))
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
let out = match self.inner.ws_stream.clone().next().await {
Some(WsMessage::Binary(v)) => v,
Some(_) => {
return Err("Unexpected WS message type".to_owned())
.map_err(logthru_net!(error));
}
None => {
return Err("WS stream closed".to_owned()).map_err(logthru_net!(error));
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
Ok(out)
}
}
}
///////////////////////////////////////////////////////////
///
pub struct WebsocketProtocolHandler {}
impl WebsocketProtocolHandler {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<NetworkConnection, String> {
assert!(local_address.is_none());
// Split dial info up
let (_tls, scheme) = match &dial_info {
DialInfo::WS(_) => (false, "ws"),
DialInfo::WSS(_) => (true, "wss"),
_ => panic!("invalid dialinfo for WS/WSS protocol"),
};
let request = dial_info.request().unwrap();
let split_url = SplitUrl::from_str(&request)?;
if split_url.scheme != scheme {
return Err("invalid websocket url scheme".to_string());
}
let (wsmeta, wsio) = WsMeta::connect(request, None)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
// Make our connection descriptor
Ok(NetworkConnection::from_protocol(ConnectionDescriptor {
local: None,
remote: dial_info.to_peer_address(),
},ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(wsmeta, wsio))))
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());
}
trace!(
"sending unbound websocket message of length {} to {}",
data.len(),
dial_info,
);
// Make the real connection
let conn = Self::connect(None, dial_info)
.await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
conn.send(data).await
}
}