fix udp and refactor native network

This commit is contained in:
John Smith
2021-12-31 22:09:30 -05:00
parent c6f573ffe0
commit 0e0209a54b
18 changed files with 870 additions and 760 deletions

View File

@@ -1,6 +1,9 @@
mod listener_state;
mod network_tcp;
mod network_udp;
mod protocol;
mod public_dialinfo_discovery;
mod start_protocols;
use crate::intf::*;
use crate::network_manager::*;
@@ -12,7 +15,6 @@ use protocol::udp::RawUdpProtocolHandler;
use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*;
use utils::async_peek_stream::*;
use utils::clone_stream::*;
use utils::network_interfaces::*;
use async_std::io;
@@ -45,16 +47,18 @@ struct NetworkInner {
ws_static_public_dialinfo: bool,
network_class: Option<NetworkClass>,
join_handles: Vec<JoinHandle<()>>,
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
tls_acceptor: Option<TlsAcceptor>,
udp_port: u16,
tcp_port: u16,
ws_port: u16,
wss_port: u16,
interfaces: NetworkInterfaces,
// udp
inbound_udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
outbound_udpv4_protocol_handler: Option<RawUdpProtocolHandler>,
outbound_udpv6_protocol_handler: Option<RawUdpProtocolHandler>,
interfaces: NetworkInterfaces,
//tcp
tls_acceptor: Option<TlsAcceptor>,
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
}
struct NetworkUnlockedInner {
@@ -83,16 +87,16 @@ impl Network {
ws_static_public_dialinfo: false,
network_class: None,
join_handles: Vec::new(),
listener_states: BTreeMap::new(),
udp_protocol_handlers: BTreeMap::new(),
tls_acceptor: None,
udp_port: 0u16,
tcp_port: 0u16,
ws_port: 0u16,
wss_port: 0u16,
interfaces: NetworkInterfaces::new(),
inbound_udp_protocol_handlers: BTreeMap::new(),
outbound_udpv4_protocol_handler: None,
outbound_udpv6_protocol_handler: None,
interfaces: NetworkInterfaces::new(),
tls_acceptor: None,
listener_states: BTreeMap::new(),
}
}
@@ -201,430 +205,21 @@ impl Network {
Ok(config)
}
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
return Ok(ts.clone());
}
let server_config = self
.load_server_config()
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
self.inner.lock().tls_acceptor = Some(acceptor.clone());
Ok(acceptor)
}
fn add_to_join_handles(&self, jh: JoinHandle<()>) {
let mut inner = self.inner.lock();
inner.join_handles.push(jh);
}
async fn try_tls_handlers(
&self,
tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
tls_connection_initial_timeout: u64,
) {
match tls_acceptor.accept(stream).await {
Ok(ts) => {
let ps = AsyncPeekStream::new(CloneStream::new(ts));
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// Try the handlers but first get a chunk of data for them to process
// Don't waste more than N seconds getting it though, in case someone
// is trying to DoS us with a bunch of connections or something
// read a chunk of the stream
match io::timeout(
Duration::from_micros(tls_connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
{
Ok(()) => (),
Err(_) => return,
}
self.clone().try_handlers(ps, addr, protocol_handlers).await;
}
Err(e) => {
debug!("TLS stream failed handshake: {}", e);
}
}
}
async fn try_handlers(
&self,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
) {
for ah in protocol_handlers.iter() {
if ah.on_accept(stream.clone(), addr).await == Ok(true) {
return;
}
}
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
// Get config
let (connection_initial_timeout, tls_connection_initial_timeout) = {
let c = self.config.get();
(
c.network.connection_initial_timeout,
c.network.tls.connection_initial_timeout,
)
};
// Create a reusable socket with no linger time, and no delay
let socket = new_shared_tcp_socket(addr)?;
// Listen on the socket
socket
.listen(128)
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from(std_listener);
trace!("spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
self.inner
.lock()
.listener_states
.insert(addr, listener_state.clone());
// Spawn the socket task
let this = self.clone();
////////////////////////////////////////////////////////////
let jh = spawn(async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
listener
.incoming()
.for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap();
let listener_state = listener_state.clone();
// match tcp_stream.set_nodelay(true) {
// Ok(_) => (),
// _ => continue,
// };
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(err) => {
error!("failed to get peer address: {}", err);
return;
}
};
// XXX limiting
trace!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream);
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
trace!("reading chunk");
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
return;
}
// Run accept handlers on accepted stream
trace!("packet ready");
// Check is this could be TLS
let ls = listener_state.read().clone();
if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
trace!("trying TLS");
this.clone()
.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await;
} else {
trace!("not TLS");
this.clone()
.try_handlers(ps, addr, &ls.protocol_handlers)
.await;
}
})
.await;
trace!("exited incoming loop for {}", addr);
// Remove our listener state from this address if we're stopping
this.inner.lock().listener_states.remove(&addr);
trace!("listener state removed for {}", addr);
// If this happened our low-level listener socket probably died
// so it's time to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handles
self.add_to_join_handles(jh);
Ok(())
}
/////////////////////////////////////////////////////////////////
// TCP listener that multiplexes ports so multiple protocols can exist on a single port
async fn start_tcp_listener(
&self,
address: String,
is_tls: bool,
new_tcp_protocol_handler: Box<NewTcpProtocolHandler>,
) -> Result<Vec<SocketAddress>, String> {
let mut out = Vec::<SocketAddress>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
let ldi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().listener_states.contains_key(&addr) {
self.clone().spawn_socket_listener(addr).await?;
}
let ls = if let Some(ls) = self.inner.lock().listener_states.get_mut(&addr) {
ls.clone()
} else {
panic!("this shouldn't happen");
};
if is_tls {
if ls.read().tls_acceptor.is_none() {
ls.write().tls_acceptor = Some(self.clone().get_or_create_tls_acceptor()?);
}
ls.write()
.tls_protocol_handlers
.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
true,
addr,
));
} else {
ls.write().protocol_handlers.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
false,
addr,
));
}
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(SocketAddress::from_socket_addr(ldi_addr));
}
}
Ok(out)
}
////////////////////////////////////////////////////////////
async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
let mut inner = self.inner.lock();
let mut port = inner.udp_port;
// v4
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v4) {
log_net!("created udpv4 outbound socket on {:?}", &socket_addr_v4);
// Pull the port if we randomly bound, so v6 can be on the same port
port = socket
.local_addr()
.map_err(map_to_string)?
.as_socket_ipv4()
.ok_or_else(|| "expected ipv4 address type".to_owned())?
.port();
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv4_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv4_protocol_handler = Some(udpv4_handler);
}
//v6
let socket_addr_v6 =
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v6) {
log_net!("created udpv6 outbound socket on {:?}", &socket_addr_v6);
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv6_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv6_protocol_handler = Some(udpv6_handler);
}
Ok(())
}
async fn spawn_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
log_net!("spawn_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket
let socket = new_shared_udp_socket(addr)?;
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let protocol_handler = RawUdpProtocolHandler::new(
self.inner.lock().network_manager.clone(),
socket_arc.clone(),
);
// Create message_handler records
self.inner
.lock()
.udp_protocol_handlers
.insert(addr, protocol_handler.clone());
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
c.network.protocol.udp.socket_pool_size
};
if task_count == 0 {
task_count = intf::get_concurrency() / 2;
if task_count == 0 {
task_count = 1;
}
}
trace!("task_count: {}", task_count);
for _ in 0..task_count {
let socket = socket_arc.clone();
let protocol_handler = protocol_handler.clone();
trace!("Spawning UDP listener task");
////////////////////////////////////////////////////////////
// Run task for messages
let this = self.clone();
let jh = spawn(async move {
trace!("UDP listener task spawned");
let mut data = vec![0u8; 65536];
while let Ok((size, socket_addr)) = socket.recv_from(&mut data).await {
// XXX: Limit the number of packets from the same IP address?
trace!("UDP packet from: {}", socket_addr);
let _processed = protocol_handler
.clone()
.on_message(&data[..size], socket_addr)
.await;
}
trace!("UDP listener task stopped");
// If this loop fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handle
self.add_to_join_handles(jh);
}
Ok(())
}
fn translate_unspecified_address(inner: &NetworkInner, from: &SocketAddr) -> Vec<SocketAddr> {
if !from.ip().is_unspecified() {
vec![*from]
} else {
let mut out = Vec::<SocketAddr>::with_capacity(inner.interfaces.len());
for (_, intf) in inner.interfaces.iter() {
if intf.is_loopback() {
continue;
}
if let Some(pipv4) = intf.primary_ipv4() {
out.push(SocketAddr::new(IpAddr::V4(pipv4), from.port()));
}
if let Some(pipv6) = intf.primary_ipv6() {
out.push(SocketAddr::new(IpAddr::V6(pipv6), from.port()));
}
}
out
}
}
async fn start_udp_handler(&self, address: String) -> Result<Vec<DialInfo>, String> {
let mut out = Vec::<DialInfo>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().udp_protocol_handlers.contains_key(&addr) {
let ldi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
self.clone().spawn_udp_inbound_socket(addr).await?;
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(DialInfo::udp_from_socketaddr(ldi_addr));
}
}
}
Ok(out)
}
/////////////////////////////////////////////////////////////////
fn find_best_udp_protocol_handler(
&self,
peer_socket_addr: &SocketAddr,
local_socket_addr: &Option<SocketAddr>,
) -> Option<RawUdpProtocolHandler> {
// if our last communication with this peer came from a particular udp protocol handler, use it
if let Some(sa) = local_socket_addr {
if let Some(ph) = self.inner.lock().udp_protocol_handlers.get(sa) {
return Some(ph.clone());
}
}
// otherwise find the outbound udp protocol handler that matches the ip protocol version of the peer addr
let inner = self.inner.lock();
match peer_socket_addr {
SocketAddr::V4(_) => inner.outbound_udpv4_protocol_handler.clone(),
SocketAddr::V6(_) => inner.outbound_udpv6_protocol_handler.clone(),
inner
.interfaces
.default_route_addresses()
.iter()
.map(|a| SocketAddr::new(*a, from.port()))
.collect()
}
}
@@ -642,6 +237,8 @@ impl Network {
}
}
////////////////////////////////////////////////////////////
async fn send_data_to_existing_connection(
&self,
descriptor: &ConnectionDescriptor,
@@ -787,272 +384,6 @@ impl Network {
/////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////
pub async fn start_udp_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(),
)
};
info!("UDP: starting listener at {:?}", listen_address);
let dial_infos = self.start_udp_handler(listen_address.clone()).await?;
let mut static_public = false;
for di in &dial_infos {
// Pick out UDP port for outbound connections (they will all be the same)
self.inner.lock().udp_port = di.port();
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::udp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
}
}
self.inner.lock().udp_static_public_dialinfo = static_public;
Ok(())
}
pub async fn start_ws_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url, path) = {
let c = self.config.get();
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
)
};
trace!("WS: starting listener at {:?}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WS: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out WS port for outbound connections (they will all be the same)
self.inner.lock().ws_port = socket_address.port();
if url.is_none() && socket_address.address().is_global() {
// Build global dial info request url
let global_url = format!("ws://{}/{}", socket_address, path);
// Create global dial info
let di = DialInfo::try_ws(socket_address, global_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(
di,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if socket_address.address().is_local() {
// Build local dial info request url
let local_url = format!("ws://{}/{}", socket_address, path);
// Create local dial info
let di = DialInfo::try_ws(socket_address, local_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(di, DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
}
split_url.scheme = "ws".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
static_public = true;
}
self.inner.lock().ws_static_public_dialinfo = static_public;
Ok(())
}
pub async fn start_wss_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url) = {
let c = self.config.get();
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(),
)
};
trace!("WSS: starting listener at {}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
true,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WSS: listener started");
// NOTE: No local dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
// is specified, then TLS won't validate, so no local dialinfo is possible.
// This is not the case with unencrypted websockets, which can be specified solely by an IP address
//
if let Some(socket_address) = socket_addresses.first() {
// Pick out WSS port for outbound connections (they will all be the same)
self.inner.lock().wss_port = socket_address.port();
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
// Add static public dialinfo if it's configured
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
}
split_url.scheme = "wss".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
} else {
return Err("WSS URL must be specified due to TLS requirements".to_owned());
}
Ok(())
}
pub async fn start_tcp_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(),
)
};
trace!("TCP: starting listener at {}", &listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, _, a| Box::new(RawTcpProtocolHandler::new(n, a))),
)
.await?;
trace!("TCP: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out TCP port for outbound connections (they will all be the same)
self.inner.lock().tcp_port = socket_address.port();
let di = DialInfo::tcp(socket_address);
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::tcp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
None,
);
static_public = true;
}
}
self.inner.lock().tcp_static_public_dialinfo = static_public;
Ok(())
}
pub fn get_protocol_config(&self) -> Option<ProtocolConfig> {
self.inner.lock().protocol_config
}
@@ -1083,7 +414,6 @@ impl Network {
// start listeners
if protocol_config.udp_enabled {
self.start_udp_listeners().await?;
self.create_udp_outbound_sockets().await?;
}
if protocol_config.ws_listen {
self.start_ws_listeners().await?;

View File

@@ -0,0 +1,242 @@
use super::*;
use utils::clone_stream::*;
impl Network {
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
return Ok(ts.clone());
}
let server_config = self
.load_server_config()
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
self.inner.lock().tls_acceptor = Some(acceptor.clone());
Ok(acceptor)
}
async fn try_tls_handlers(
&self,
tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
tls_connection_initial_timeout: u64,
) {
match tls_acceptor.accept(stream).await {
Ok(ts) => {
let ps = AsyncPeekStream::new(CloneStream::new(ts));
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// Try the handlers but first get a chunk of data for them to process
// Don't waste more than N seconds getting it though, in case someone
// is trying to DoS us with a bunch of connections or something
// read a chunk of the stream
match io::timeout(
Duration::from_micros(tls_connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
{
Ok(()) => (),
Err(_) => return,
}
self.clone().try_handlers(ps, addr, protocol_handlers).await;
}
Err(e) => {
debug!("TLS stream failed handshake: {}", e);
}
}
}
async fn try_handlers(
&self,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
) {
for ah in protocol_handlers.iter() {
if ah.on_accept(stream.clone(), addr).await == Ok(true) {
return;
}
}
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
// Get config
let (connection_initial_timeout, tls_connection_initial_timeout) = {
let c = self.config.get();
(
c.network.connection_initial_timeout,
c.network.tls.connection_initial_timeout,
)
};
// Create a reusable socket with no linger time, and no delay
let socket = new_shared_tcp_socket(addr)?;
// Listen on the socket
socket
.listen(128)
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from(std_listener);
trace!("spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
self.inner
.lock()
.listener_states
.insert(addr, listener_state.clone());
// Spawn the socket task
let this = self.clone();
////////////////////////////////////////////////////////////
let jh = spawn(async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
listener
.incoming()
.for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap();
let listener_state = listener_state.clone();
// match tcp_stream.set_nodelay(true) {
// Ok(_) => (),
// _ => continue,
// };
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(err) => {
error!("failed to get peer address: {}", err);
return;
}
};
// XXX limiting
trace!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream);
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
trace!("reading chunk");
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
return;
}
// Run accept handlers on accepted stream
trace!("packet ready");
// Check is this could be TLS
let ls = listener_state.read().clone();
if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
trace!("trying TLS");
this.clone()
.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await;
} else {
trace!("not TLS");
this.clone()
.try_handlers(ps, addr, &ls.protocol_handlers)
.await;
}
})
.await;
trace!("exited incoming loop for {}", addr);
// Remove our listener state from this address if we're stopping
this.inner.lock().listener_states.remove(&addr);
trace!("listener state removed for {}", addr);
// If this happened our low-level listener socket probably died
// so it's time to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handles
self.add_to_join_handles(jh);
Ok(())
}
/////////////////////////////////////////////////////////////////
// TCP listener that multiplexes ports so multiple protocols can exist on a single port
pub(super) async fn start_tcp_listener(
&self,
address: String,
is_tls: bool,
new_tcp_protocol_handler: Box<NewTcpProtocolHandler>,
) -> Result<Vec<SocketAddress>, String> {
let mut out = Vec::<SocketAddress>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
let ldi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().listener_states.contains_key(&addr) {
self.clone().spawn_socket_listener(addr).await?;
}
let ls = if let Some(ls) = self.inner.lock().listener_states.get_mut(&addr) {
ls.clone()
} else {
panic!("this shouldn't happen");
};
if is_tls {
if ls.read().tls_acceptor.is_none() {
ls.write().tls_acceptor = Some(self.clone().get_or_create_tls_acceptor()?);
}
ls.write()
.tls_protocol_handlers
.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
true,
addr,
));
} else {
ls.write().protocol_handlers.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
false,
addr,
));
}
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(SocketAddress::from_socket_addr(ldi_addr));
}
}
Ok(out)
}
}

