refactor network manager
This commit is contained in:
218
veilid-core/src/network_manager/connection_limits.rs
Normal file
218
veilid-core/src/network_manager/connection_limits.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use alloc::collections::btree_map::Entry;
|
||||
use core::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AddressFilterError {
|
||||
CountExceeded,
|
||||
RateExceeded,
|
||||
}
|
||||
impl fmt::Display for AddressFilterError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match *self {
|
||||
Self::CountExceeded => "Count exceeded",
|
||||
Self::RateExceeded => "Rate exceeded",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
impl std::error::Error for AddressFilterError {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct AddressNotInTableError {}
|
||||
impl fmt::Display for AddressNotInTableError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Address not in table")
|
||||
}
|
||||
}
|
||||
impl std::error::Error for AddressNotInTableError {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionLimits {
|
||||
max_connections_per_ip4: usize,
|
||||
max_connections_per_ip6_prefix: usize,
|
||||
max_connections_per_ip6_prefix_size: usize,
|
||||
max_connection_frequency_per_min: usize,
|
||||
conn_count_by_ip4: BTreeMap<Ipv4Addr, usize>,
|
||||
conn_count_by_ip6_prefix: BTreeMap<Ipv6Addr, usize>,
|
||||
conn_timestamps_by_ip4: BTreeMap<Ipv4Addr, Vec<u64>>,
|
||||
conn_timestamps_by_ip6_prefix: BTreeMap<Ipv6Addr, Vec<u64>>,
|
||||
}
|
||||
|
||||
impl ConnectionLimits {
|
||||
pub fn new(config: VeilidConfig) -> Self {
|
||||
let c = config.get();
|
||||
Self {
|
||||
max_connections_per_ip4: c.network.max_connections_per_ip4 as usize,
|
||||
max_connections_per_ip6_prefix: c.network.max_connections_per_ip6_prefix as usize,
|
||||
max_connections_per_ip6_prefix_size: c.network.max_connections_per_ip6_prefix_size
|
||||
as usize,
|
||||
max_connection_frequency_per_min: c.network.max_connection_frequency_per_min as usize,
|
||||
conn_count_by_ip4: BTreeMap::new(),
|
||||
conn_count_by_ip6_prefix: BTreeMap::new(),
|
||||
conn_timestamps_by_ip4: BTreeMap::new(),
|
||||
conn_timestamps_by_ip6_prefix: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// Converts an ip to a ip block by applying a netmask
|
||||
// to the host part of the ip address
|
||||
// ipv4 addresses are treated as single hosts
|
||||
// ipv6 addresses are treated as prefix allocated blocks
|
||||
fn ip_to_ipblock(&self, addr: IpAddr) -> IpAddr {
|
||||
match addr {
|
||||
IpAddr::V4(_) => addr,
|
||||
IpAddr::V6(v6) => {
|
||||
let mut hostlen = 128usize.saturating_sub(self.max_connections_per_ip6_prefix_size);
|
||||
let mut out = v6.octets();
|
||||
for i in (0..16).rev() {
|
||||
if hostlen >= 8 {
|
||||
out[i] = 0xFF;
|
||||
hostlen -= 8;
|
||||
} else {
|
||||
out[i] |= !(0xFFu8 << hostlen);
|
||||
break;
|
||||
}
|
||||
}
|
||||
IpAddr::V6(Ipv6Addr::from(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn purge_old_timestamps(&mut self, cur_ts: u64) {
|
||||
// v4
|
||||
{
|
||||
let mut dead_keys = Vec::<Ipv4Addr>::new();
|
||||
for (key, value) in &mut self.conn_timestamps_by_ip4 {
|
||||
value.retain(|v| {
|
||||
// keep timestamps that are less than a minute away
|
||||
cur_ts.saturating_sub(*v) < 60_000_000u64
|
||||
});
|
||||
if value.is_empty() {
|
||||
dead_keys.push(*key);
|
||||
}
|
||||
}
|
||||
for key in dead_keys {
|
||||
self.conn_timestamps_by_ip4.remove(&key);
|
||||
}
|
||||
}
|
||||
// v6
|
||||
{
|
||||
let mut dead_keys = Vec::<Ipv6Addr>::new();
|
||||
for (key, value) in &mut self.conn_timestamps_by_ip6_prefix {
|
||||
value.retain(|v| {
|
||||
// keep timestamps that are less than a minute away
|
||||
cur_ts.saturating_sub(*v) < 60_000_000u64
|
||||
});
|
||||
if value.is_empty() {
|
||||
dead_keys.push(*key);
|
||||
}
|
||||
}
|
||||
for key in dead_keys {
|
||||
self.conn_timestamps_by_ip6_prefix.remove(&key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, addr: IpAddr) -> Result<(), AddressFilterError> {
|
||||
let ipblock = self.ip_to_ipblock(addr);
|
||||
let ts = intf::get_timestamp();
|
||||
|
||||
self.purge_old_timestamps(ts);
|
||||
|
||||
match ipblock {
|
||||
IpAddr::V4(v4) => {
|
||||
// See if we have too many connections from this ip block
|
||||
let cnt = &mut *self.conn_count_by_ip4.entry(v4).or_default();
|
||||
assert!(*cnt <= self.max_connections_per_ip4);
|
||||
if *cnt == self.max_connections_per_ip4 {
|
||||
warn!("address filter count exceeded: {:?}", v4);
|
||||
return Err(AddressFilterError::CountExceeded);
|
||||
}
|
||||
// See if this ip block has connected too frequently
|
||||
let tstamps = &mut self.conn_timestamps_by_ip4.entry(v4).or_default();
|
||||
tstamps.retain(|v| {
|
||||
// keep timestamps that are less than a minute away
|
||||
ts.saturating_sub(*v) < 60_000_000u64
|
||||
});
|
||||
assert!(tstamps.len() <= self.max_connection_frequency_per_min);
|
||||
if tstamps.len() == self.max_connection_frequency_per_min {
|
||||
warn!("address filter rate exceeded: {:?}", v4);
|
||||
return Err(AddressFilterError::RateExceeded);
|
||||
}
|
||||
|
||||
// If it's okay, add the counts and timestamps
|
||||
*cnt += 1;
|
||||
tstamps.push(ts);
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
// See if we have too many connections from this ip block
|
||||
let cnt = &mut *self.conn_count_by_ip6_prefix.entry(v6).or_default();
|
||||
assert!(*cnt <= self.max_connections_per_ip6_prefix);
|
||||
if *cnt == self.max_connections_per_ip6_prefix {
|
||||
warn!("address filter count exceeded: {:?}", v6);
|
||||
return Err(AddressFilterError::CountExceeded);
|
||||
}
|
||||
// See if this ip block has connected too frequently
|
||||
let tstamps = &mut self.conn_timestamps_by_ip6_prefix.entry(v6).or_default();
|
||||
assert!(tstamps.len() <= self.max_connection_frequency_per_min);
|
||||
if tstamps.len() == self.max_connection_frequency_per_min {
|
||||
warn!("address filter rate exceeded: {:?}", v6);
|
||||
return Err(AddressFilterError::RateExceeded);
|
||||
}
|
||||
|
||||
// If it's okay, add the counts and timestamps
|
||||
*cnt += 1;
|
||||
tstamps.push(ts);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, addr: IpAddr) -> Result<(), AddressNotInTableError> {
|
||||
let ipblock = self.ip_to_ipblock(addr);
|
||||
|
||||
let ts = intf::get_timestamp();
|
||||
self.purge_old_timestamps(ts);
|
||||
|
||||
match ipblock {
|
||||
IpAddr::V4(v4) => {
|
||||
match self.conn_count_by_ip4.entry(v4) {
|
||||
Entry::Vacant(_) => {
|
||||
return Err(AddressNotInTableError {});
|
||||
}
|
||||
Entry::Occupied(mut o) => {
|
||||
let cnt = o.get_mut();
|
||||
assert!(*cnt > 0);
|
||||
if *cnt == 0 {
|
||||
self.conn_count_by_ip4.remove(&v4);
|
||||
} else {
|
||||
*cnt -= 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
match self.conn_count_by_ip6_prefix.entry(v6) {
|
||||
Entry::Vacant(_) => {
|
||||
return Err(AddressNotInTableError {});
|
||||
}
|
||||
Entry::Occupied(mut o) => {
|
||||
let cnt = o.get_mut();
|
||||
assert!(*cnt > 0);
|
||||
if *cnt == 0 {
|
||||
self.conn_count_by_ip6_prefix.remove(&v6);
|
||||
} else {
|
||||
*cnt -= 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
243
veilid-core/src/network_manager/connection_manager.rs
Normal file
243
veilid-core/src/network_manager/connection_manager.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use super::*;
|
||||
use crate::xx::*;
|
||||
use connection_table::*;
|
||||
use network_connection::*;
|
||||
|
||||
const CONNECTION_PROCESSOR_CHANNEL_SIZE: usize = 128usize;
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
// Connection manager
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConnectionManagerInner {
|
||||
connection_table: ConnectionTable,
|
||||
}
|
||||
|
||||
struct ConnectionManagerArc {
|
||||
network_manager: NetworkManager,
|
||||
inner: AsyncMutex<ConnectionManagerInner>,
|
||||
}
|
||||
impl core::fmt::Debug for ConnectionManagerArc {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.debug_struct("ConnectionManagerArc")
|
||||
.field("inner", &self.inner)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectionManager {
|
||||
arc: Arc<ConnectionManagerArc>,
|
||||
}
|
||||
|
||||
impl ConnectionManager {
|
||||
fn new_inner(config: VeilidConfig) -> ConnectionManagerInner {
|
||||
ConnectionManagerInner {
|
||||
connection_table: ConnectionTable::new(config),
|
||||
}
|
||||
}
|
||||
fn new_arc(network_manager: NetworkManager) -> ConnectionManagerArc {
|
||||
let config = network_manager.config();
|
||||
ConnectionManagerArc {
|
||||
network_manager,
|
||||
inner: AsyncMutex::new(Self::new_inner(config)),
|
||||
}
|
||||
}
|
||||
pub fn new(network_manager: NetworkManager) -> Self {
|
||||
Self {
|
||||
arc: Arc::new(Self::new_arc(network_manager)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn network_manager(&self) -> NetworkManager {
|
||||
self.arc.network_manager.clone()
|
||||
}
|
||||
|
||||
pub async fn startup(&self) {
|
||||
trace!("startup connection manager");
|
||||
//let mut inner = self.arc.inner.lock().await;
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
// xxx close all connections in the connection table
|
||||
|
||||
*self.arc.inner.lock().await = Self::new_inner(self.arc.network_manager.config());
|
||||
}
|
||||
|
||||
// Returns a network connection if one already is established
|
||||
pub async fn get_connection(
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Option<NetworkConnection> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
inner.connection_table.get_connection(descriptor)
|
||||
}
|
||||
|
||||
// Internal routine to register new connection atomically
|
||||
fn on_new_connection_internal(
|
||||
&self,
|
||||
inner: &mut ConnectionManagerInner,
|
||||
conn: NetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
log_net!("on_new_connection_internal: {:?}", conn);
|
||||
let tx = inner
|
||||
.connection_add_channel_tx
|
||||
.as_ref()
|
||||
.ok_or_else(fn_string!("connection channel isn't open yet"))?
|
||||
.clone();
|
||||
|
||||
let receiver_loop_future = Self::process_connection(self.clone(), conn.clone());
|
||||
tx.try_send(receiver_loop_future)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to start receiver loop"))?;
|
||||
|
||||
// If the receiver loop started successfully,
|
||||
// add the new connection to the table
|
||||
inner.connection_table.add_connection(conn)
|
||||
}
|
||||
|
||||
// Called by low-level network when any connection-oriented protocol connection appears
|
||||
// either from incoming or outgoing connections. Registers connection in the connection table for later access
|
||||
// and spawns a message processing loop for the connection
|
||||
pub async fn on_new_connection(&self, conn: NetworkConnection) -> Result<(), String> {
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
self.on_new_connection_internal(&mut *inner, conn)
|
||||
}
|
||||
|
||||
pub async fn get_or_create_connection(
|
||||
&self,
|
||||
local_addr: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
log_net!(
|
||||
"== get_or_create_connection local_addr={:?} dial_info={:?}",
|
||||
local_addr.green(),
|
||||
dial_info.green()
|
||||
);
|
||||
|
||||
let peer_address = dial_info.to_peer_address();
|
||||
let descriptor = match local_addr {
|
||||
Some(la) => {
|
||||
ConnectionDescriptor::new(peer_address, SocketAddress::from_socket_addr(la))
|
||||
}
|
||||
None => ConnectionDescriptor::new_no_local(peer_address),
|
||||
};
|
||||
|
||||
// If any connection to this remote exists that has the same protocol, return it
|
||||
// Any connection will do, we don't have to match the local address
|
||||
let mut inner = self.arc.inner.lock().await;
|
||||
|
||||
if let Some(conn) = inner
|
||||
.connection_table
|
||||
.get_last_connection_by_remote(descriptor.remote)
|
||||
{
|
||||
log_net!(
|
||||
"== Returning existing connection local_addr={:?} peer_address={:?}",
|
||||
local_addr.green(),
|
||||
peer_address.green()
|
||||
);
|
||||
|
||||
return Ok(conn);
|
||||
}
|
||||
|
||||
// Drop any other protocols connections that have the same local addr
|
||||
// otherwise this connection won't succeed due to binding
|
||||
if let Some(local_addr) = local_addr {
|
||||
if local_addr.port() != 0 {
|
||||
for pt in [ProtocolType::TCP, ProtocolType::WS, ProtocolType::WSS] {
|
||||
let pa = PeerAddress::new(descriptor.remote.socket_address, pt);
|
||||
for conn in inner.connection_table.get_connections_by_remote(pa) {
|
||||
let desc = conn.connection_descriptor();
|
||||
let mut kill = false;
|
||||
if let Some(conn_local) = desc.local {
|
||||
if (local_addr.ip().is_unspecified()
|
||||
|| (local_addr.ip() == conn_local.to_ip_addr()))
|
||||
&& conn_local.port() == local_addr.port()
|
||||
{
|
||||
kill = true;
|
||||
}
|
||||
}
|
||||
if kill {
|
||||
log_net!(debug
|
||||
">< Terminating connection local_addr={:?} peer_address={:?}",
|
||||
local_addr.green(),
|
||||
pa.green()
|
||||
);
|
||||
conn.close().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt new connection
|
||||
let conn = NetworkConnection::connect(local_addr, dial_info).await?;
|
||||
|
||||
self.on_new_connection_internal(&mut *inner, conn.clone())?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
// Connection receiver loop
|
||||
fn process_connection(
|
||||
this: ConnectionManager,
|
||||
conn: NetworkConnection,
|
||||
) -> SystemPinBoxFuture<()> {
|
||||
log_net!("Starting process_connection loop for {:?}", conn.green());
|
||||
let network_manager = this.network_manager();
|
||||
Box::pin(async move {
|
||||
//
|
||||
let descriptor = conn.connection_descriptor();
|
||||
let inactivity_timeout = this
|
||||
.network_manager()
|
||||
.config()
|
||||
.get()
|
||||
.network
|
||||
.connection_inactivity_timeout_ms;
|
||||
loop {
|
||||
// process inactivity timeout on receives only
|
||||
// if you want a keepalive, it has to be requested from the other side
|
||||
let message = select! {
|
||||
res = conn.recv().fuse() => {
|
||||
match res {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
log_net!(debug e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = intf::sleep(inactivity_timeout).fuse()=> {
|
||||
// timeout
|
||||
log_net!("connection timeout on {:?}", descriptor.green());
|
||||
break;
|
||||
}
|
||||
};
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(message.as_slice(), descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log_net!(
|
||||
"== Connection loop finished local_addr={:?} remote={:?}",
|
||||
descriptor.local.green(),
|
||||
descriptor.remote.green()
|
||||
);
|
||||
|
||||
if let Err(e) = this
|
||||
.arc
|
||||
.inner
|
||||
.lock()
|
||||
.await
|
||||
.connection_table
|
||||
.remove_connection(descriptor)
|
||||
{
|
||||
log_net!(error e);
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
168
veilid-core/src/network_manager/connection_table.rs
Normal file
168
veilid-core/src/network_manager/connection_table.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use super::connection_limits::*;
|
||||
use super::network_connection::*;
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use alloc::collections::btree_map::Entry;
|
||||
use hashlink::LruCache;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionTable {
|
||||
max_connections: Vec<usize>,
|
||||
conn_by_descriptor: Vec<LruCache<ConnectionDescriptor, NetworkConnection>>,
|
||||
conns_by_remote: BTreeMap<PeerAddress, Vec<NetworkConnection>>,
|
||||
address_filter: ConnectionLimits,
|
||||
}
|
||||
|
||||
fn protocol_to_index(protocol: ProtocolType) -> usize {
|
||||
match protocol {
|
||||
ProtocolType::TCP => 0,
|
||||
ProtocolType::WS => 1,
|
||||
ProtocolType::WSS => 2,
|
||||
ProtocolType::UDP => panic!("not a connection-oriented protocol"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionTable {
|
||||
pub fn new(config: VeilidConfig) -> Self {
|
||||
let max_connections = {
|
||||
let c = config.get();
|
||||
vec![
|
||||
c.network.protocol.tcp.max_connections as usize,
|
||||
c.network.protocol.ws.max_connections as usize,
|
||||
c.network.protocol.wss.max_connections as usize,
|
||||
]
|
||||
};
|
||||
Self {
|
||||
max_connections,
|
||||
conn_by_descriptor: vec![
|
||||
LruCache::new_unbounded(),
|
||||
LruCache::new_unbounded(),
|
||||
LruCache::new_unbounded(),
|
||||
],
|
||||
conns_by_remote: BTreeMap::new(),
|
||||
address_filter: ConnectionLimits::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
|
||||
let descriptor = conn.connection_descriptor();
|
||||
let ip_addr = descriptor.remote.socket_address.to_ip_addr();
|
||||
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
if self.conn_by_descriptor[index].contains_key(&descriptor) {
|
||||
return Err(format!(
|
||||
"Connection already added to table: {:?}",
|
||||
descriptor
|
||||
));
|
||||
}
|
||||
|
||||
// Filter by ip for connection limits
|
||||
self.address_filter.add(ip_addr).map_err(map_to_string)?;
|
||||
|
||||
// Add the connection to the table
|
||||
let res = self.conn_by_descriptor[index].insert(descriptor, conn.clone());
|
||||
assert!(res.is_none());
|
||||
|
||||
// if we have reached the maximum number of connections per protocol type
|
||||
// then drop the least recently used connection
|
||||
if self.conn_by_descriptor[index].len() > self.max_connections[index] {
|
||||
if let Some((lruk, _)) = self.conn_by_descriptor[index].remove_lru() {
|
||||
warn!("XX: connection lru out: {:?}", lruk);
|
||||
self.remove_connection_records(lruk);
|
||||
}
|
||||
}
|
||||
|
||||
// add connection records
|
||||
let conns = self.conns_by_remote.entry(descriptor.remote).or_default();
|
||||
|
||||
warn!("add_connection: {:?}", conn);
|
||||
conns.push(conn);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_connection(
|
||||
&mut self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Option<NetworkConnection> {
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let out = self.conn_by_descriptor[index].get(&descriptor).cloned();
|
||||
warn!("get_connection: {:?} -> {:?}", descriptor, out);
|
||||
out
|
||||
}
|
||||
|
||||
pub fn get_last_connection_by_remote(
|
||||
&mut self,
|
||||
remote: PeerAddress,
|
||||
) -> Option<NetworkConnection> {
|
||||
let out = self
|
||||
.conns_by_remote
|
||||
.get(&remote)
|
||||
.map(|v| v[(v.len() - 1)].clone());
|
||||
warn!("get_last_connection_by_remote: {:?} -> {:?}", remote, out);
|
||||
if let Some(connection) = &out {
|
||||
// lru bump
|
||||
let index = protocol_to_index(connection.connection_descriptor().protocol_type());
|
||||
let _ = self.conn_by_descriptor[index].get(&connection.connection_descriptor());
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn get_connections_by_remote(&mut self, remote: PeerAddress) -> Vec<NetworkConnection> {
|
||||
let out = self
|
||||
.conns_by_remote
|
||||
.get(&remote)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
warn!("get_connections_by_remote: {:?} -> {:?}", remote, out);
|
||||
out
|
||||
}
|
||||
|
||||
pub fn connection_count(&self) -> usize {
|
||||
self.conn_by_descriptor.iter().fold(0, |b, c| b + c.len())
|
||||
}
|
||||
|
||||
fn remove_connection_records(&mut self, descriptor: ConnectionDescriptor) {
|
||||
let ip_addr = descriptor.remote.socket_address.to_ip_addr();
|
||||
|
||||
// conns_by_remote
|
||||
match self.conns_by_remote.entry(descriptor.remote) {
|
||||
Entry::Vacant(_) => {
|
||||
panic!("inconsistency in connection table")
|
||||
}
|
||||
Entry::Occupied(mut o) => {
|
||||
let v = o.get_mut();
|
||||
|
||||
// Remove one matching connection from the list
|
||||
for (n, elem) in v.iter().enumerate() {
|
||||
if elem.connection_descriptor() == descriptor {
|
||||
v.remove(n);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// No connections left for this remote, remove the entry from conns_by_remote
|
||||
if v.is_empty() {
|
||||
o.remove_entry();
|
||||
}
|
||||
}
|
||||
}
|
||||
self.address_filter
|
||||
.remove(ip_addr)
|
||||
.expect("Inconsistency in connection table");
|
||||
}
|
||||
|
||||
pub fn remove_connection(
|
||||
&mut self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
warn!("remove_connection: {:?}", descriptor);
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let out = self.conn_by_descriptor[index]
|
||||
.remove(&descriptor)
|
||||
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
|
||||
|
||||
self.remove_connection_records(descriptor);
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
1433
veilid-core/src/network_manager/mod.rs
Normal file
1433
veilid-core/src/network_manager/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
550
veilid-core/src/network_manager/native/mod.rs
Normal file
550
veilid-core/src/network_manager/native/mod.rs
Normal file
@@ -0,0 +1,550 @@
|
||||
mod network_class_discovery;
|
||||
mod network_tcp;
|
||||
mod network_udp;
|
||||
mod protocol;
|
||||
mod start_protocols;
|
||||
|
||||
use crate::intf::*;
|
||||
use crate::network_manager::*;
|
||||
use crate::routing_table::*;
|
||||
use connection_manager::*;
|
||||
use network_tcp::*;
|
||||
use protocol::tcp::RawTcpProtocolHandler;
|
||||
use protocol::udp::RawUdpProtocolHandler;
|
||||
use protocol::ws::WebsocketProtocolHandler;
|
||||
pub use protocol::*;
|
||||
use utils::network_interfaces::*;
|
||||
|
||||
use async_std::io;
|
||||
use async_std::net::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use futures_util::StreamExt;
|
||||
// xxx: rustls ^0.20
|
||||
//use rustls::{server::NoClientAuth, Certificate, PrivateKey, ServerConfig};
|
||||
use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
|
||||
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub const PEEK_DETECT_LEN: usize = 64;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
struct NetworkInner {
|
||||
routing_table: RoutingTable,
|
||||
network_manager: NetworkManager,
|
||||
network_started: bool,
|
||||
network_needs_restart: bool,
|
||||
protocol_config: Option<ProtocolConfig>,
|
||||
static_public_dialinfo: ProtocolSet,
|
||||
network_class: Option<NetworkClass>,
|
||||
join_handles: Vec<JoinHandle<()>>,
|
||||
udp_port: u16,
|
||||
tcp_port: u16,
|
||||
ws_port: u16,
|
||||
wss_port: u16,
|
||||
interfaces: NetworkInterfaces,
|
||||
// udp
|
||||
bound_first_udp: BTreeMap<u16, Option<(socket2::Socket, socket2::Socket)>>,
|
||||
inbound_udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
|
||||
outbound_udpv4_protocol_handler: Option<RawUdpProtocolHandler>,
|
||||
outbound_udpv6_protocol_handler: Option<RawUdpProtocolHandler>,
|
||||
//tcp
|
||||
bound_first_tcp: BTreeMap<u16, Option<(socket2::Socket, socket2::Socket)>>,
|
||||
tls_acceptor: Option<TlsAcceptor>,
|
||||
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
|
||||
}
|
||||
|
||||
struct NetworkUnlockedInner {
|
||||
// Background processes
|
||||
update_network_class_task: TickTask,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Network {
|
||||
config: VeilidConfig,
|
||||
inner: Arc<Mutex<NetworkInner>>,
|
||||
unlocked_inner: Arc<NetworkUnlockedInner>,
|
||||
}
|
||||
|
||||
impl Network {
|
||||
fn new_inner(network_manager: NetworkManager) -> NetworkInner {
|
||||
NetworkInner {
|
||||
routing_table: network_manager.routing_table(),
|
||||
network_manager,
|
||||
network_started: false,
|
||||
network_needs_restart: false,
|
||||
protocol_config: None,
|
||||
static_public_dialinfo: ProtocolSet::empty(),
|
||||
network_class: None,
|
||||
join_handles: Vec::new(),
|
||||
udp_port: 0u16,
|
||||
tcp_port: 0u16,
|
||||
ws_port: 0u16,
|
||||
wss_port: 0u16,
|
||||
interfaces: NetworkInterfaces::new(),
|
||||
bound_first_udp: BTreeMap::new(),
|
||||
inbound_udp_protocol_handlers: BTreeMap::new(),
|
||||
outbound_udpv4_protocol_handler: None,
|
||||
outbound_udpv6_protocol_handler: None,
|
||||
bound_first_tcp: BTreeMap::new(),
|
||||
tls_acceptor: None,
|
||||
listener_states: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_unlocked_inner() -> NetworkUnlockedInner {
|
||||
NetworkUnlockedInner {
|
||||
update_network_class_task: TickTask::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(network_manager: NetworkManager) -> Self {
|
||||
let this = Self {
|
||||
config: network_manager.config(),
|
||||
inner: Arc::new(Mutex::new(Self::new_inner(network_manager))),
|
||||
unlocked_inner: Arc::new(Self::new_unlocked_inner()),
|
||||
};
|
||||
|
||||
// Set update network class tick task
|
||||
{
|
||||
let this2 = this.clone();
|
||||
this.unlocked_inner
|
||||
.update_network_class_task
|
||||
.set_routine(move |l, t| {
|
||||
Box::pin(this2.clone().update_network_class_task_routine(l, t))
|
||||
});
|
||||
}
|
||||
|
||||
this
|
||||
}
|
||||
|
||||
fn network_manager(&self) -> NetworkManager {
|
||||
self.inner.lock().network_manager.clone()
|
||||
}
|
||||
|
||||
fn routing_table(&self) -> RoutingTable {
|
||||
self.inner.lock().routing_table.clone()
|
||||
}
|
||||
|
||||
fn connection_manager(&self) -> ConnectionManager {
|
||||
self.inner.lock().network_manager.connection_manager()
|
||||
}
|
||||
|
||||
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
|
||||
let cvec = certs(&mut BufReader::new(File::open(path)?))
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
|
||||
Ok(cvec.into_iter().map(Certificate).collect())
|
||||
}
|
||||
|
||||
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
|
||||
{
|
||||
if let Ok(v) = rsa_private_keys(&mut BufReader::new(File::open(path)?)) {
|
||||
if !v.is_empty() {
|
||||
return Ok(v.into_iter().map(PrivateKey).collect());
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
if let Ok(v) = pkcs8_private_keys(&mut BufReader::new(File::open(path)?)) {
|
||||
if !v.is_empty() {
|
||||
return Ok(v.into_iter().map(PrivateKey).collect());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid TLS private key",
|
||||
))
|
||||
}
|
||||
|
||||
fn load_server_config(&self) -> io::Result<ServerConfig> {
|
||||
let c = self.config.get();
|
||||
//
|
||||
trace!(
|
||||
"loading certificate from {}",
|
||||
c.network.tls.certificate_path
|
||||
);
|
||||
let certs = Self::load_certs(&PathBuf::from(&c.network.tls.certificate_path))?;
|
||||
trace!("loaded {} certificates", certs.len());
|
||||
if certs.is_empty() {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput, format!("Certificates at {} could not be loaded.\nEnsure it is in PEM format, beginning with '-----BEGIN CERTIFICATE-----'",c.network.tls.certificate_path)));
|
||||
}
|
||||
//
|
||||
trace!(
|
||||
"loading private key from {}",
|
||||
c.network.tls.private_key_path
|
||||
);
|
||||
let mut keys = Self::load_keys(&PathBuf::from(&c.network.tls.private_key_path))?;
|
||||
trace!("loaded {} keys", keys.len());
|
||||
if keys.is_empty() {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput, format!("Private key at {} could not be loaded.\nEnsure it is unencrypted and in RSA or PKCS8 format, beginning with '-----BEGIN RSA PRIVATE KEY-----' or '-----BEGIN PRIVATE KEY-----'",c.network.tls.private_key_path)));
|
||||
}
|
||||
|
||||
// xxx: rustls ^0.20
|
||||
// let mut config = ServerConfig::builder()
|
||||
// .with_safe_defaults()
|
||||
// .with_no_client_auth()
|
||||
// .with_single_cert(certs, keys.remove(0))
|
||||
// .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
let mut config = ServerConfig::new(NoClientAuth::new());
|
||||
config
|
||||
.set_single_cert(certs, keys.remove(0))
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.join_handles.push(jh);
|
||||
}
|
||||
|
||||
fn translate_unspecified_address(inner: &NetworkInner, from: &SocketAddr) -> Vec<SocketAddr> {
|
||||
if !from.ip().is_unspecified() {
|
||||
vec![*from]
|
||||
} else {
|
||||
inner
|
||||
.interfaces
|
||||
.best_addresses()
|
||||
.iter()
|
||||
.filter_map(|a| {
|
||||
// We create sockets that are only ipv6 or ipv6 (not dual, so only translate matching unspecified address)
|
||||
if (a.is_ipv4() && from.is_ipv4()) || (a.is_ipv6() && from.is_ipv6()) {
|
||||
Some(SocketAddr::new(*a, from.port()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_preferred_local_address(&self, dial_info: &DialInfo) -> SocketAddr {
|
||||
let inner = self.inner.lock();
|
||||
|
||||
let local_port = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => inner.udp_port,
|
||||
ProtocolType::TCP => inner.tcp_port,
|
||||
ProtocolType::WS => inner.ws_port,
|
||||
ProtocolType::WSS => inner.wss_port,
|
||||
};
|
||||
|
||||
match dial_info.address_type() {
|
||||
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), local_port),
|
||||
AddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), local_port),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_interface_addresses<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&[IpAddr]) -> R,
|
||||
{
|
||||
let inner = self.inner.lock();
|
||||
inner.interfaces.with_best_addresses(f)
|
||||
}
|
||||
|
||||
// See if our interface addresses have changed, if so we need to punt the network
|
||||
// and redo all our addresses. This is overkill, but anything more accurate
|
||||
// would require inspection of routing tables that we dont want to bother with
|
||||
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
|
||||
let mut inner = self.inner.lock();
|
||||
if !inner.interfaces.refresh().await? {
|
||||
return Ok(false);
|
||||
}
|
||||
inner.network_needs_restart = true;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Send data to a dial info, unbound, using a new connection from a random port
|
||||
// This creates a short-lived connection in the case of connection-oriented protocols
|
||||
// for the purpose of sending this one message.
|
||||
// This bypasses the connection table as it is not a 'node to node' connection.
|
||||
pub async fn send_data_unbound_to_dial_info(
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
let data_len = data.len();
|
||||
let res = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
};
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub async fn send_data_to_existing_connection(
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
let data_len = data.len();
|
||||
|
||||
// Handle connectionless protocol
|
||||
if descriptor.protocol_type() == ProtocolType::UDP {
|
||||
// send over the best udp socket we have bound since UDP is not connection oriented
|
||||
let peer_socket_addr = descriptor.remote.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(
|
||||
&peer_socket_addr,
|
||||
&descriptor.local.map(|sa| sa.to_socket_addr()),
|
||||
) {
|
||||
log_net!(
|
||||
"send_data_to_existing_connection connectionless to {:?}",
|
||||
descriptor
|
||||
);
|
||||
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(peer_socket_addr.ip(), data_len as u64);
|
||||
|
||||
// Data was consumed
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
|
||||
// Try to send to the exact existing connection if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send(data).await.map_err(logthru_net!())?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(descriptor.remote.to_socket_addr().ip(), data_len as u64);
|
||||
|
||||
// Data was consumed
|
||||
Ok(None)
|
||||
} else {
|
||||
// Connection or didn't exist
|
||||
// Pass the data back out so we don't own it any more
|
||||
Ok(Some(data))
|
||||
}
|
||||
}
|
||||
|
||||
// Send data directly to a dial info, possibly without knowing which node it is going to
|
||||
pub async fn send_data_to_dial_info(
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
let data_len = data.len();
|
||||
// Handle connectionless protocol
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
let res = ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!());
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(peer_socket_addr.ip(), data_len as u64);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
let local_addr = self.get_preferred_local_address(&dial_info);
|
||||
let conn = self
|
||||
.connection_manager()
|
||||
.get_or_create_connection(Some(local_addr), dial_info.clone())
|
||||
.await?;
|
||||
|
||||
let res = conn.send(data).await.map_err(logthru_net!(error));
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn get_protocol_config(&self) -> Option<ProtocolConfig> {
|
||||
self.inner.lock().protocol_config
|
||||
}
|
||||
|
||||
pub async fn startup(&self) -> Result<(), String> {
|
||||
trace!("startup network");
|
||||
|
||||
// initialize interfaces
|
||||
let mut interfaces = NetworkInterfaces::new();
|
||||
interfaces.refresh().await?;
|
||||
self.inner.lock().interfaces = interfaces;
|
||||
|
||||
// get protocol config
|
||||
let protocol_config = {
|
||||
let c = self.config.get();
|
||||
let mut inbound = ProtocolSet::new();
|
||||
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
inbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.listen && c.capabilities.protocol_accept_tcp {
|
||||
inbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.listen && c.capabilities.protocol_accept_ws {
|
||||
inbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.listen && c.capabilities.protocol_accept_wss {
|
||||
inbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
let mut outbound = ProtocolSet::new();
|
||||
if c.network.protocol.udp.enabled && c.capabilities.protocol_udp {
|
||||
outbound.insert(ProtocolType::UDP);
|
||||
}
|
||||
if c.network.protocol.tcp.connect && c.capabilities.protocol_connect_tcp {
|
||||
outbound.insert(ProtocolType::TCP);
|
||||
}
|
||||
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
|
||||
outbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
|
||||
outbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
ProtocolConfig { inbound, outbound }
|
||||
};
|
||||
self.inner.lock().protocol_config = Some(protocol_config);
|
||||
|
||||
// start listeners
|
||||
if protocol_config.inbound.contains(ProtocolType::UDP) {
|
||||
self.start_udp_listeners().await?;
|
||||
}
|
||||
if protocol_config.inbound.contains(ProtocolType::WS) {
|
||||
self.start_ws_listeners().await?;
|
||||
}
|
||||
if protocol_config.inbound.contains(ProtocolType::WSS) {
|
||||
self.start_wss_listeners().await?;
|
||||
}
|
||||
if protocol_config.inbound.contains(ProtocolType::TCP) {
|
||||
self.start_tcp_listeners().await?;
|
||||
}
|
||||
|
||||
// release caches of available listener ports
|
||||
// this releases the 'first bound' ports we use to guarantee
|
||||
// that we have ports available to us
|
||||
self.free_bound_first_ports();
|
||||
|
||||
// If we have static public dialinfo, upgrade our network class
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
if !inner.static_public_dialinfo.is_empty() {
|
||||
inner.network_class = Some(NetworkClass::InboundCapable);
|
||||
}
|
||||
}
|
||||
|
||||
info!("network started");
|
||||
self.inner.lock().network_started = true;
|
||||
|
||||
// Inform routing table entries that our dial info has changed
|
||||
self.routing_table().send_node_info_updates();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn needs_restart(&self) -> bool {
|
||||
self.inner.lock().network_needs_restart
|
||||
}
|
||||
|
||||
pub fn is_started(&self) -> bool {
|
||||
self.inner.lock().network_started
|
||||
}
|
||||
|
||||
pub fn restart_network(&self) {
|
||||
self.inner.lock().network_needs_restart = true;
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
info!("stopping network");
|
||||
|
||||
let network_manager = self.network_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Cancel all tasks
|
||||
if let Err(e) = self.unlocked_inner.update_network_class_task.cancel().await {
|
||||
warn!("update_network_class_task not cancelled: {}", e);
|
||||
}
|
||||
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||
|
||||
// Reset state including network class
|
||||
// Cancels all async background tasks by dropping join handles
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
|
||||
info!("network stopped");
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
pub fn get_network_class(&self) -> Option<NetworkClass> {
|
||||
let inner = self.inner.lock();
|
||||
inner.network_class
|
||||
}
|
||||
|
||||
pub fn reset_network_class(&self) {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.network_class = None;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
|
||||
pub async fn tick(&self) -> Result<(), String> {
|
||||
let network_class = self.get_network_class().unwrap_or(NetworkClass::Invalid);
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// If we need to figure out our network class, tick the task for it
|
||||
if network_class == NetworkClass::Invalid {
|
||||
let rth = routing_table.get_routing_table_health();
|
||||
|
||||
// Need at least two entries to do this
|
||||
if rth.unreliable_entry_count + rth.reliable_entry_count >= 2 {
|
||||
self.unlocked_inner.update_network_class_task.tick().await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@@ -0,0 +1,611 @@
|
||||
use super::*;
|
||||
use futures_util::stream::FuturesUnordered;
|
||||
use futures_util::FutureExt;
|
||||
|
||||
struct DetectedPublicDialInfo {
|
||||
dial_info: DialInfo,
|
||||
class: DialInfoClass,
|
||||
}
|
||||
struct DiscoveryContextInner {
|
||||
// per-protocol
|
||||
intf_addrs: Option<Vec<SocketAddress>>,
|
||||
protocol_type: Option<ProtocolType>,
|
||||
address_type: Option<AddressType>,
|
||||
// first node contacted
|
||||
external_1_dial_info: Option<DialInfo>,
|
||||
external_1_address: Option<SocketAddress>,
|
||||
node_1: Option<NodeRef>,
|
||||
// detected public dialinfo
|
||||
detected_network_class: Option<NetworkClass>,
|
||||
detected_public_dial_info: Option<DetectedPublicDialInfo>,
|
||||
}
|
||||
|
||||
pub struct DiscoveryContext {
|
||||
routing_table: RoutingTable,
|
||||
net: Network,
|
||||
inner: Arc<Mutex<DiscoveryContextInner>>,
|
||||
}
|
||||
|
||||
impl DiscoveryContext {
|
||||
pub fn new(routing_table: RoutingTable, net: Network) -> Self {
|
||||
Self {
|
||||
routing_table,
|
||||
net,
|
||||
inner: Arc::new(Mutex::new(DiscoveryContextInner {
|
||||
// per-protocol
|
||||
intf_addrs: None,
|
||||
protocol_type: None,
|
||||
address_type: None,
|
||||
external_1_dial_info: None,
|
||||
external_1_address: None,
|
||||
node_1: None,
|
||||
detected_network_class: None,
|
||||
detected_public_dial_info: None,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
///////
|
||||
// Utilities
|
||||
|
||||
// Pick the best network class we have seen so far
|
||||
pub fn set_detected_network_class(&self, network_class: NetworkClass) {
|
||||
let mut inner = self.inner.lock();
|
||||
log_net!( debug
|
||||
"=== set_detected_network_class {:?} {:?}: {:?} ===",
|
||||
inner.protocol_type,
|
||||
inner.address_type,
|
||||
network_class
|
||||
);
|
||||
|
||||
inner.detected_network_class = Some(network_class);
|
||||
}
|
||||
|
||||
pub fn set_detected_public_dial_info(&self, dial_info: DialInfo, class: DialInfoClass) {
|
||||
let mut inner = self.inner.lock();
|
||||
log_net!( debug
|
||||
"=== set_detected_public_dial_info {:?} {:?}: {} {:?} ===",
|
||||
inner.protocol_type,
|
||||
inner.address_type,
|
||||
dial_info,
|
||||
class
|
||||
);
|
||||
inner.detected_public_dial_info = Some(DetectedPublicDialInfo { dial_info, class });
|
||||
}
|
||||
|
||||
// Ask for a public address check from a particular noderef
|
||||
// This is done over the normal port using RPC
|
||||
async fn request_public_address(&self, node_ref: NodeRef) -> Option<SocketAddress> {
|
||||
let rpc = self.routing_table.rpc_processor();
|
||||
rpc.rpc_call_status(node_ref.clone())
|
||||
.await
|
||||
.map_err(logthru_net!(
|
||||
"failed to get status answer from {:?}",
|
||||
node_ref
|
||||
))
|
||||
.map(|sa| {
|
||||
let ret = sa.sender_info.socket_address;
|
||||
log_net!("request_public_address: {:?}", ret);
|
||||
ret
|
||||
})
|
||||
.unwrap_or(None)
|
||||
}
|
||||
|
||||
// find fast peers with a particular address type, and ask them to tell us what our external address is
|
||||
// This is done over the normal port using RPC
|
||||
async fn discover_external_address(
|
||||
&self,
|
||||
protocol_type: ProtocolType,
|
||||
address_type: AddressType,
|
||||
ignore_node: Option<DHTKey>,
|
||||
) -> Option<(SocketAddress, NodeRef)> {
|
||||
let filter = DialInfoFilter::global()
|
||||
.with_protocol_type(protocol_type)
|
||||
.with_address_type(address_type);
|
||||
let peers = self.routing_table.find_fast_public_nodes_filtered(&filter);
|
||||
if peers.is_empty() {
|
||||
log_net!("no peers of type '{:?}'", filter);
|
||||
return None;
|
||||
}
|
||||
for peer in peers {
|
||||
if let Some(ignore_node) = ignore_node {
|
||||
if peer.node_id() == ignore_node {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Some(sa) = self.request_public_address(peer.clone()).await {
|
||||
return Some((sa, peer));
|
||||
}
|
||||
}
|
||||
log_net!("no peers responded with an external address");
|
||||
None
|
||||
}
|
||||
|
||||
// This pulls the already-detected local interface dial info from the routing table
|
||||
fn get_local_addresses(
|
||||
&self,
|
||||
protocol_type: ProtocolType,
|
||||
address_type: AddressType,
|
||||
) -> Vec<SocketAddress> {
|
||||
let filter = DialInfoFilter::local()
|
||||
.with_protocol_type(protocol_type)
|
||||
.with_address_type(address_type);
|
||||
self.routing_table
|
||||
.dial_info_details(RoutingDomain::LocalNetwork)
|
||||
.iter()
|
||||
.filter_map(|did| {
|
||||
if did.dial_info.matches_filter(&filter) {
|
||||
Some(did.dial_info.socket_address())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn validate_dial_info(
|
||||
&self,
|
||||
node_ref: NodeRef,
|
||||
dial_info: DialInfo,
|
||||
redirect: bool,
|
||||
) -> bool {
|
||||
let rpc = self.routing_table.rpc_processor();
|
||||
rpc.rpc_call_validate_dial_info(node_ref.clone(), dial_info, redirect)
|
||||
.await
|
||||
.map_err(logthru_net!(
|
||||
"failed to send validate_dial_info to {:?}",
|
||||
node_ref
|
||||
))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn try_port_mapping(&self) -> Option<DialInfo> {
|
||||
//xxx
|
||||
None
|
||||
}
|
||||
|
||||
fn make_dial_info(&self, addr: SocketAddress, protocol_type: ProtocolType) -> DialInfo {
|
||||
match protocol_type {
|
||||
ProtocolType::UDP => DialInfo::udp(addr),
|
||||
ProtocolType::TCP => DialInfo::tcp(addr),
|
||||
ProtocolType::WS => {
|
||||
let c = self.net.config.get();
|
||||
DialInfo::try_ws(
|
||||
addr,
|
||||
format!("ws://{}/{}", addr, c.network.protocol.ws.path),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
ProtocolType::WSS => panic!("none of the discovery functions are used for wss"),
|
||||
}
|
||||
}
|
||||
|
||||
///////
|
||||
// Per-protocol discovery routines
|
||||
|
||||
pub fn protocol_begin(&self, protocol_type: ProtocolType, address_type: AddressType) {
|
||||
// Get our interface addresses
|
||||
let intf_addrs = self.get_local_addresses(protocol_type, address_type);
|
||||
|
||||
let mut inner = self.inner.lock();
|
||||
inner.intf_addrs = Some(intf_addrs);
|
||||
inner.protocol_type = Some(protocol_type);
|
||||
inner.address_type = Some(address_type);
|
||||
inner.external_1_dial_info = None;
|
||||
inner.external_1_address = None;
|
||||
inner.node_1 = None;
|
||||
}
|
||||
|
||||
// Get our first node's view of our external IP address via normal RPC
|
||||
pub async fn protocol_get_external_address_1(&self) -> bool {
|
||||
let (protocol_type, address_type) = {
|
||||
let inner = self.inner.lock();
|
||||
(inner.protocol_type.unwrap(), inner.address_type.unwrap())
|
||||
};
|
||||
|
||||
// Get our external address from some fast node, call it node 1
|
||||
let (external_1, node_1) = match self
|
||||
.discover_external_address(protocol_type, address_type, None)
|
||||
.await
|
||||
{
|
||||
None => {
|
||||
// If we can't get an external address, exit but don't throw an error so we can try again later
|
||||
log_net!(debug "couldn't get external address 1 for {:?} {:?}", protocol_type, address_type);
|
||||
return false;
|
||||
}
|
||||
Some(v) => v,
|
||||
};
|
||||
let external_1_dial_info = self.make_dial_info(external_1, protocol_type);
|
||||
|
||||
let mut inner = self.inner.lock();
|
||||
inner.external_1_dial_info = Some(external_1_dial_info);
|
||||
inner.external_1_address = Some(external_1);
|
||||
inner.node_1 = Some(node_1);
|
||||
|
||||
log_net!(debug "external_1_dial_info: {:?}\nexternal_1_address: {:?}\nnode_1: {:?}", inner.external_1_dial_info, inner.external_1_address, inner.node_1);
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// If we know we are not behind NAT, check our firewall status
|
||||
pub async fn protocol_process_no_nat(&self) -> Result<(), String> {
|
||||
let (node_b, external_1_dial_info) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
inner.node_1.as_ref().unwrap().clone(),
|
||||
inner.external_1_dial_info.as_ref().unwrap().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
// Do a validate_dial_info on the external address from a redirected node
|
||||
if self
|
||||
.validate_dial_info(node_b.clone(), external_1_dial_info.clone(), true)
|
||||
.await
|
||||
{
|
||||
// Add public dial info with Direct dialinfo class
|
||||
self.set_detected_public_dial_info(external_1_dial_info, DialInfoClass::Direct);
|
||||
}
|
||||
// Attempt a port mapping via all available and enabled mechanisms
|
||||
else if let Some(external_mapped_dial_info) = self.try_port_mapping().await {
|
||||
// Got a port mapping, let's use it
|
||||
self.set_detected_public_dial_info(external_mapped_dial_info, DialInfoClass::Mapped);
|
||||
} else {
|
||||
// Add public dial info with Blocked dialinfo class
|
||||
self.set_detected_public_dial_info(external_1_dial_info, DialInfoClass::Blocked);
|
||||
}
|
||||
self.set_detected_network_class(NetworkClass::InboundCapable);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// If we know we are behind NAT check what kind
|
||||
pub async fn protocol_process_nat(&self) -> Result<bool, String> {
|
||||
let (node_1, external_1_dial_info, external_1_address, protocol_type, address_type) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
inner.node_1.as_ref().unwrap().clone(),
|
||||
inner.external_1_dial_info.as_ref().unwrap().clone(),
|
||||
inner.external_1_address.unwrap(),
|
||||
inner.protocol_type.unwrap(),
|
||||
inner.address_type.unwrap(),
|
||||
)
|
||||
};
|
||||
|
||||
// Attempt a UDP port mapping via all available and enabled mechanisms
|
||||
if let Some(external_mapped_dial_info) = self.try_port_mapping().await {
|
||||
// Got a port mapping, let's use it
|
||||
self.set_detected_public_dial_info(external_mapped_dial_info, DialInfoClass::Mapped);
|
||||
self.set_detected_network_class(NetworkClass::InboundCapable);
|
||||
|
||||
// No more retries
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Port mapping was not possible, let's see what kind of NAT we have
|
||||
|
||||
// Does a redirected dial info validation from a different address and a random port find us?
|
||||
if self
|
||||
.validate_dial_info(node_1.clone(), external_1_dial_info.clone(), true)
|
||||
.await
|
||||
{
|
||||
// Yes, another machine can use the dial info directly, so Full Cone
|
||||
// Add public dial info with full cone NAT network class
|
||||
self.set_detected_public_dial_info(external_1_dial_info, DialInfoClass::FullConeNAT);
|
||||
self.set_detected_network_class(NetworkClass::InboundCapable);
|
||||
|
||||
// No more retries
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// No, we are restricted, determine what kind of restriction
|
||||
|
||||
// Get our external address from some fast node, that is not node 1, call it node 2
|
||||
let (external_2_address, node_2) = match self
|
||||
.discover_external_address(protocol_type, address_type, Some(node_1.node_id()))
|
||||
.await
|
||||
{
|
||||
None => {
|
||||
// If we can't get an external address, allow retry
|
||||
return Ok(false);
|
||||
}
|
||||
Some(v) => v,
|
||||
};
|
||||
|
||||
// If we have two different external addresses, then this is a symmetric NAT
|
||||
if external_2_address != external_1_address {
|
||||
// Symmetric NAT is outbound only, no public dial info will work
|
||||
self.set_detected_network_class(NetworkClass::OutboundOnly);
|
||||
|
||||
// No more retries
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// If we're going to end up as a restricted NAT of some sort
|
||||
// Address is the same, so it's address or port restricted
|
||||
|
||||
// Do a validate_dial_info on the external address from a random port
|
||||
if self
|
||||
.validate_dial_info(node_2.clone(), external_1_dial_info.clone(), false)
|
||||
.await
|
||||
{
|
||||
// Got a reply from a non-default port, which means we're only address restricted
|
||||
self.set_detected_public_dial_info(
|
||||
external_1_dial_info,
|
||||
DialInfoClass::AddressRestrictedNAT,
|
||||
);
|
||||
} else {
|
||||
// Didn't get a reply from a non-default port, which means we are also port restricted
|
||||
self.set_detected_public_dial_info(
|
||||
external_1_dial_info,
|
||||
DialInfoClass::PortRestrictedNAT,
|
||||
);
|
||||
}
|
||||
self.set_detected_network_class(NetworkClass::InboundCapable);
|
||||
|
||||
// Allow another retry because sometimes trying again will get us Full Cone NAT instead
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Network {
|
||||
pub async fn update_ipv4_protocol_dialinfo(
|
||||
&self,
|
||||
context: &DiscoveryContext,
|
||||
protocol_type: ProtocolType,
|
||||
) -> Result<(), String> {
|
||||
let mut retry_count = {
|
||||
let c = self.config.get();
|
||||
c.network.restricted_nat_retries
|
||||
};
|
||||
|
||||
// Start doing ipv4 protocol
|
||||
context.protocol_begin(protocol_type, AddressType::IPV4);
|
||||
|
||||
// Loop for restricted NAT retries
|
||||
loop {
|
||||
log_net!(debug
|
||||
"=== update_ipv4_protocol_dialinfo {:?} tries_left={} ===",
|
||||
protocol_type,
|
||||
retry_count
|
||||
);
|
||||
// Get our external address from some fast node, call it node 1
|
||||
if !context.protocol_get_external_address_1().await {
|
||||
// If we couldn't get an external address, then we should just try the whole network class detection again later
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If our local interface list contains external_1 then there is no NAT in place
|
||||
{
|
||||
let res = {
|
||||
let inner = context.inner.lock();
|
||||
inner
|
||||
.intf_addrs
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains(inner.external_1_address.as_ref().unwrap())
|
||||
};
|
||||
if res {
|
||||
// No NAT
|
||||
context.protocol_process_no_nat().await?;
|
||||
|
||||
// No more retries
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// There is -some NAT-
|
||||
if context.protocol_process_nat().await? {
|
||||
// We either got dial info or a network class without one
|
||||
break;
|
||||
}
|
||||
|
||||
// If we tried everything, break anyway after N attempts
|
||||
if retry_count == 0 {
|
||||
break;
|
||||
}
|
||||
retry_count -= 1;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_ipv6_protocol_dialinfo(
|
||||
&self,
|
||||
context: &DiscoveryContext,
|
||||
protocol_type: ProtocolType,
|
||||
) -> Result<(), String> {
|
||||
// Start doing ipv6 protocol
|
||||
context.protocol_begin(protocol_type, AddressType::IPV6);
|
||||
|
||||
log_net!(debug "=== update_ipv6_protocol_dialinfo {:?} ===", protocol_type);
|
||||
|
||||
// Get our external address from some fast node, call it node 1
|
||||
if !context.protocol_get_external_address_1().await {
|
||||
// If we couldn't get an external address, then we should just try the whole network class detection again later
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If our local interface list doesn't contain external_1 then there is an Ipv6 NAT in place
|
||||
{
|
||||
let inner = context.inner.lock();
|
||||
if !inner
|
||||
.intf_addrs
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains(inner.external_1_address.as_ref().unwrap())
|
||||
{
|
||||
// IPv6 NAT is not supported today
|
||||
log_net!(warn
|
||||
"IPv6 NAT is not supported for external address: {}",
|
||||
inner.external_1_address.unwrap()
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// No NAT
|
||||
context.protocol_process_no_nat().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_network_class_task_routine(self, _l: u64, _t: u64) -> Result<(), String> {
|
||||
log_net!("--- updating network class");
|
||||
|
||||
// Ensure we aren't trying to update this without clearing it first
|
||||
let old_network_class = self.inner.lock().network_class;
|
||||
assert_eq!(old_network_class, None);
|
||||
|
||||
let protocol_config = self.inner.lock().protocol_config.unwrap_or_default();
|
||||
let mut unord = FuturesUnordered::new();
|
||||
|
||||
// Do UDPv4+v6 at the same time as everything else
|
||||
if protocol_config.inbound.contains(ProtocolType::UDP) {
|
||||
// UDPv4
|
||||
unord.push(
|
||||
async {
|
||||
let udpv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
|
||||
if let Err(e) = self
|
||||
.update_ipv4_protocol_dialinfo(&udpv4_context, ProtocolType::UDP)
|
||||
.await
|
||||
{
|
||||
log_net!(debug "Failed UDPv4 dialinfo discovery: {}", e);
|
||||
return None;
|
||||
}
|
||||
Some(vec![udpv4_context])
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
|
||||
// UDPv6
|
||||
unord.push(
|
||||
async {
|
||||
let udpv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
|
||||
if let Err(e) = self
|
||||
.update_ipv6_protocol_dialinfo(&udpv6_context, ProtocolType::UDP)
|
||||
.await
|
||||
{
|
||||
log_net!(debug "Failed UDPv6 dialinfo discovery: {}", e);
|
||||
return None;
|
||||
}
|
||||
Some(vec![udpv6_context])
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
|
||||
// Do TCPv4 + WSv4 in series because they may use the same connection 5-tuple
|
||||
unord.push(
|
||||
async {
|
||||
// TCPv4
|
||||
let mut out = Vec::<DiscoveryContext>::new();
|
||||
if protocol_config.inbound.contains(ProtocolType::TCP) {
|
||||
let tcpv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
|
||||
if let Err(e) = self
|
||||
.update_ipv4_protocol_dialinfo(&tcpv4_context, ProtocolType::TCP)
|
||||
.await
|
||||
{
|
||||
log_net!(debug "Failed TCPv4 dialinfo discovery: {}", e);
|
||||
return None;
|
||||
}
|
||||
out.push(tcpv4_context);
|
||||
}
|
||||
|
||||
// WSv4
|
||||
if protocol_config.inbound.contains(ProtocolType::WS) {
|
||||
let wsv4_context = DiscoveryContext::new(self.routing_table(), self.clone());
|
||||
if let Err(e) = self
|
||||
.update_ipv4_protocol_dialinfo(&wsv4_context, ProtocolType::WS)
|
||||
.await
|
||||
{
|
||||
log_net!(debug "Failed WSv4 dialinfo discovery: {}", e);
|
||||
return None;
|
||||
}
|
||||
out.push(wsv4_context);
|
||||
}
|
||||
Some(out)
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
|
||||
// Do TCPv6 + WSv6 in series because they may use the same connection 5-tuple
|
||||
unord.push(
|
||||
async {
|
||||
// TCPv6
|
||||
let mut out = Vec::<DiscoveryContext>::new();
|
||||
if protocol_config.inbound.contains(ProtocolType::TCP) {
|
||||
let tcpv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
|
||||
if let Err(e) = self
|
||||
.update_ipv6_protocol_dialinfo(&tcpv6_context, ProtocolType::TCP)
|
||||
.await
|
||||
{
|
||||
log_net!(debug "Failed TCPv6 dialinfo discovery: {}", e);
|
||||
return None;
|
||||
}
|
||||
out.push(tcpv6_context);
|
||||
}
|
||||
|
||||
// WSv6
|
||||
if protocol_config.inbound.contains(ProtocolType::WS) {
|
||||
let wsv6_context = DiscoveryContext::new(self.routing_table(), self.clone());
|
||||
if let Err(e) = self
|
||||
.update_ipv6_protocol_dialinfo(&wsv6_context, ProtocolType::WS)
|
||||
.await
|
||||
{
|
||||
log_net!(debug "Failed WSv6 dialinfo discovery: {}", e);
|
||||
return None;
|
||||
}
|
||||
out.push(wsv6_context);
|
||||
}
|
||||
Some(out)
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
|
||||
// Wait for all discovery futures to complete and collect contexts
|
||||
let mut contexts = Vec::<DiscoveryContext>::new();
|
||||
let mut network_class = Option::<NetworkClass>::None;
|
||||
while let Some(ctxvec) = unord.next().await {
|
||||
if let Some(ctxvec) = ctxvec {
|
||||
for ctx in ctxvec {
|
||||
if let Some(nc) = ctx.inner.lock().detected_network_class {
|
||||
if let Some(last_nc) = network_class {
|
||||
if nc < last_nc {
|
||||
network_class = Some(nc);
|
||||
}
|
||||
} else {
|
||||
network_class = Some(nc);
|
||||
}
|
||||
}
|
||||
|
||||
contexts.push(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get best network class
|
||||
if network_class.is_some() {
|
||||
// Update public dial info
|
||||
let routing_table = self.routing_table();
|
||||
for ctx in contexts {
|
||||
let inner = ctx.inner.lock();
|
||||
if let Some(pdi) = &inner.detected_public_dial_info {
|
||||
if let Err(e) = routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
pdi.dial_info.clone(),
|
||||
pdi.class,
|
||||
) {
|
||||
log_net!(warn "Failed to register detected public dial info: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update network class
|
||||
self.inner.lock().network_class = network_class;
|
||||
log_net!(debug "network class changed to {:?}", network_class);
|
||||
|
||||
// Send updates to everyone
|
||||
routing_table.send_node_info_updates();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
289
veilid-core/src/network_manager/native/network_tcp.rs
Normal file
289
veilid-core/src/network_manager/native/network_tcp.rs
Normal file
@@ -0,0 +1,289 @@
|
||||
use super::*;
|
||||
use crate::intf::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
use sockets::*;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ListenerState {
|
||||
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_acceptor: Option<TlsAcceptor>,
|
||||
}
|
||||
|
||||
impl ListenerState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
protocol_handlers: Vec::new(),
|
||||
tls_protocol_handlers: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
impl Network {
|
||||
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
|
||||
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
|
||||
return Ok(ts.clone());
|
||||
}
|
||||
|
||||
let server_config = self
|
||||
.load_server_config()
|
||||
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(server_config));
|
||||
self.inner.lock().tls_acceptor = Some(acceptor.clone());
|
||||
Ok(acceptor)
|
||||
}
|
||||
|
||||
async fn try_tls_handlers(
|
||||
&self,
|
||||
tls_acceptor: &TlsAcceptor,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout: u64,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
let ts = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
|
||||
let ps = AsyncPeekStream::new(CloneStream::new(ts));
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// Try the handlers but first get a chunk of data for them to process
|
||||
// Don't waste more than N seconds getting it though, in case someone
|
||||
// is trying to DoS us with a bunch of connections or something
|
||||
// read a chunk of the stream
|
||||
io::timeout(
|
||||
Duration::from_micros(tls_connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
|
||||
self.try_handlers(ps, tcp_stream, addr, protocol_handlers)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn try_handlers(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
for ah in protocol_handlers.iter() {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), tcp_stream.clone(), addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
{
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
// Get config
|
||||
let (connection_initial_timeout, tls_connection_initial_timeout) = {
|
||||
let c = self.config.get();
|
||||
(
|
||||
ms_to_us(c.network.connection_initial_timeout_ms),
|
||||
ms_to_us(c.network.tls.connection_initial_timeout_ms),
|
||||
)
|
||||
};
|
||||
|
||||
// Create a reusable socket with no linger time, and no delay
|
||||
let socket = new_bound_shared_tcp_socket(addr)?;
|
||||
// Listen on the socket
|
||||
socket
|
||||
.listen(128)
|
||||
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
|
||||
|
||||
// Make an async tcplistener from the socket2 socket
|
||||
let std_listener: std::net::TcpListener = socket.into();
|
||||
let listener = TcpListener::from(std_listener);
|
||||
|
||||
debug!("spawn_socket_listener: binding successful to {}", addr);
|
||||
|
||||
// Create protocol handler records
|
||||
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
|
||||
self.inner
|
||||
.lock()
|
||||
.listener_states
|
||||
.insert(addr, listener_state.clone());
|
||||
|
||||
// Spawn the socket task
|
||||
let this = self.clone();
|
||||
let connection_manager = self.connection_manager();
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
let jh = spawn(async move {
|
||||
// moves listener object in and get incoming iterator
|
||||
// when this task exists, the listener will close the socket
|
||||
listener
|
||||
.incoming()
|
||||
.for_each_concurrent(None, |tcp_stream| async {
|
||||
let tcp_stream = tcp_stream.unwrap();
|
||||
let listener_state = listener_state.clone();
|
||||
let connection_manager = connection_manager.clone();
|
||||
|
||||
// Limit the number of connections from the same IP address
|
||||
// and the number of total connections
|
||||
let addr = match tcp_stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
log_net!(error "failed to get peer address: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
// Create a stream we can peek on
|
||||
let ps = AsyncPeekStream::new(tcp_stream.clone());
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
// If we fail to get a packet within the connection initial timeout
|
||||
// then we punt this connection
|
||||
log_net!(warn "connection initial timeout from: {:?}", addr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run accept handlers on accepted stream
|
||||
|
||||
// Check is this could be TLS
|
||||
let ls = listener_state.read().clone();
|
||||
|
||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
this.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
tcp_stream,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
this.try_handlers(ps, tcp_stream, addr, &ls.protocol_handlers)
|
||||
.await
|
||||
};
|
||||
|
||||
let conn = match conn {
|
||||
Ok(Some(c)) => {
|
||||
log_net!("protocol handler found for {:?}: {:?}", addr, c);
|
||||
c
|
||||
}
|
||||
Ok(None) => {
|
||||
// No protocol handlers matched? drop it.
|
||||
log_net!(warn "no protocol handler for connection from {:?}", addr);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
// Failed to negotiate connection? drop it.
|
||||
log_net!(warn "failed to negotiate connection from {:?}: {}", addr, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
if let Err(e) = connection_manager.on_new_connection(conn).await {
|
||||
log_net!(error "failed to register new connection: {}", e);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
log_net!(debug "exited incoming loop for {}", addr);
|
||||
// Remove our listener state from this address if we're stopping
|
||||
this.inner.lock().listener_states.remove(&addr);
|
||||
log_net!(debug "listener state removed for {}", addr);
|
||||
|
||||
// If this happened our low-level listener socket probably died
|
||||
// so it's time to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
});
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handles
|
||||
self.add_to_join_handles(jh);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
// TCP listener that multiplexes ports so multiple protocols can exist on a single port
|
||||
pub(super) async fn start_tcp_listener(
|
||||
&self,
|
||||
ip_addrs: Vec<IpAddr>,
|
||||
port: u16,
|
||||
is_tls: bool,
|
||||
new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
|
||||
) -> Result<Vec<SocketAddress>, String> {
|
||||
let mut out = Vec::<SocketAddress>::new();
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
let addr = SocketAddr::new(ip_addr, port);
|
||||
let idi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
|
||||
|
||||
// see if we've already bound to this already
|
||||
// if not, spawn a listener
|
||||
if !self.inner.lock().listener_states.contains_key(&addr) {
|
||||
self.clone().spawn_socket_listener(addr).await?;
|
||||
}
|
||||
|
||||
let ls = if let Some(ls) = self.inner.lock().listener_states.get_mut(&addr) {
|
||||
ls.clone()
|
||||
} else {
|
||||
panic!("this shouldn't happen");
|
||||
};
|
||||
|
||||
if is_tls {
|
||||
if ls.read().tls_acceptor.is_none() {
|
||||
ls.write().tls_acceptor = Some(self.clone().get_or_create_tls_acceptor()?);
|
||||
}
|
||||
ls.write()
|
||||
.tls_protocol_handlers
|
||||
.push(new_protocol_accept_handler(
|
||||
self.network_manager().config(),
|
||||
true,
|
||||
addr,
|
||||
));
|
||||
} else {
|
||||
ls.write()
|
||||
.protocol_handlers
|
||||
.push(new_protocol_accept_handler(
|
||||
self.network_manager().config(),
|
||||
false,
|
||||
addr,
|
||||
));
|
||||
}
|
||||
|
||||
// Return interface dial infos we listen on
|
||||
for idi_addr in idi_addrs {
|
||||
out.push(SocketAddress::from_socket_addr(idi_addr));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
209
veilid-core/src/network_manager/native/network_udp.rs
Normal file
209
veilid-core/src/network_manager/native/network_udp.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use super::*;
|
||||
use sockets::*;
|
||||
|
||||
impl Network {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
||||
// Spawn socket tasks
|
||||
let mut task_count = {
|
||||
let c = self.config.get();
|
||||
c.network.protocol.udp.socket_pool_size
|
||||
};
|
||||
if task_count == 0 {
|
||||
task_count = intf::get_concurrency() / 2;
|
||||
if task_count == 0 {
|
||||
task_count = 1;
|
||||
}
|
||||
}
|
||||
trace!("task_count: {}", task_count);
|
||||
for _ in 0..task_count {
|
||||
trace!("Spawning UDP listener task");
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
// Run thread task to process stream of messages
|
||||
let this = self.clone();
|
||||
|
||||
let jh = spawn(async move {
|
||||
trace!("UDP listener task spawned");
|
||||
|
||||
// Collect all our protocol handlers into a vector
|
||||
let mut protocol_handlers: Vec<RawUdpProtocolHandler> = this
|
||||
.inner
|
||||
.lock()
|
||||
.inbound_udp_protocol_handlers
|
||||
.values()
|
||||
.cloned()
|
||||
.collect();
|
||||
if let Some(ph) = this.inner.lock().outbound_udpv4_protocol_handler.clone() {
|
||||
protocol_handlers.push(ph);
|
||||
}
|
||||
if let Some(ph) = this.inner.lock().outbound_udpv6_protocol_handler.clone() {
|
||||
protocol_handlers.push(ph);
|
||||
}
|
||||
|
||||
// Spawn a local async task for each socket
|
||||
let mut protocol_handlers_unordered = FuturesUnordered::new();
|
||||
let network_manager = this.network_manager();
|
||||
|
||||
for ph in protocol_handlers {
|
||||
let network_manager = network_manager.clone();
|
||||
let jh = spawn_local(async move {
|
||||
let mut data = vec![0u8; 65536];
|
||||
|
||||
while let Ok((size, descriptor)) = ph.recv_message(&mut data).await {
|
||||
// XXX: Limit the number of packets from the same IP address?
|
||||
log_net!("UDP packet: {:?}", descriptor);
|
||||
|
||||
// Network accounting
|
||||
network_manager.stats_packet_rcvd(
|
||||
descriptor.remote.to_socket_addr().ip(),
|
||||
size as u64,
|
||||
);
|
||||
|
||||
// Pass it up for processing
|
||||
if let Err(e) = network_manager
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
protocol_handlers_unordered.push(jh);
|
||||
}
|
||||
// Now we wait for any join handle to exit,
|
||||
// which would indicate an error needing
|
||||
// us to completely restart the network
|
||||
let _ = protocol_handlers_unordered.next().await;
|
||||
|
||||
trace!("UDP listener task stopped");
|
||||
// If this loop fails, our socket died and we need to restart the network
|
||||
this.inner.lock().network_needs_restart = true;
|
||||
});
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Add to join handle
|
||||
self.add_to_join_handles(jh);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock();
|
||||
let mut port = inner.udp_port;
|
||||
// v4
|
||||
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
||||
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v4) {
|
||||
// Pull the port if we randomly bound, so v6 can be on the same port
|
||||
port = socket
|
||||
.local_addr()
|
||||
.map_err(map_to_string)?
|
||||
.as_socket_ipv4()
|
||||
.ok_or_else(|| "expected ipv4 address type".to_owned())?
|
||||
.port();
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
let udpv4_handler = RawUdpProtocolHandler::new(socket_arc);
|
||||
|
||||
inner.outbound_udpv4_protocol_handler = Some(udpv4_handler);
|
||||
}
|
||||
//v6
|
||||
let socket_addr_v6 =
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
|
||||
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v6) {
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
let udpv6_handler = RawUdpProtocolHandler::new(socket_arc);
|
||||
|
||||
inner.outbound_udpv6_protocol_handler = Some(udpv6_handler);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
log_net!("create_udp_inbound_socket on {:?}", &addr);
|
||||
|
||||
// Create a reusable socket
|
||||
let socket = new_bound_shared_udp_socket(addr)?;
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
|
||||
// Create protocol handler
|
||||
let protocol_handler = RawUdpProtocolHandler::new(socket_arc);
|
||||
|
||||
// Create message_handler records
|
||||
self.inner
|
||||
.lock()
|
||||
.inbound_udp_protocol_handlers
|
||||
.insert(addr, protocol_handler);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn create_udp_inbound_sockets(
|
||||
&self,
|
||||
ip_addrs: Vec<IpAddr>,
|
||||
port: u16,
|
||||
) -> Result<Vec<DialInfo>, String> {
|
||||
let mut out = Vec::<DialInfo>::new();
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
let addr = SocketAddr::new(ip_addr, port);
|
||||
|
||||
// see if we've already bound to this already
|
||||
// if not, spawn a listener
|
||||
if !self
|
||||
.inner
|
||||
.lock()
|
||||
.inbound_udp_protocol_handlers
|
||||
.contains_key(&addr)
|
||||
{
|
||||
let idi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
|
||||
|
||||
self.clone().create_udp_inbound_socket(addr).await?;
|
||||
|
||||
// Return interface dial infos we listen on
|
||||
for idi_addr in idi_addrs {
|
||||
out.push(DialInfo::udp_from_socketaddr(idi_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub(super) fn find_best_udp_protocol_handler(
|
||||
&self,
|
||||
peer_socket_addr: &SocketAddr,
|
||||
local_socket_addr: &Option<SocketAddr>,
|
||||
) -> Option<RawUdpProtocolHandler> {
|
||||
// if our last communication with this peer came from a particular inbound udp protocol handler, use it
|
||||
if let Some(sa) = local_socket_addr {
|
||||
if let Some(ph) = self.inner.lock().inbound_udp_protocol_handlers.get(sa) {
|
||||
return Some(ph.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise find the outbound udp protocol handler that matches the ip protocol version of the peer addr
|
||||
let inner = self.inner.lock();
|
||||
match peer_socket_addr {
|
||||
SocketAddr::V4(_) => inner.outbound_udpv4_protocol_handler.clone(),
|
||||
SocketAddr::V6(_) => inner.outbound_udpv6_protocol_handler.clone(),
|
||||
}
|
||||
}
|
||||
}
|
86
veilid-core/src/network_manager/native/protocol/mod.rs
Normal file
86
veilid-core/src/network_manager/native/protocol/mod.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
pub mod sockets;
|
||||
pub mod tcp;
|
||||
pub mod udp;
|
||||
pub mod wrtc;
|
||||
pub mod ws;
|
||||
|
||||
use super::*;
|
||||
use crate::xx::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProtocolNetworkConnection {
|
||||
Dummy(DummyNetworkConnection),
|
||||
RawTcp(tcp::RawTcpNetworkConnection),
|
||||
WsAccepted(ws::WebSocketNetworkConnectionAccepted),
|
||||
Ws(ws::WebsocketNetworkConnectionWS),
|
||||
Wss(ws::WebsocketNetworkConnectionWSS),
|
||||
//WebRTC(wrtc::WebRTCNetworkConnection),
|
||||
}
|
||||
|
||||
impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
tcp::RawTcpProtocolHandler::connect(local_address, dial_info).await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
udp::RawUdpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
tcp::RawTcpProtocolHandler::send_unbound_message(peer_socket_addr, data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.close(),
|
||||
Self::RawTcp(t) => t.close().await,
|
||||
Self::WsAccepted(w) => w.close().await,
|
||||
Self::Ws(w) => w.close().await,
|
||||
Self::Wss(w) => w.close().await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::RawTcp(t) => t.send(message).await,
|
||||
Self::WsAccepted(w) => w.send(message).await,
|
||||
Self::Ws(w) => w.send(message).await,
|
||||
Self::Wss(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::RawTcp(t) => t.recv().await,
|
||||
Self::WsAccepted(w) => w.recv().await,
|
||||
Self::Ws(w) => w.recv().await,
|
||||
Self::Wss(w) => w.recv().await,
|
||||
}
|
||||
}
|
||||
}
|
222
veilid-core/src/network_manager/native/protocol/sockets.rs
Normal file
222
veilid-core/src/network_manager/native/protocol/sockets.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use async_io::Async;
|
||||
use async_std::net::TcpStream;
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(windows)] {
|
||||
use winapi::shared::ws2def::{ SOL_SOCKET, SO_EXCLUSIVEADDRUSE};
|
||||
use winapi::um::winsock2::{SOCKET_ERROR, setsockopt};
|
||||
use winapi::ctypes::c_int;
|
||||
use std::os::windows::io::AsRawSocket;
|
||||
|
||||
fn set_exclusiveaddruse(socket: &Socket) -> Result<(), String> {
|
||||
unsafe {
|
||||
let optval:c_int = 1;
|
||||
if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
|
||||
std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
|
||||
return Err("Unable to SO_EXCLUSIVEADDRUSE".to_owned());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
}
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
}
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_unbound_shared_udp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket.bind(&socket2_addr).map_err(|e| {
|
||||
format!(
|
||||
"failed to bind UDP socket to '{}' in domain '{:?}': {} ",
|
||||
local_address, domain, e
|
||||
)
|
||||
})?;
|
||||
|
||||
log_net!("created bound shared udp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
}
|
||||
// Bind the socket -first- before turning on 'reuse address' this way it will
|
||||
// fail if the port is already taken
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
|
||||
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
|
||||
cfg_if! {
|
||||
if #[cfg(windows)] {
|
||||
set_exclusiveaddruse(&socket)?;
|
||||
}
|
||||
}
|
||||
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind UDP socket: {}", e))?;
|
||||
|
||||
// Set 'reuse address' so future binds to this port will succeed
|
||||
// This does not work on Windows, where reuse options can not be set after the bind
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
}
|
||||
}
|
||||
log_net!("created bound first udp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
}
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = new_unbound_shared_tcp_socket(domain)?;
|
||||
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
|
||||
log_net!("created bound shared tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
}
|
||||
|
||||
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
|
||||
cfg_if! {
|
||||
if #[cfg(windows)] {
|
||||
set_exclusiveaddruse(&socket)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the socket -first- before turning on 'reuse address' this way it will
|
||||
// fail if the port is already taken
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
|
||||
// Set 'reuse address' so future binds to this port will succeed
|
||||
// This does not work on Windows, where reuse options can not be set after the bind
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
}
|
||||
}
|
||||
log_net!("created bound first tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
// Non-blocking connect is tricky when you want to start with a prepared socket
|
||||
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::Result<TcpStream> {
|
||||
// Set for non blocking connect
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
// Make socket2 SockAddr
|
||||
let socket2_addr = socket2::SockAddr::from(addr);
|
||||
|
||||
// Connect to the remote address
|
||||
match socket.connect(&socket2_addr) {
|
||||
Ok(()) => Ok(()),
|
||||
#[cfg(unix)]
|
||||
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
|
||||
Err(e) => Err(e),
|
||||
}?;
|
||||
|
||||
let async_stream = Async::new(std::net::TcpStream::from(socket))?;
|
||||
|
||||
// The stream becomes writable when connected
|
||||
intf::timeout(2000, async_stream.writable())
|
||||
.await
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e))??;
|
||||
|
||||
// Check low level error
|
||||
let async_stream = match async_stream.get_ref().take_error()? {
|
||||
None => Ok(async_stream),
|
||||
Some(err) => Err(err),
|
||||
}?;
|
||||
|
||||
// Convert back to inner and then return async version
|
||||
Ok(TcpStream::from(async_stream.into_inner()?))
|
||||
}
|
227
veilid-core/src/network_manager/native/protocol/tcp.rs
Normal file
227
veilid-core/src/network_manager/native/protocol/tcp.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use super::*;
|
||||
use futures_util::{AsyncReadExt, AsyncWriteExt};
|
||||
use sockets::*;
|
||||
|
||||
pub struct RawTcpNetworkConnection {
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
|
||||
impl fmt::Debug for RawTcpNetworkConnection {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("RawTCPNetworkConnection").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl RawTcpNetworkConnection {
|
||||
pub fn new(stream: AsyncPeekStream, tcp_stream: TcpStream) -> Self {
|
||||
Self { stream, tcp_stream }
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
.clone()
|
||||
.close()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
}
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
let mut stream = self.stream.clone();
|
||||
stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
stream
|
||||
.write_all(&message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
let mut stream = self.stream.clone();
|
||||
|
||||
stream
|
||||
.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| format!("TCP recv error: {}", e))?;
|
||||
if header[0] != b'V' || header[1] != b'L' {
|
||||
return Err("received invalid TCP frame header".to_owned());
|
||||
}
|
||||
let len = ((header[3] as usize) << 8) | (header[2] as usize);
|
||||
if len > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large TCP frame".to_owned());
|
||||
}
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
stream.read_exact(&mut out).await.map_err(map_to_string)?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
///
|
||||
|
||||
struct RawTcpProtocolHandlerInner {
|
||||
local_address: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RawTcpProtocolHandler
|
||||
where
|
||||
Self: ProtocolAcceptHandler,
|
||||
{
|
||||
inner: Arc<Mutex<RawTcpProtocolHandlerInner>>,
|
||||
}
|
||||
|
||||
impl RawTcpProtocolHandler {
|
||||
fn new_inner(local_address: SocketAddr) -> RawTcpProtocolHandlerInner {
|
||||
RawTcpProtocolHandlerInner { local_address }
|
||||
}
|
||||
|
||||
pub fn new(local_address: SocketAddr) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(Self::new_inner(local_address))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_accept_async(
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
log_net!("TCP: on_accept_async: enter");
|
||||
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
|
||||
let peeklen = stream
|
||||
.peek(&mut peekbuf)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not peek tcp stream"))?;
|
||||
assert_eq!(peeklen, PEEK_DETECT_LEN);
|
||||
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(socket_addr),
|
||||
ProtocolType::TCP,
|
||||
);
|
||||
let local_address = self.inner.lock().local_address;
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream, tcp_stream)),
|
||||
);
|
||||
|
||||
log_net!(debug "TCP: on_accept_async from: {}", socket_addr);
|
||||
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
// Get remote socket address to connect to
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
// Make a shared socket
|
||||
let socket = match local_address {
|
||||
Some(a) => new_bound_shared_tcp_socket(a)?,
|
||||
None => {
|
||||
new_unbound_shared_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?
|
||||
}
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, remote_socket_addr).await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: dial_info.to_peer_address(),
|
||||
},
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
|
||||
);
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound TCP message".to_owned());
|
||||
}
|
||||
trace!(
|
||||
"sending unbound message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
// Make a shared socket
|
||||
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let ps = AsyncPeekStream::new(ts.clone());
|
||||
|
||||
// Wrap the stream in a network connection and return it
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(socket_addr),
|
||||
ProtocolType::TCP,
|
||||
),
|
||||
},
|
||||
ProtocolNetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps, ts)),
|
||||
);
|
||||
conn.send(data).await
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<NetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
}
|
||||
}
|
102
veilid-core/src/network_manager/native/protocol/udp.rs
Normal file
102
veilid-core/src/network_manager/native/protocol/udp.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RawUdpProtocolHandler {
|
||||
socket: Arc<UdpSocket>,
|
||||
}
|
||||
|
||||
impl RawUdpProtocolHandler {
|
||||
pub fn new(socket: Arc<UdpSocket>) -> Self {
|
||||
Self { socket }
|
||||
}
|
||||
|
||||
pub async fn recv_message(
|
||||
&self,
|
||||
data: &mut [u8],
|
||||
) -> Result<(usize, ConnectionDescriptor), String> {
|
||||
let (size, remote_addr) = self.socket.recv_from(data).await.map_err(map_to_string)?;
|
||||
|
||||
if size > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large UDP message".to_owned());
|
||||
}
|
||||
|
||||
trace!(
|
||||
"receiving UDP message of length {} from {}",
|
||||
size,
|
||||
remote_addr
|
||||
);
|
||||
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(remote_addr),
|
||||
ProtocolType::UDP,
|
||||
);
|
||||
let local_socket_addr = self.socket.local_addr().map_err(map_to_string)?;
|
||||
let descriptor = ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_socket_addr),
|
||||
);
|
||||
Ok((size, descriptor))
|
||||
}
|
||||
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
|
||||
log_net!(
|
||||
"sending UDP message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
let len = self
|
||||
.socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed udp send: addr={}", socket_addr))?;
|
||||
|
||||
if len != data.len() {
|
||||
Err("UDP partial send".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound UDP message".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
log_net!(
|
||||
"sending unbound message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
// get local wildcard address for bind
|
||||
let local_socket_addr = match socket_addr {
|
||||
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
|
||||
SocketAddr::V6(_) => {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||
}
|
||||
};
|
||||
let socket = UdpSocket::bind(local_socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
|
||||
let len = socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
|
||||
if len != data.len() {
|
||||
Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
301
veilid-core/src/network_manager/native/protocol/ws.rs
Normal file
301
veilid-core/src/network_manager/native/protocol/ws.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
use super::*;
|
||||
use async_std::io;
|
||||
use async_tls::TlsConnector;
|
||||
use async_tungstenite::tungstenite::protocol::Message;
|
||||
use async_tungstenite::{accept_async, client_async, WebSocketStream};
|
||||
use futures_util::SinkExt;
|
||||
use sockets::*;
|
||||
|
||||
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
|
||||
pub type WebsocketNetworkConnectionWSS =
|
||||
WebsocketNetworkConnection<async_tls::client::TlsStream<TcpStream>>;
|
||||
pub type WebsocketNetworkConnectionWS = WebsocketNetworkConnection<TcpStream>;
|
||||
|
||||
pub struct WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
stream: CloneStream<WebSocketStream<T>>,
|
||||
tcp_stream: TcpStream,
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", std::any::type_name::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
pub fn new(stream: WebSocketStream<T>, tcp_stream: TcpStream) -> Self {
|
||||
Self {
|
||||
stream: CloneStream::new(stream),
|
||||
tcp_stream,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
// Make an attempt to flush the stream
|
||||
self.stream
|
||||
.clone()
|
||||
.close()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
// Then forcibly close the socket
|
||||
self.tcp_stream
|
||||
.shutdown(Shutdown::Both)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
}
|
||||
self.stream
|
||||
.clone()
|
||||
.send(Message::binary(message))
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
}
|
||||
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let out = match self.stream.clone().next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
Some(Ok(_)) => {
|
||||
return Err("Unexpected WS message type".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
///
|
||||
struct WebsocketProtocolHandlerArc {
|
||||
tls: bool,
|
||||
local_address: SocketAddr,
|
||||
request_path: Vec<u8>,
|
||||
connection_initial_timeout: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketProtocolHandler
|
||||
where
|
||||
Self: ProtocolAcceptHandler,
|
||||
{
|
||||
arc: Arc<WebsocketProtocolHandlerArc>,
|
||||
}
|
||||
impl WebsocketProtocolHandler {
|
||||
pub fn new(config: VeilidConfig, tls: bool, local_address: SocketAddr) -> Self {
|
||||
let c = config.get();
|
||||
let path = if tls {
|
||||
format!("GET /{}", c.network.protocol.ws.path.trim_end_matches('/'))
|
||||
} else {
|
||||
format!("GET /{}", c.network.protocol.wss.path.trim_end_matches('/'))
|
||||
};
|
||||
let connection_initial_timeout = if tls {
|
||||
ms_to_us(c.network.tls.connection_initial_timeout_ms)
|
||||
} else {
|
||||
ms_to_us(c.network.connection_initial_timeout_ms)
|
||||
};
|
||||
|
||||
Self {
|
||||
arc: Arc::new(WebsocketProtocolHandlerArc {
|
||||
tls,
|
||||
local_address,
|
||||
request_path: path.as_bytes().to_vec(),
|
||||
connection_initial_timeout,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_accept_async(
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
log_net!("WS: on_accept_async: enter");
|
||||
let request_path_len = self.arc.request_path.len() + 2;
|
||||
|
||||
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
|
||||
match io::timeout(
|
||||
Duration::from_micros(self.arc.connection_initial_timeout),
|
||||
ps.peek_exact(&mut peekbuf),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::TimedOut {
|
||||
return Err(e).map_err(map_to_string).map_err(logthru_net!());
|
||||
}
|
||||
return Err(e).map_err(map_to_string).map_err(logthru_net!(error));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for websocket path
|
||||
let matches_path = &peekbuf[0..request_path_len - 2] == self.arc.request_path.as_slice()
|
||||
&& (peekbuf[request_path_len - 2] == b' '
|
||||
|| (peekbuf[request_path_len - 2] == b'/'
|
||||
&& peekbuf[request_path_len - 1] == b' '));
|
||||
|
||||
if !matches_path {
|
||||
log_net!("WS: not websocket");
|
||||
return Ok(None);
|
||||
}
|
||||
log_net!("WS: found websocket");
|
||||
|
||||
let ws_stream = accept_async(ps)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed websockets handshake"))?;
|
||||
|
||||
// Wrap the websocket in a NetworkConnection and register it
|
||||
let protocol_type = if self.arc.tls {
|
||||
ProtocolType::WSS
|
||||
} else {
|
||||
ProtocolType::WS
|
||||
};
|
||||
|
||||
let peer_addr =
|
||||
PeerAddress::new(SocketAddress::from_socket_addr(socket_addr), protocol_type);
|
||||
|
||||
let conn = NetworkConnection::from_protocol(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(self.arc.local_address),
|
||||
),
|
||||
ProtocolNetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
);
|
||||
|
||||
log_net!(debug "{}: on_accept_async from: {}", if self.arc.tls { "WSS" } else { "WS" }, socket_addr);
|
||||
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match &dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
DialInfo::WSS(_) => (true, "wss"),
|
||||
_ => panic!("invalid dialinfo for WS/WSS protocol"),
|
||||
};
|
||||
let request = dial_info.request().unwrap();
|
||||
let split_url = SplitUrl::from_str(&request)?;
|
||||
if split_url.scheme != scheme {
|
||||
return Err("invalid websocket url scheme".to_string());
|
||||
}
|
||||
let domain = split_url.host.clone();
|
||||
|
||||
// Resolve remote address
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
// Make a shared socket
|
||||
let socket = match local_address {
|
||||
Some(a) => new_bound_shared_tcp_socket(a)?,
|
||||
None => {
|
||||
new_unbound_shared_tcp_socket(socket2::Domain::for_address(remote_socket_addr))?
|
||||
}
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
|
||||
// Make our connection descriptor
|
||||
let descriptor = ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_addr)),
|
||||
remote: dial_info.to_peer_address(),
|
||||
};
|
||||
// Negotiate TLS if this is WSS
|
||||
if tls {
|
||||
let connector = TlsConnector::default();
|
||||
let tls_stream = connector
|
||||
.connect(domain.to_string(), tcp_stream.clone())
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let (ws_stream, _response) = client_async(request, tls_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
Ok(NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Wss(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
))
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream.clone())
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
Ok(NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(
|
||||
ws_stream, tcp_stream,
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
}
|
||||
trace!(
|
||||
"sending unbound websocket message of length {} to {}",
|
||||
data.len(),
|
||||
dial_info,
|
||||
);
|
||||
|
||||
let conn = Self::connect(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
|
||||
conn.send(data).await
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, tcp_stream, peer_addr))
|
||||
}
|
||||
}
|
711
veilid-core/src/network_manager/native/start_protocols.rs
Normal file
711
veilid-core/src/network_manager/native/start_protocols.rs
Normal file
@@ -0,0 +1,711 @@
|
||||
use super::sockets::*;
|
||||
use super::*;
|
||||
use lazy_static::*;
|
||||
|
||||
lazy_static! {
|
||||
static ref BAD_PORTS: BTreeSet<u16> = BTreeSet::from([
|
||||
1, // tcpmux
|
||||
7, // echo
|
||||
9, // discard
|
||||
11, // systat
|
||||
13, // daytime
|
||||
15, // netstat
|
||||
17, // qotd
|
||||
19, // chargen
|
||||
20, // ftp data
|
||||
21, // ftp access
|
||||
22, // ssh
|
||||
23, // telnet
|
||||
25, // smtp
|
||||
37, // time
|
||||
42, // name
|
||||
43, // nicname
|
||||
53, // domain
|
||||
77, // priv-rjs
|
||||
79, // finger
|
||||
87, // ttylink
|
||||
95, // supdup
|
||||
101, // hostriame
|
||||
102, // iso-tsap
|
||||
103, // gppitnp
|
||||
104, // acr-nema
|
||||
109, // pop2
|
||||
110, // pop3
|
||||
111, // sunrpc
|
||||
113, // auth
|
||||
115, // sftp
|
||||
117, // uucp-path
|
||||
119, // nntp
|
||||
123, // NTP
|
||||
135, // loc-srv /epmap
|
||||
139, // netbios
|
||||
143, // imap2
|
||||
179, // BGP
|
||||
389, // ldap
|
||||
427, // SLP (Also used by Apple Filing Protocol)
|
||||
465, // smtp+ssl
|
||||
512, // print / exec
|
||||
513, // login
|
||||
514, // shell
|
||||
515, // printer
|
||||
526, // tempo
|
||||
530, // courier
|
||||
531, // chat
|
||||
532, // netnews
|
||||
540, // uucp
|
||||
548, // AFP (Apple Filing Protocol)
|
||||
556, // remotefs
|
||||
563, // nntp+ssl
|
||||
587, // smtp (rfc6409)
|
||||
601, // syslog-conn (rfc3195)
|
||||
636, // ldap+ssl
|
||||
993, // ldap+ssl
|
||||
995, // pop3+ssl
|
||||
2049, // nfs
|
||||
3659, // apple-sasl / PasswordServer
|
||||
4045, // lockd
|
||||
6000, // X11
|
||||
6665, // Alternate IRC [Apple addition]
|
||||
6666, // Alternate IRC [Apple addition]
|
||||
6667, // Standard IRC [Apple addition]
|
||||
6668, // Alternate IRC [Apple addition]
|
||||
6669, // Alternate IRC [Apple addition]
|
||||
6697, // IRC + TLS
|
||||
]);
|
||||
}
|
||||
|
||||
impl Network {
|
||||
/////////////////////////////////////////////////////
|
||||
// Support for binding first on ports to ensure nobody binds ahead of us
|
||||
// or two copies of the app don't accidentally collide. This is tricky
|
||||
// because we use 'reuseaddr/port' and we can accidentally bind in front of ourselves :P
|
||||
|
||||
fn bind_first_udp_port(&self, udp_port: u16) -> bool {
|
||||
let mut inner = self.inner.lock();
|
||||
if inner.bound_first_udp.contains_key(&udp_port) {
|
||||
return true;
|
||||
}
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let mut bound_first_socket_v4 = None;
|
||||
let mut bound_first_socket_v6 = None;
|
||||
if let Ok(bfs4) =
|
||||
new_bound_first_udp_socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), udp_port))
|
||||
{
|
||||
if let Ok(bfs6) = new_bound_first_udp_socket(SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
|
||||
udp_port,
|
||||
)) {
|
||||
bound_first_socket_v4 = Some(bfs4);
|
||||
bound_first_socket_v6 = Some(bfs6);
|
||||
}
|
||||
}
|
||||
if let (Some(bfs4), Some(bfs6)) = (bound_first_socket_v4, bound_first_socket_v6) {
|
||||
cfg_if! {
|
||||
if #[cfg(windows)] {
|
||||
// On windows, drop the socket. This is a race condition, but there's
|
||||
// no way around it. This isn't for security anyway, it's to prevent multiple copies of the
|
||||
// app from binding on the same port.
|
||||
drop(bfs4);
|
||||
drop(bfs6);
|
||||
inner.bound_first_udp.insert(udp_port, None);
|
||||
} else {
|
||||
inner.bound_first_udp.insert(udp_port, Some((bfs4, bfs6)));
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn bind_first_tcp_port(&self, tcp_port: u16) -> bool {
|
||||
let mut inner = self.inner.lock();
|
||||
if inner.bound_first_tcp.contains_key(&tcp_port) {
|
||||
return true;
|
||||
}
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let mut bound_first_socket_v4 = None;
|
||||
let mut bound_first_socket_v6 = None;
|
||||
if let Ok(bfs4) =
|
||||
new_bound_first_tcp_socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), tcp_port))
|
||||
{
|
||||
if let Ok(bfs6) = new_bound_first_tcp_socket(SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
|
||||
tcp_port,
|
||||
)) {
|
||||
bound_first_socket_v4 = Some(bfs4);
|
||||
bound_first_socket_v6 = Some(bfs6);
|
||||
}
|
||||
}
|
||||
if let (Some(bfs4), Some(bfs6)) = (bound_first_socket_v4, bound_first_socket_v6) {
|
||||
cfg_if! {
|
||||
if #[cfg(windows)] {
|
||||
// On windows, drop the socket. This is a race condition, but there's
|
||||
// no way around it. This isn't for security anyway, it's to prevent multiple copies of the
|
||||
// app from binding on the same port.
|
||||
drop(bfs4);
|
||||
drop(bfs6);
|
||||
inner.bound_first_tcp.insert(tcp_port, None);
|
||||
} else {
|
||||
inner.bound_first_tcp.insert(tcp_port, Some((bfs4, bfs6)));
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn free_bound_first_ports(&self) {
|
||||
let mut inner = self.inner.lock();
|
||||
inner.bound_first_udp.clear();
|
||||
inner.bound_first_tcp.clear();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
fn find_available_udp_port(&self) -> Result<u16, String> {
|
||||
// If the address is empty, iterate ports until we find one we can use.
|
||||
let mut udp_port = 5150u16;
|
||||
loop {
|
||||
if BAD_PORTS.contains(&udp_port) {
|
||||
continue;
|
||||
}
|
||||
if self.bind_first_udp_port(udp_port) {
|
||||
break;
|
||||
}
|
||||
if udp_port == 65535 {
|
||||
return Err("Could not find free udp port to listen on".to_owned());
|
||||
}
|
||||
udp_port += 1;
|
||||
}
|
||||
Ok(udp_port)
|
||||
}
|
||||
|
||||
fn find_available_tcp_port(&self) -> Result<u16, String> {
|
||||
// If the address is empty, iterate ports until we find one we can use.
|
||||
let mut tcp_port = 5150u16;
|
||||
loop {
|
||||
if BAD_PORTS.contains(&tcp_port) {
|
||||
continue;
|
||||
}
|
||||
if self.bind_first_tcp_port(tcp_port) {
|
||||
break;
|
||||
}
|
||||
if tcp_port == 65535 {
|
||||
return Err("Could not find free tcp port to listen on".to_owned());
|
||||
}
|
||||
tcp_port += 1;
|
||||
}
|
||||
Ok(tcp_port)
|
||||
}
|
||||
|
||||
async fn allocate_udp_port(
|
||||
&self,
|
||||
listen_address: String,
|
||||
) -> Result<(u16, Vec<IpAddr>), String> {
|
||||
if listen_address.is_empty() {
|
||||
// If listen address is empty, find us a port iteratively
|
||||
let port = self.find_available_udp_port()?;
|
||||
let ip_addrs = vec![
|
||||
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
|
||||
];
|
||||
Ok((port, ip_addrs))
|
||||
} else {
|
||||
// If no address is specified, but the port is, use ipv4 and ipv6 unspecified
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
|
||||
if sockaddrs.is_empty() {
|
||||
return Err(format!("No valid listen address: {}", listen_address));
|
||||
}
|
||||
let port = sockaddrs[0].port();
|
||||
if self.bind_first_udp_port(port) {
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
} else {
|
||||
Err("Could not find free udp port to listen on".to_owned())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn allocate_tcp_port(
|
||||
&self,
|
||||
listen_address: String,
|
||||
) -> Result<(u16, Vec<IpAddr>), String> {
|
||||
if listen_address.is_empty() {
|
||||
// If listen address is empty, find us a port iteratively
|
||||
let port = self.find_available_tcp_port()?;
|
||||
let ip_addrs = vec![
|
||||
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
|
||||
];
|
||||
Ok((port, ip_addrs))
|
||||
} else {
|
||||
// If no address is specified, but the port is, use ipv4 and ipv6 unspecified
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
|
||||
if sockaddrs.is_empty() {
|
||||
return Err(format!("No valid listen address: {}", listen_address));
|
||||
}
|
||||
let port = sockaddrs[0].port();
|
||||
if self.bind_first_tcp_port(port) {
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
} else {
|
||||
Err("Could not find free tcp port to listen on".to_owned())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
|
||||
trace!("starting udp listeners");
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, public_address, enable_local_peer_scope) = {
|
||||
let c = self.config.get();
|
||||
(
|
||||
c.network.protocol.udp.listen_address.clone(),
|
||||
c.network.protocol.udp.public_address.clone(),
|
||||
c.network.enable_local_peer_scope,
|
||||
)
|
||||
};
|
||||
|
||||
// Pick out UDP port we're going to use everywhere
|
||||
// Keep sockets around until the end of this function
|
||||
// to keep anyone else from binding in front of us
|
||||
let (udp_port, ip_addrs) = self.allocate_udp_port(listen_address.clone()).await?;
|
||||
|
||||
// Save the bound udp port for use later on
|
||||
self.inner.lock().udp_port = udp_port;
|
||||
|
||||
// First, create outbound sockets
|
||||
// (unlike tcp where we create sockets for every connection)
|
||||
// and we'll add protocol handlers for them too
|
||||
self.create_udp_outbound_sockets().await?;
|
||||
|
||||
// Now create udp inbound sockets for whatever interfaces we're listening on
|
||||
info!(
|
||||
"UDP: starting listeners on port {} at {:?}",
|
||||
udp_port, ip_addrs
|
||||
);
|
||||
let local_dial_info_list = self.create_udp_inbound_sockets(ip_addrs, udp_port).await?;
|
||||
let mut static_public = false;
|
||||
|
||||
trace!("UDP: listener started on {:#?}", local_dial_info_list);
|
||||
|
||||
// Register local dial info
|
||||
for di in &local_dial_info_list {
|
||||
// If the local interface address is global, or we are enabling local peer scope
|
||||
// register global dial info if no public address is specified
|
||||
if public_address.is_none() && (di.is_global() || enable_local_peer_scope) {
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
di.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
static_public = true;
|
||||
}
|
||||
|
||||
// Register interface dial info as well since the address is on the local interface
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
di.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
}
|
||||
|
||||
// Add static public dialinfo if it's configured
|
||||
if let Some(public_address) = public_address.as_ref() {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
for pdi_addr in &mut public_sockaddrs {
|
||||
let pdi = DialInfo::udp_from_socketaddr(pdi_addr);
|
||||
|
||||
// Register the public address
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
pdi.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
|
||||
// See if this public address is also a local interface address we haven't registered yet
|
||||
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
|
||||
for ip_addr in ip_addrs {
|
||||
if pdi_addr.ip() == *ip_addr {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
});
|
||||
if !local_dial_info_list.contains(&pdi) && is_interface_address {
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
DialInfo::udp_from_socketaddr(pdi_addr),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
}
|
||||
|
||||
static_public = true;
|
||||
}
|
||||
}
|
||||
|
||||
if static_public {
|
||||
self.inner
|
||||
.lock()
|
||||
.static_public_dialinfo
|
||||
.insert(ProtocolType::UDP);
|
||||
}
|
||||
|
||||
// Now create tasks for udp listeners
|
||||
self.create_udp_listener_tasks().await
|
||||
}
|
||||
|
||||
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
|
||||
trace!("starting ws listeners");
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, url, path, enable_local_peer_scope) = {
|
||||
let c = self.config.get();
|
||||
(
|
||||
c.network.protocol.ws.listen_address.clone(),
|
||||
c.network.protocol.ws.url.clone(),
|
||||
c.network.protocol.ws.path.clone(),
|
||||
c.network.enable_local_peer_scope,
|
||||
)
|
||||
};
|
||||
|
||||
// Pick out TCP port we're going to use everywhere
|
||||
// Keep sockets around until the end of this function
|
||||
// to keep anyone else from binding in front of us
|
||||
let (ws_port, ip_addrs) = self.allocate_tcp_port(listen_address.clone()).await?;
|
||||
|
||||
// Save the bound ws port for use later on
|
||||
self.inner.lock().ws_port = ws_port;
|
||||
|
||||
trace!(
|
||||
"WS: starting listener on port {} at {:?}",
|
||||
ws_port,
|
||||
ip_addrs
|
||||
);
|
||||
let socket_addresses = self
|
||||
.start_tcp_listener(
|
||||
ip_addrs,
|
||||
ws_port,
|
||||
false,
|
||||
Box::new(|c, t, a| Box::new(WebsocketProtocolHandler::new(c, t, a))),
|
||||
)
|
||||
.await?;
|
||||
trace!("WS: listener started on {:#?}", socket_addresses);
|
||||
|
||||
let mut static_public = false;
|
||||
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
|
||||
|
||||
// Add static public dialinfo if it's configured
|
||||
if let Some(url) = url.as_ref() {
|
||||
let mut split_url = SplitUrl::from_str(url)?;
|
||||
if split_url.scheme.to_ascii_lowercase() != "ws" {
|
||||
return Err("WS URL must use 'ws://' scheme".to_owned());
|
||||
}
|
||||
split_url.scheme = "ws".to_owned();
|
||||
|
||||
// Resolve static public hostnames
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(80)
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
for gsa in global_socket_addrs {
|
||||
let pdi = DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
pdi.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
static_public = true;
|
||||
|
||||
// See if this public address is also a local interface address
|
||||
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
|
||||
for ip_addr in ip_addrs {
|
||||
if gsa.ip() == *ip_addr {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
});
|
||||
if !registered_addresses.contains(&gsa.ip()) && is_interface_address {
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
pdi,
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
}
|
||||
|
||||
registered_addresses.insert(gsa.ip());
|
||||
}
|
||||
}
|
||||
|
||||
for socket_address in socket_addresses {
|
||||
// Skip addresses we already did
|
||||
if registered_addresses.contains(&socket_address.to_ip_addr()) {
|
||||
continue;
|
||||
}
|
||||
// Build dial info request url
|
||||
let local_url = format!("ws://{}/{}", socket_address, path);
|
||||
let local_di = DialInfo::try_ws(socket_address, local_url)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
if url.is_none() && (socket_address.address().is_global() || enable_local_peer_scope) {
|
||||
// Register public dial info
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
local_di.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
static_public = true;
|
||||
}
|
||||
|
||||
// Register local dial info
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
local_di,
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
}
|
||||
|
||||
if static_public {
|
||||
self.inner
|
||||
.lock()
|
||||
.static_public_dialinfo
|
||||
.insert(ProtocolType::WS);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
|
||||
trace!("starting wss listeners");
|
||||
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, url) = {
|
||||
let c = self.config.get();
|
||||
(
|
||||
c.network.protocol.wss.listen_address.clone(),
|
||||
c.network.protocol.wss.url.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
// Pick out TCP port we're going to use everywhere
|
||||
// Keep sockets around until the end of this function
|
||||
// to keep anyone else from binding in front of us
|
||||
let (wss_port, ip_addrs) = self.allocate_tcp_port(listen_address.clone()).await?;
|
||||
|
||||
// Save the bound wss port for use later on
|
||||
self.inner.lock().wss_port = wss_port;
|
||||
|
||||
trace!(
|
||||
"WSS: starting listener on port {} at {:?}",
|
||||
wss_port,
|
||||
ip_addrs
|
||||
);
|
||||
let socket_addresses = self
|
||||
.start_tcp_listener(
|
||||
ip_addrs,
|
||||
wss_port,
|
||||
true,
|
||||
Box::new(|c, t, a| Box::new(WebsocketProtocolHandler::new(c, t, a))),
|
||||
)
|
||||
.await?;
|
||||
trace!("WSS: listener started on {:#?}", socket_addresses);
|
||||
|
||||
// NOTE: No interface dial info for WSS, as there is no way to connect to a local dialinfo via TLS
|
||||
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
|
||||
// is specified, then TLS won't validate, so no local dialinfo is possible.
|
||||
// This is not the case with unencrypted websockets, which can be specified solely by an IP address
|
||||
|
||||
let mut static_public = false;
|
||||
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
|
||||
|
||||
// Add static public dialinfo if it's configured
|
||||
if let Some(url) = url.as_ref() {
|
||||
// Add static public dialinfo if it's configured
|
||||
let mut split_url = SplitUrl::from_str(url)?;
|
||||
if split_url.scheme.to_ascii_lowercase() != "wss" {
|
||||
return Err("WSS URL must use 'wss://' scheme".to_owned());
|
||||
}
|
||||
split_url.scheme = "wss".to_owned();
|
||||
|
||||
// Resolve static public hostnames
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(443)
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
for gsa in global_socket_addrs {
|
||||
let pdi = DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
pdi.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
static_public = true;
|
||||
|
||||
// See if this public address is also a local interface address
|
||||
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
|
||||
for ip_addr in ip_addrs {
|
||||
if gsa.ip() == *ip_addr {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
});
|
||||
if !registered_addresses.contains(&gsa.ip()) && is_interface_address {
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
pdi,
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
}
|
||||
|
||||
registered_addresses.insert(gsa.ip());
|
||||
}
|
||||
} else {
|
||||
return Err("WSS URL must be specified due to TLS requirements".to_owned());
|
||||
}
|
||||
|
||||
if static_public {
|
||||
self.inner
|
||||
.lock()
|
||||
.static_public_dialinfo
|
||||
.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
|
||||
trace!("starting tcp listeners");
|
||||
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, public_address, enable_local_peer_scope) = {
|
||||
let c = self.config.get();
|
||||
(
|
||||
c.network.protocol.tcp.listen_address.clone(),
|
||||
c.network.protocol.tcp.public_address.clone(),
|
||||
c.network.enable_local_peer_scope,
|
||||
)
|
||||
};
|
||||
|
||||
// Pick out TCP port we're going to use everywhere
|
||||
// Keep sockets around until the end of this function
|
||||
// to keep anyone else from binding in front of us
|
||||
let (tcp_port, ip_addrs) = self.allocate_tcp_port(listen_address.clone()).await?;
|
||||
|
||||
// Save the bound tcp port for use later on
|
||||
self.inner.lock().tcp_port = tcp_port;
|
||||
|
||||
trace!(
|
||||
"TCP: starting listener on port {} at {:?}",
|
||||
tcp_port,
|
||||
ip_addrs
|
||||
);
|
||||
let socket_addresses = self
|
||||
.start_tcp_listener(
|
||||
ip_addrs,
|
||||
tcp_port,
|
||||
false,
|
||||
Box::new(|_, _, a| Box::new(RawTcpProtocolHandler::new(a))),
|
||||
)
|
||||
.await?;
|
||||
trace!("TCP: listener started on {:#?}", socket_addresses);
|
||||
|
||||
let mut static_public = false;
|
||||
let mut registered_addresses: HashSet<IpAddr> = HashSet::new();
|
||||
|
||||
for socket_address in socket_addresses {
|
||||
let di = DialInfo::tcp(socket_address);
|
||||
|
||||
// Register global dial info if no public address is specified
|
||||
if public_address.is_none() && (di.is_global() || enable_local_peer_scope) {
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
di.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
static_public = true;
|
||||
}
|
||||
// Register interface dial info
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
di.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
registered_addresses.insert(socket_address.to_ip_addr());
|
||||
}
|
||||
|
||||
// Add static public dialinfo if it's configured
|
||||
if let Some(public_address) = public_address.as_ref() {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.await
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
for pdi_addr in &mut public_sockaddrs {
|
||||
// Skip addresses we already did
|
||||
if registered_addresses.contains(&pdi_addr.ip()) {
|
||||
continue;
|
||||
}
|
||||
let pdi = DialInfo::tcp_from_socketaddr(pdi_addr);
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
pdi.clone(),
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
static_public = true;
|
||||
|
||||
// See if this public address is also a local interface address
|
||||
let is_interface_address = self.with_interface_addresses(|ip_addrs| {
|
||||
for ip_addr in ip_addrs {
|
||||
if pdi_addr.ip() == *ip_addr {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
});
|
||||
if is_interface_address {
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::LocalNetwork,
|
||||
pdi,
|
||||
DialInfoClass::Direct,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if static_public {
|
||||
self.inner
|
||||
.lock()
|
||||
.static_public_dialinfo
|
||||
.insert(ProtocolType::TCP);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
175
veilid-core/src/network_manager/network_connection.rs
Normal file
175
veilid-core/src/network_manager/network_connection.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use super::*;
|
||||
use crate::xx::*;
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
// No accept support for WASM
|
||||
} else {
|
||||
use async_std::net::*;
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
// Accept
|
||||
|
||||
pub trait ProtocolAcceptHandler: ProtocolAcceptHandlerClone + Send + Sync {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
tcp_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>>;
|
||||
}
|
||||
|
||||
pub trait ProtocolAcceptHandlerClone {
|
||||
fn clone_box(&self) -> Box<dyn ProtocolAcceptHandler>;
|
||||
}
|
||||
|
||||
impl<T> ProtocolAcceptHandlerClone for T
|
||||
where
|
||||
T: 'static + ProtocolAcceptHandler + Clone,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn ProtocolAcceptHandler> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
impl Clone for Box<dyn ProtocolAcceptHandler> {
|
||||
fn clone(&self) -> Box<dyn ProtocolAcceptHandler> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
pub type NewProtocolAcceptHandler =
|
||||
dyn Fn(VeilidConfig, bool, SocketAddr) -> Box<dyn ProtocolAcceptHandler> + Send;
|
||||
}
|
||||
}
|
||||
///////////////////////////////////////////////////////////
|
||||
// Dummy protocol network connection for testing
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DummyNetworkConnection {}
|
||||
|
||||
impl DummyNetworkConnection {
|
||||
pub fn close(&self) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
pub fn send(&self, _message: Vec<u8>) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
pub fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
// Top-level protocol independent network connection object
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NetworkConnectionStats {
|
||||
last_message_sent_time: Option<u64>,
|
||||
last_message_recv_time: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NetworkConnectionInner {
|
||||
stats: NetworkConnectionStats,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NetworkConnectionArc {
|
||||
descriptor: ConnectionDescriptor,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
established_time: u64,
|
||||
inner: Mutex<NetworkConnectionInner>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NetworkConnection {
|
||||
arc: Arc<NetworkConnectionArc>,
|
||||
}
|
||||
impl PartialEq for NetworkConnection {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::as_ptr(&self.arc) == Arc::as_ptr(&other.arc)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for NetworkConnection {}
|
||||
|
||||
impl NetworkConnection {
|
||||
fn new_inner() -> NetworkConnectionInner {
|
||||
NetworkConnectionInner {
|
||||
stats: NetworkConnectionStats {
|
||||
last_message_sent_time: None,
|
||||
last_message_recv_time: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
fn new_arc(
|
||||
descriptor: ConnectionDescriptor,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
) -> NetworkConnectionArc {
|
||||
NetworkConnectionArc {
|
||||
descriptor,
|
||||
protocol_connection,
|
||||
established_time: intf::get_timestamp(),
|
||||
inner: Mutex::new(Self::new_inner()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dummy(descriptor: ConnectionDescriptor) -> Self {
|
||||
NetworkConnection::from_protocol(
|
||||
descriptor,
|
||||
ProtocolNetworkConnection::Dummy(DummyNetworkConnection {}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_protocol(
|
||||
descriptor: ConnectionDescriptor,
|
||||
protocol_connection: ProtocolNetworkConnection,
|
||||
) -> Self {
|
||||
Self {
|
||||
arc: Arc::new(Self::new_arc(descriptor, protocol_connection)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
ProtocolNetworkConnection::connect(local_address, dial_info).await
|
||||
}
|
||||
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
self.arc.descriptor
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
self.arc.protocol_connection.close().await
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
let ts = intf::get_timestamp();
|
||||
let out = self.arc.protocol_connection.send(message).await;
|
||||
if out.is_ok() {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
inner.stats.last_message_sent_time.max_assign(Some(ts));
|
||||
}
|
||||
out
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let ts = intf::get_timestamp();
|
||||
let out = self.arc.protocol_connection.recv().await;
|
||||
if out.is_ok() {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
inner.stats.last_message_recv_time.max_assign(Some(ts));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> NetworkConnectionStats {
|
||||
let inner = self.arc.inner.lock();
|
||||
inner.stats.clone()
|
||||
}
|
||||
|
||||
pub fn established_time(&self) -> u64 {
|
||||
self.arc.established_time
|
||||
}
|
||||
}
|
2
veilid-core/src/network_manager/tests/mod.rs
Normal file
2
veilid-core/src/network_manager/tests/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod test_connection_table;
|
||||
use super::*;
|
101
veilid-core/src/network_manager/tests/test_connection_table.rs
Normal file
101
veilid-core/src/network_manager/tests/test_connection_table.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use super::connection_table::*;
|
||||
use super::network_connection::*;
|
||||
use crate::tests::common::test_veilid_config::*;
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
|
||||
pub async fn test_add_get_remove() {
|
||||
let config = get_config();
|
||||
|
||||
let mut table = ConnectionTable::new(config);
|
||||
|
||||
let a1 = ConnectionDescriptor::new_no_local(PeerAddress::new(
|
||||
SocketAddress::new(Address::IPV4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
|
||||
ProtocolType::TCP,
|
||||
));
|
||||
let a2 = a1;
|
||||
let a3 = ConnectionDescriptor::new(
|
||||
PeerAddress::new(
|
||||
SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090),
|
||||
ProtocolType::TCP,
|
||||
),
|
||||
SocketAddress::from_socket_addr(SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
|
||||
8080,
|
||||
0,
|
||||
0,
|
||||
))),
|
||||
);
|
||||
let a4 = ConnectionDescriptor::new(
|
||||
PeerAddress::new(
|
||||
SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090),
|
||||
ProtocolType::TCP,
|
||||
),
|
||||
SocketAddress::from_socket_addr(SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::new(1, 0, 0, 0, 0, 0, 0, 1),
|
||||
8080,
|
||||
0,
|
||||
0,
|
||||
))),
|
||||
);
|
||||
let a5 = ConnectionDescriptor::new(
|
||||
PeerAddress::new(
|
||||
SocketAddress::new(Address::IPV6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8090),
|
||||
ProtocolType::WSS,
|
||||
),
|
||||
SocketAddress::from_socket_addr(SocketAddr::V6(SocketAddrV6::new(
|
||||
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
|
||||
8080,
|
||||
0,
|
||||
0,
|
||||
))),
|
||||
);
|
||||
|
||||
let c1 = NetworkConnection::dummy(a1);
|
||||
let c2 = NetworkConnection::dummy(a2);
|
||||
let c3 = NetworkConnection::dummy(a3);
|
||||
let c4 = NetworkConnection::dummy(a4);
|
||||
let c5 = NetworkConnection::dummy(a5);
|
||||
|
||||
assert_eq!(a1, c2.connection_descriptor());
|
||||
assert_ne!(a3, c4.connection_descriptor());
|
||||
assert_ne!(a4, c5.connection_descriptor());
|
||||
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_eq!(table.get_connection(a1), None);
|
||||
table.add_connection(c1.clone()).unwrap();
|
||||
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_err!(table.remove_connection(a3));
|
||||
assert_err!(table.remove_connection(a4));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_err!(table.add_connection(c1.clone()));
|
||||
assert_err!(table.add_connection(c2.clone()));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.get_connection(a1), Some(c1.clone()));
|
||||
assert_eq!(table.connection_count(), 1);
|
||||
assert_eq!(table.remove_connection(a2), Ok(c1.clone()));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_err!(table.remove_connection(a2));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
assert_eq!(table.get_connection(a2), None);
|
||||
assert_eq!(table.get_connection(a1), None);
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
table.add_connection(c1.clone()).unwrap();
|
||||
assert_err!(table.add_connection(c2));
|
||||
table.add_connection(c3.clone()).unwrap();
|
||||
table.add_connection(c4.clone()).unwrap();
|
||||
assert_eq!(table.connection_count(), 3);
|
||||
assert_eq!(table.remove_connection(a2), Ok(c1));
|
||||
assert_eq!(table.remove_connection(a3), Ok(c3));
|
||||
assert_eq!(table.remove_connection(a4), Ok(c4));
|
||||
assert_eq!(table.connection_count(), 0);
|
||||
}
|
||||
|
||||
pub async fn test_all() {
|
||||
test_add_get_remove().await;
|
||||
}
|
235
veilid-core/src/network_manager/wasm/mod.rs
Normal file
235
veilid-core/src/network_manager/wasm/mod.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
mod protocol;
|
||||
|
||||
use crate::connection_manager::*;
|
||||
use crate::network_manager::*;
|
||||
use crate::routing_table::*;
|
||||
use crate::intf::*;
|
||||
use crate::*;
|
||||
use protocol::ws::WebsocketProtocolHandler;
|
||||
pub use protocol::*;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
struct NetworkInner {
|
||||
network_manager: NetworkManager,
|
||||
stop_network: Eventual,
|
||||
network_started: bool,
|
||||
network_needs_restart: bool,
|
||||
protocol_config: Option<ProtocolConfig>,
|
||||
//join_handle: TryJoin?
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Network {
|
||||
config: VeilidConfig,
|
||||
inner: Arc<Mutex<NetworkInner>>,
|
||||
}
|
||||
|
||||
impl Network {
|
||||
fn new_inner(network_manager: NetworkManager) -> NetworkInner {
|
||||
NetworkInner {
|
||||
network_manager,
|
||||
stop_network: Eventual::new(),
|
||||
network_started: false,
|
||||
network_needs_restart: false,
|
||||
protocol_config: None, //join_handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(network_manager: NetworkManager) -> Self {
|
||||
Self {
|
||||
config: network_manager.config(),
|
||||
inner: Arc::new(Mutex::new(Self::new_inner(network_manager))),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn network_manager(&self) -> NetworkManager {
|
||||
self.inner.lock().network_manager.clone()
|
||||
}
|
||||
fn connection_manager(&self) -> ConnectionManager {
|
||||
self.inner.lock().network_manager.connection_manager()
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub async fn send_data_unbound_to_dial_info(
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
let data_len = data.len();
|
||||
|
||||
let res = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
return Err("no support for UDP protocol".to_owned()).map_err(logthru_net!(error))
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
return Err("no support for TCP protocol".to_owned()).map_err(logthru_net!(error))
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data)
|
||||
.await
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
};
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub async fn send_data_to_existing_connection(
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
let data_len = data.len();
|
||||
match descriptor.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
return Err("no support for udp protocol".to_owned()).map_err(logthru_net!(error))
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
return Err("no support for tcp protocol".to_owned()).map_err(logthru_net!(error))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
|
||||
// Try to send to the exact existing connection if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
|
||||
// connection exists, send over it
|
||||
conn.send(data).await.map_err(logthru_net!())?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(descriptor.remote.to_socket_addr().ip(), data_len as u64);
|
||||
|
||||
// Data was consumed
|
||||
Ok(None)
|
||||
} else {
|
||||
// Connection or didn't exist
|
||||
// Pass the data back out so we don't own it any more
|
||||
Ok(Some(data))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_data_to_dial_info(
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
let data_len = data.len();
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
return Err("no support for UDP protocol".to_owned()).map_err(logthru_net!(error))
|
||||
}
|
||||
if dial_info.protocol_type() == ProtocolType::TCP {
|
||||
return Err("no support for TCP protocol".to_owned()).map_err(logthru_net!(error))
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
let conn = self
|
||||
.connection_manager()
|
||||
.get_or_create_connection(None, dial_info.clone())
|
||||
.await?;
|
||||
|
||||
let res = conn.send(data).await.map_err(logthru_net!(error));
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
.stats_packet_sent(dial_info.to_ip_addr(), data_len as u64);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub async fn startup(&self) -> Result<(), String> {
|
||||
// get protocol config
|
||||
self.inner.lock().protocol_config = Some({
|
||||
let c = self.config.get();
|
||||
let inbound = ProtocolSet::new();
|
||||
let mut outbound = ProtocolSet::new();
|
||||
|
||||
if c.network.protocol.ws.connect && c.capabilities.protocol_connect_ws {
|
||||
outbound.insert(ProtocolType::WS);
|
||||
}
|
||||
if c.network.protocol.wss.connect && c.capabilities.protocol_connect_wss {
|
||||
outbound.insert(ProtocolType::WSS);
|
||||
}
|
||||
|
||||
ProtocolConfig { inbound, outbound }
|
||||
});
|
||||
|
||||
self.inner.lock().network_started = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn needs_restart(&self) -> bool {
|
||||
self.inner.lock().network_needs_restart
|
||||
}
|
||||
|
||||
pub fn is_started(&self) -> bool {
|
||||
self.inner.lock().network_started
|
||||
}
|
||||
|
||||
pub fn restart_network(&self) {
|
||||
self.inner.lock().network_needs_restart = true;
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) {
|
||||
trace!("stopping network");
|
||||
|
||||
// Reset state
|
||||
let network_manager = self.inner.lock().network_manager.clone();
|
||||
let routing_table = network_manager.routing_table();
|
||||
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details(RoutingDomain::PublicInternet);
|
||||
routing_table.clear_dial_info_details(RoutingDomain::LocalNetwork);
|
||||
|
||||
// Cancels all async background tasks by dropping join handles
|
||||
*self.inner.lock() = Self::new_inner(network_manager);
|
||||
|
||||
trace!("network stopped");
|
||||
}
|
||||
|
||||
pub fn with_interface_addresses<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&[IpAddr]) -> R,
|
||||
{
|
||||
f(&[])
|
||||
}
|
||||
|
||||
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////
|
||||
pub fn get_network_class(&self) -> Option<NetworkClass> {
|
||||
// xxx eventually detect tor browser?
|
||||
return if self.inner.lock().network_started {
|
||||
Some(NetworkClass::WebApp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
pub fn reset_network_class(&self) {
|
||||
//let mut inner = self.inner.lock();
|
||||
//inner.network_class = None;
|
||||
}
|
||||
|
||||
pub fn get_protocol_config(&self) -> Option<ProtocolConfig> {
|
||||
self.inner.lock().protocol_config.clone()
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
pub async fn tick(&self) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
68
veilid-core/src/network_manager/wasm/protocol/mod.rs
Normal file
68
veilid-core/src/network_manager/wasm/protocol/mod.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
pub mod wrtc;
|
||||
pub mod ws;
|
||||
|
||||
use crate::network_connection::*;
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProtocolNetworkConnection {
|
||||
Dummy(DummyNetworkConnection),
|
||||
Ws(ws::WebsocketNetworkConnection),
|
||||
//WebRTC(wrtc::WebRTCNetworkConnection),
|
||||
}
|
||||
|
||||
impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("UDP dial info is not support on WASM targets");
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
panic!("TCP dial info is not support on WASM targets");
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("UDP dial info is not support on WASM targets");
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
panic!("TCP dial info is not support on WASM targets");
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::send_unbound_message(dial_info, data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.close(),
|
||||
Self::Ws(w) => w.close().await,
|
||||
}
|
||||
}
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::Ws(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::Ws(w) => w.recv().await,
|
||||
}
|
||||
}
|
||||
}
|
123
veilid-core/src/network_manager/wasm/protocol/ws.rs
Normal file
123
veilid-core/src/network_manager/wasm/protocol/ws.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use crate::intf::*;
|
||||
use crate::network_connection::*;
|
||||
use crate::network_manager::MAX_MESSAGE_SIZE;
|
||||
use crate::*;
|
||||
use alloc::fmt;
|
||||
use ws_stream_wasm::*;
|
||||
use futures_util::{StreamExt, SinkExt};
|
||||
|
||||
struct WebsocketNetworkConnectionInner {
|
||||
ws_meta: WsMeta,
|
||||
ws_stream: CloneStream<WsStream>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketNetworkConnection {
|
||||
inner: Arc<WebsocketNetworkConnectionInner>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for WebsocketNetworkConnection {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", core::any::type_name::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
impl WebsocketNetworkConnection {
|
||||
pub fn new(ws_meta: WsMeta, ws_stream: WsStream) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(WebsocketNetworkConnectionInner {
|
||||
ws_meta,
|
||||
ws_stream: CloneStream::new(ws_stream),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
self.inner.ws_stream.clone()
|
||||
.send(WsMessage::Binary(message)).await
|
||||
.map_err(|_| "failed to send to websocket".to_owned())
|
||||
.map_err(logthru_net!(error))
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let out = match self.inner.ws_stream.clone().next().await {
|
||||
Some(WsMessage::Binary(v)) => v,
|
||||
Some(_) => {
|
||||
return Err("Unexpected WS message type".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
///
|
||||
|
||||
pub struct WebsocketProtocolHandler {}
|
||||
|
||||
impl WebsocketProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
|
||||
assert!(local_address.is_none());
|
||||
|
||||
// Split dial info up
|
||||
let (_tls, scheme) = match &dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
DialInfo::WSS(_) => (true, "wss"),
|
||||
_ => panic!("invalid dialinfo for WS/WSS protocol"),
|
||||
};
|
||||
let request = dial_info.request().unwrap();
|
||||
let split_url = SplitUrl::from_str(&request)?;
|
||||
if split_url.scheme != scheme {
|
||||
return Err("invalid websocket url scheme".to_string());
|
||||
}
|
||||
|
||||
let (wsmeta, wsio) = WsMeta::connect(request, None)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
// Make our connection descriptor
|
||||
|
||||
Ok(NetworkConnection::from_protocol(ConnectionDescriptor {
|
||||
local: None,
|
||||
remote: dial_info.to_peer_address(),
|
||||
},ProtocolNetworkConnection::Ws(WebsocketNetworkConnection::new(wsmeta, wsio))))
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
}
|
||||
trace!(
|
||||
"sending unbound websocket message of length {} to {}",
|
||||
data.len(),
|
||||
dial_info,
|
||||
);
|
||||
|
||||
// Make the real connection
|
||||
let conn = Self::connect(None, dial_info)
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
|
||||
conn.send(data).await
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user