View File

@@ -0,0 +1,187 @@
use super::*;
use futures_util::stream;
impl Network {
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
c.network.protocol.udp.socket_pool_size
};
if task_count == 0 {
task_count = intf::get_concurrency() / 2;
if task_count == 0 {
task_count = 1;
}
}
trace!("task_count: {}", task_count);
for _ in 0..task_count {
trace!("Spawning UDP listener task");
////////////////////////////////////////////////////////////
// Run thread task to process stream of messages
let this = self.clone();
let jh = spawn(async move {
trace!("UDP listener task spawned");
// Collect all our protocol handlers into a vector
let mut protocol_handlers: Vec<RawUdpProtocolHandler> = this
.inner
.lock()
.inbound_udp_protocol_handlers
.values()
.cloned()
.collect();
if let Some(ph) = this.inner.lock().outbound_udpv4_protocol_handler.clone() {
protocol_handlers.push(ph);
}
if let Some(ph) = this.inner.lock().outbound_udpv6_protocol_handler.clone() {
protocol_handlers.push(ph);
}
// Spawn a local async task for each socket
let mut protocol_handlers_unordered = stream::FuturesUnordered::new();
for ph in protocol_handlers {
let jh = spawn_local(ph.clone().receive_loop());
protocol_handlers_unordered.push(jh);
}
// Now we wait for any join handle to exit,
// which would indicate an error needing
// us to completely restart the network
let _ = protocol_handlers_unordered.next().await;
trace!("UDP listener task stopped");
// If this loop fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handle
self.add_to_join_handles(jh);
}
Ok(())
}
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
let mut inner = self.inner.lock();
let mut port = inner.udp_port;
// v4
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v4) {
// Pull the port if we randomly bound, so v6 can be on the same port
port = socket
.local_addr()
.map_err(map_to_string)?
.as_socket_ipv4()
.ok_or_else(|| "expected ipv4 address type".to_owned())?
.port();
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv4_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv4_protocol_handler = Some(udpv4_handler);
}
//v6
let socket_addr_v6 =
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v6) {
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv6_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv6_protocol_handler = Some(udpv6_handler);
}
Ok(())
}
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
log_net!("create_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket
let socket = new_shared_udp_socket(addr)?;
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let protocol_handler =
RawUdpProtocolHandler::new(self.inner.lock().network_manager.clone(), socket_arc);
// Create message_handler records
self.inner
.lock()
.inbound_udp_protocol_handlers
.insert(addr, protocol_handler);
Ok(())
}
pub(super) async fn create_udp_inbound_sockets(
&self,
address: String,
) -> Result<Vec<DialInfo>, String> {
let mut out = Vec::<DialInfo>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
// see if we've already bound to this already
// if not, spawn a listener
if !self
.inner
.lock()
.inbound_udp_protocol_handlers
.contains_key(&addr)
{
let ldi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
self.clone().create_udp_inbound_socket(addr).await?;
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(DialInfo::udp_from_socketaddr(ldi_addr));
}
}
}
Ok(out)
}
/////////////////////////////////////////////////////////////////
pub(super) fn find_best_udp_protocol_handler(
&self,
peer_socket_addr: &SocketAddr,
local_socket_addr: &Option<SocketAddr>,
) -> Option<RawUdpProtocolHandler> {
// if our last communication with this peer came from a particular inbound udp protocol handler, use it
if let Some(sa) = local_socket_addr {
if let Some(ph) = self.inner.lock().inbound_udp_protocol_handlers.get(sa) {
return Some(ph.clone());
}
}
// otherwise find the outbound udp protocol handler that matches the ip protocol version of the peer addr
let inner = self.inner.lock();
match peer_socket_addr {
SocketAddr::V4(_) => inner.outbound_udpv4_protocol_handler.clone(),
SocketAddr::V6(_) => inner.outbound_udpv6_protocol_handler.clone(),
}
}
}

View File

@@ -72,6 +72,8 @@ pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socke
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind UDP socket: {}", e))?;
log_net!("created shared udp socket on {:?}", &local_address);
Ok(socket)
}

View File

@@ -30,6 +30,17 @@ impl RawUdpProtocolHandler {
}
}
pub async fn receive_loop(self) {
let mut data = vec![0u8; 65536];
let socket = self.inner.lock().socket.clone();
while let Ok((size, socket_addr)) = socket.recv_from(&mut data).await {
// XXX: Limit the number of packets from the same IP address?
trace!("UDP packet from: {}", socket_addr);
let _processed = self.clone().on_message(&data[..size], socket_addr).await;
}
}
pub async fn on_message(&self, data: &[u8], remote_addr: SocketAddr) -> Result<bool, String> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("received too large UDP message".to_owned());

View File

@@ -0,0 +1,275 @@
use super::*;
impl Network {
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
// First, create outbound sockets and we'll listen on them too
self.create_udp_outbound_sockets().await?;
// Now create udp inbound sockets for whatever interfaces we're listening on
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(),
)
};
info!("UDP: starting listener at {:?}", listen_address);
let dial_infos = self
.create_udp_inbound_sockets(listen_address.clone())
.await?;
let mut static_public = false;
for di in &dial_infos {
// Pick out UDP port for outbound connections (they will all be the same)
self.inner.lock().udp_port = di.port();
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::udp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
}
}
self.inner.lock().udp_static_public_dialinfo = static_public;
// Now create tasks for udp listeners
self.create_udp_listener_tasks().await
}
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url, path) = {
let c = self.config.get();
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
)
};
trace!("WS: starting listener at {:?}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WS: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out WS port for outbound connections (they will all be the same)
self.inner.lock().ws_port = socket_address.port();
if url.is_none() && socket_address.address().is_global() {
// Build global dial info request url
let global_url = format!("ws://{}/{}", socket_address, path);
// Create global dial info
let di = DialInfo::try_ws(socket_address, global_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(
di,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if socket_address.address().is_local() {
// Build local dial info request url
let local_url = format!("ws://{}/{}", socket_address, path);
// Create local dial info
let di = DialInfo::try_ws(socket_address, local_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(di, DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
}
split_url.scheme = "ws".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
static_public = true;
}
self.inner.lock().ws_static_public_dialinfo = static_public;
Ok(())
}
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url) = {
let c = self.config.get();
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(),
)
};
trace!("WSS: starting listener at {}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
true,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WSS: listener started");
// NOTE: No local dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
// is specified, then TLS won't validate, so no local dialinfo is possible.
// This is not the case with unencrypted websockets, which can be specified solely by an IP address
//
if let Some(socket_address) = socket_addresses.first() {
// Pick out WSS port for outbound connections (they will all be the same)
self.inner.lock().wss_port = socket_address.port();
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
// Add static public dialinfo if it's configured
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
}
split_url.scheme = "wss".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
} else {
return Err("WSS URL must be specified due to TLS requirements".to_owned());
}
Ok(())
}
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(),
)
};
trace!("TCP: starting listener at {}", &listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, _, a| Box::new(RawTcpProtocolHandler::new(n, a))),
)
.await?;
trace!("TCP: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out TCP port for outbound connections (they will all be the same)
self.inner.lock().tcp_port = socket_address.port();
let di = DialInfo::tcp(socket_address);
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::tcp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
None,
);
static_public = true;
}
}
self.inner.lock().tcp_static_public_dialinfo = static_public;
Ok(())
}
}