fix udp and refactor native network

This commit is contained in:
John Smith 2021-12-31 22:09:30 -05:00
parent c6f573ffe0
commit 0e0209a54b
18 changed files with 870 additions and 760 deletions

View File

@ -61,7 +61,7 @@ def read_until_local_dial_info(proc, proto):
if idx != -1: if idx != -1:
idx += len(local_dial_info_str) idx += len(local_dial_info_str)
di = ln[idx:] di = ln[idx:]
if di.startswith(proto): if b"@"+bytes(proto)+b"|" in di:
return di.decode("utf-8").strip() return di.decode("utf-8").strip()
return None return None
@ -95,7 +95,7 @@ def main():
help='specify subnode index to wait for the debugger') help='specify subnode index to wait for the debugger')
parser.add_argument("--config-file", type=str, parser.add_argument("--config-file", type=str,
help='configuration file to specify for the bootstrap node') help='configuration file to specify for the bootstrap node')
parser.add_argument("--protocol", type=bytes, default=b"udp", parser.add_argument("--protocol", type=str, default="udp",
help='default protocol to choose for dial info') help='default protocol to choose for dial info')
args = parser.parse_args() args = parser.parse_args()
@ -110,7 +110,6 @@ def main():
veilid_server_exe = veilid_server_exe_debug veilid_server_exe = veilid_server_exe_debug
base_args = [veilid_server_exe] base_args = [veilid_server_exe]
base_args.append("--attach=true")
if args.log_info: if args.log_info:
pass pass
elif args.log_trace: elif args.log_trace:
@ -131,7 +130,10 @@ def main():
main_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) main_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
print(">>> MAIN NODE PID={}".format(main_proc.pid)) print(">>> MAIN NODE PID={}".format(main_proc.pid))
main_di = read_until_local_dial_info(main_proc, args.protocol) main_di = read_until_local_dial_info(
main_proc, bytes(args.protocol, 'utf-8'))
print(">>> MAIN DIAL INFO={}".format(main_di))
threads.append( threads.append(
tee(b"Veilid-0: ", main_proc.stdout, open("/tmp/veilid-0-out", "wb"), tee(b"Veilid-0: ", main_proc.stdout, open("/tmp/veilid-0-out", "wb"),

View File

@ -102,7 +102,7 @@ struct DialInfo {
} }
} }
struct NodeDialInfoSingle { struct NodeDialInfo {
nodeId @0 :NodeID; # node id nodeId @0 :NodeID; # node id
dialInfo @1 :DialInfo; # how to get to the node dialInfo @1 :DialInfo; # how to get to the node
} }
@ -119,7 +119,7 @@ struct RouteHopData {
} }
struct RouteHop { struct RouteHop {
dialInfo @0 :NodeDialInfoSingle; # dial info for this hop dialInfo @0 :NodeDialInfo; # dial info for this hop
nextHop @1 :RouteHopData; # Optional: next hop in encrypted blob nextHop @1 :RouteHopData; # Optional: next hop in encrypted blob
# Null means no next hop, at destination (only used in private route, safety routes must enclose a stub private route) # Null means no next hop, at destination (only used in private route, safety routes must enclose a stub private route)
} }

View File

@ -1,6 +1,9 @@
mod listener_state; mod listener_state;
mod network_tcp;
mod network_udp;
mod protocol; mod protocol;
mod public_dialinfo_discovery; mod public_dialinfo_discovery;
mod start_protocols;
use crate::intf::*; use crate::intf::*;
use crate::network_manager::*; use crate::network_manager::*;
@ -12,7 +15,6 @@ use protocol::udp::RawUdpProtocolHandler;
use protocol::ws::WebsocketProtocolHandler; use protocol::ws::WebsocketProtocolHandler;
pub use protocol::*; pub use protocol::*;
use utils::async_peek_stream::*; use utils::async_peek_stream::*;
use utils::clone_stream::*;
use utils::network_interfaces::*; use utils::network_interfaces::*;
use async_std::io; use async_std::io;
@ -45,16 +47,18 @@ struct NetworkInner {
ws_static_public_dialinfo: bool, ws_static_public_dialinfo: bool,
network_class: Option<NetworkClass>, network_class: Option<NetworkClass>,
join_handles: Vec<JoinHandle<()>>, join_handles: Vec<JoinHandle<()>>,
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
tls_acceptor: Option<TlsAcceptor>,
udp_port: u16, udp_port: u16,
tcp_port: u16, tcp_port: u16,
ws_port: u16, ws_port: u16,
wss_port: u16, wss_port: u16,
interfaces: NetworkInterfaces,
// udp
inbound_udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
outbound_udpv4_protocol_handler: Option<RawUdpProtocolHandler>, outbound_udpv4_protocol_handler: Option<RawUdpProtocolHandler>,
outbound_udpv6_protocol_handler: Option<RawUdpProtocolHandler>, outbound_udpv6_protocol_handler: Option<RawUdpProtocolHandler>,
interfaces: NetworkInterfaces, //tcp
tls_acceptor: Option<TlsAcceptor>,
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
} }
struct NetworkUnlockedInner { struct NetworkUnlockedInner {
@ -83,16 +87,16 @@ impl Network {
ws_static_public_dialinfo: false, ws_static_public_dialinfo: false,
network_class: None, network_class: None,
join_handles: Vec::new(), join_handles: Vec::new(),
listener_states: BTreeMap::new(),
udp_protocol_handlers: BTreeMap::new(),
tls_acceptor: None,
udp_port: 0u16, udp_port: 0u16,
tcp_port: 0u16, tcp_port: 0u16,
ws_port: 0u16, ws_port: 0u16,
wss_port: 0u16, wss_port: 0u16,
interfaces: NetworkInterfaces::new(),
inbound_udp_protocol_handlers: BTreeMap::new(),
outbound_udpv4_protocol_handler: None, outbound_udpv4_protocol_handler: None,
outbound_udpv6_protocol_handler: None, outbound_udpv6_protocol_handler: None,
interfaces: NetworkInterfaces::new(), tls_acceptor: None,
listener_states: BTreeMap::new(),
} }
} }
@ -201,430 +205,21 @@ impl Network {
Ok(config) Ok(config)
} }
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
return Ok(ts.clone());
}
let server_config = self
.load_server_config()
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
self.inner.lock().tls_acceptor = Some(acceptor.clone());
Ok(acceptor)
}
fn add_to_join_handles(&self, jh: JoinHandle<()>) { fn add_to_join_handles(&self, jh: JoinHandle<()>) {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.join_handles.push(jh); inner.join_handles.push(jh);
} }
async fn try_tls_handlers(
&self,
tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
tls_connection_initial_timeout: u64,
) {
match tls_acceptor.accept(stream).await {
Ok(ts) => {
let ps = AsyncPeekStream::new(CloneStream::new(ts));
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// Try the handlers but first get a chunk of data for them to process
// Don't waste more than N seconds getting it though, in case someone
// is trying to DoS us with a bunch of connections or something
// read a chunk of the stream
match io::timeout(
Duration::from_micros(tls_connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
{
Ok(()) => (),
Err(_) => return,
}
self.clone().try_handlers(ps, addr, protocol_handlers).await;
}
Err(e) => {
debug!("TLS stream failed handshake: {}", e);
}
}
}
async fn try_handlers(
&self,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
) {
for ah in protocol_handlers.iter() {
if ah.on_accept(stream.clone(), addr).await == Ok(true) {
return;
}
}
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
// Get config
let (connection_initial_timeout, tls_connection_initial_timeout) = {
let c = self.config.get();
(
c.network.connection_initial_timeout,
c.network.tls.connection_initial_timeout,
)
};
// Create a reusable socket with no linger time, and no delay
let socket = new_shared_tcp_socket(addr)?;
// Listen on the socket
socket
.listen(128)
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from(std_listener);
trace!("spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
self.inner
.lock()
.listener_states
.insert(addr, listener_state.clone());
// Spawn the socket task
let this = self.clone();
////////////////////////////////////////////////////////////
let jh = spawn(async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
listener
.incoming()
.for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap();
let listener_state = listener_state.clone();
// match tcp_stream.set_nodelay(true) {
// Ok(_) => (),
// _ => continue,
// };
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(err) => {
error!("failed to get peer address: {}", err);
return;
}
};
// XXX limiting
trace!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream);
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
trace!("reading chunk");
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
return;
}
// Run accept handlers on accepted stream
trace!("packet ready");
// Check is this could be TLS
let ls = listener_state.read().clone();
if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
trace!("trying TLS");
this.clone()
.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await;
} else {
trace!("not TLS");
this.clone()
.try_handlers(ps, addr, &ls.protocol_handlers)
.await;
}
})
.await;
trace!("exited incoming loop for {}", addr);
// Remove our listener state from this address if we're stopping
this.inner.lock().listener_states.remove(&addr);
trace!("listener state removed for {}", addr);
// If this happened our low-level listener socket probably died
// so it's time to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handles
self.add_to_join_handles(jh);
Ok(())
}
/////////////////////////////////////////////////////////////////
// TCP listener that multiplexes ports so multiple protocols can exist on a single port
async fn start_tcp_listener(
&self,
address: String,
is_tls: bool,
new_tcp_protocol_handler: Box<NewTcpProtocolHandler>,
) -> Result<Vec<SocketAddress>, String> {
let mut out = Vec::<SocketAddress>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
let ldi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().listener_states.contains_key(&addr) {
self.clone().spawn_socket_listener(addr).await?;
}
let ls = if let Some(ls) = self.inner.lock().listener_states.get_mut(&addr) {
ls.clone()
} else {
panic!("this shouldn't happen");
};
if is_tls {
if ls.read().tls_acceptor.is_none() {
ls.write().tls_acceptor = Some(self.clone().get_or_create_tls_acceptor()?);
}
ls.write()
.tls_protocol_handlers
.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
true,
addr,
));
} else {
ls.write().protocol_handlers.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
false,
addr,
));
}
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(SocketAddress::from_socket_addr(ldi_addr));
}
}
Ok(out)
}
////////////////////////////////////////////////////////////
async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
let mut inner = self.inner.lock();
let mut port = inner.udp_port;
// v4
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v4) {
log_net!("created udpv4 outbound socket on {:?}", &socket_addr_v4);
// Pull the port if we randomly bound, so v6 can be on the same port
port = socket
.local_addr()
.map_err(map_to_string)?
.as_socket_ipv4()
.ok_or_else(|| "expected ipv4 address type".to_owned())?
.port();
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv4_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv4_protocol_handler = Some(udpv4_handler);
}
//v6
let socket_addr_v6 =
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v6) {
log_net!("created udpv6 outbound socket on {:?}", &socket_addr_v6);
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv6_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv6_protocol_handler = Some(udpv6_handler);
}
Ok(())
}
async fn spawn_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
log_net!("spawn_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket
let socket = new_shared_udp_socket(addr)?;
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let protocol_handler = RawUdpProtocolHandler::new(
self.inner.lock().network_manager.clone(),
socket_arc.clone(),
);
// Create message_handler records
self.inner
.lock()
.udp_protocol_handlers
.insert(addr, protocol_handler.clone());
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
c.network.protocol.udp.socket_pool_size
};
if task_count == 0 {
task_count = intf::get_concurrency() / 2;
if task_count == 0 {
task_count = 1;
}
}
trace!("task_count: {}", task_count);
for _ in 0..task_count {
let socket = socket_arc.clone();
let protocol_handler = protocol_handler.clone();
trace!("Spawning UDP listener task");
////////////////////////////////////////////////////////////
// Run task for messages
let this = self.clone();
let jh = spawn(async move {
trace!("UDP listener task spawned");
let mut data = vec![0u8; 65536];
while let Ok((size, socket_addr)) = socket.recv_from(&mut data).await {
// XXX: Limit the number of packets from the same IP address?
trace!("UDP packet from: {}", socket_addr);
let _processed = protocol_handler
.clone()
.on_message(&data[..size], socket_addr)
.await;
}
trace!("UDP listener task stopped");
// If this loop fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handle
self.add_to_join_handles(jh);
}
Ok(())
}
fn translate_unspecified_address(inner: &NetworkInner, from: &SocketAddr) -> Vec<SocketAddr> { fn translate_unspecified_address(inner: &NetworkInner, from: &SocketAddr) -> Vec<SocketAddr> {
if !from.ip().is_unspecified() { if !from.ip().is_unspecified() {
vec![*from] vec![*from]
} else { } else {
let mut out = Vec::<SocketAddr>::with_capacity(inner.interfaces.len()); inner
for (_, intf) in inner.interfaces.iter() { .interfaces
if intf.is_loopback() { .default_route_addresses()
continue; .iter()
} .map(|a| SocketAddr::new(*a, from.port()))
if let Some(pipv4) = intf.primary_ipv4() { .collect()
out.push(SocketAddr::new(IpAddr::V4(pipv4), from.port()));
}
if let Some(pipv6) = intf.primary_ipv6() {
out.push(SocketAddr::new(IpAddr::V6(pipv6), from.port()));
}
}
out
}
}
async fn start_udp_handler(&self, address: String) -> Result<Vec<DialInfo>, String> {
let mut out = Vec::<DialInfo>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().udp_protocol_handlers.contains_key(&addr) {
let ldi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
self.clone().spawn_udp_inbound_socket(addr).await?;
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(DialInfo::udp_from_socketaddr(ldi_addr));
}
}
}
Ok(out)
}
/////////////////////////////////////////////////////////////////
fn find_best_udp_protocol_handler(
&self,
peer_socket_addr: &SocketAddr,
local_socket_addr: &Option<SocketAddr>,
) -> Option<RawUdpProtocolHandler> {
// if our last communication with this peer came from a particular udp protocol handler, use it
if let Some(sa) = local_socket_addr {
if let Some(ph) = self.inner.lock().udp_protocol_handlers.get(sa) {
return Some(ph.clone());
}
}
// otherwise find the outbound udp protocol handler that matches the ip protocol version of the peer addr
let inner = self.inner.lock();
match peer_socket_addr {
SocketAddr::V4(_) => inner.outbound_udpv4_protocol_handler.clone(),
SocketAddr::V6(_) => inner.outbound_udpv6_protocol_handler.clone(),
} }
} }
@ -642,6 +237,8 @@ impl Network {
} }
} }
////////////////////////////////////////////////////////////
async fn send_data_to_existing_connection( async fn send_data_to_existing_connection(
&self, &self,
descriptor: &ConnectionDescriptor, descriptor: &ConnectionDescriptor,
@ -787,272 +384,6 @@ impl Network {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////
pub async fn start_udp_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(),
)
};
info!("UDP: starting listener at {:?}", listen_address);
let dial_infos = self.start_udp_handler(listen_address.clone()).await?;
let mut static_public = false;
for di in &dial_infos {
// Pick out UDP port for outbound connections (they will all be the same)
self.inner.lock().udp_port = di.port();
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::udp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
}
}
self.inner.lock().udp_static_public_dialinfo = static_public;
Ok(())
}
pub async fn start_ws_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url, path) = {
let c = self.config.get();
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
)
};
trace!("WS: starting listener at {:?}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WS: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out WS port for outbound connections (they will all be the same)
self.inner.lock().ws_port = socket_address.port();
if url.is_none() && socket_address.address().is_global() {
// Build global dial info request url
let global_url = format!("ws://{}/{}", socket_address, path);
// Create global dial info
let di = DialInfo::try_ws(socket_address, global_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(
di,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if socket_address.address().is_local() {
// Build local dial info request url
let local_url = format!("ws://{}/{}", socket_address, path);
// Create local dial info
let di = DialInfo::try_ws(socket_address, local_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(di, DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
}
split_url.scheme = "ws".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
static_public = true;
}
self.inner.lock().ws_static_public_dialinfo = static_public;
Ok(())
}
pub async fn start_wss_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url) = {
let c = self.config.get();
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(),
)
};
trace!("WSS: starting listener at {}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
true,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WSS: listener started");
// NOTE: No local dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
// is specified, then TLS won't validate, so no local dialinfo is possible.
// This is not the case with unencrypted websockets, which can be specified solely by an IP address
//
if let Some(socket_address) = socket_addresses.first() {
// Pick out WSS port for outbound connections (they will all be the same)
self.inner.lock().wss_port = socket_address.port();
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
// Add static public dialinfo if it's configured
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
}
split_url.scheme = "wss".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
} else {
return Err("WSS URL must be specified due to TLS requirements".to_owned());
}
Ok(())
}
pub async fn start_tcp_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(),
)
};
trace!("TCP: starting listener at {}", &listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, _, a| Box::new(RawTcpProtocolHandler::new(n, a))),
)
.await?;
trace!("TCP: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out TCP port for outbound connections (they will all be the same)
self.inner.lock().tcp_port = socket_address.port();
let di = DialInfo::tcp(socket_address);
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::tcp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
None,
);
static_public = true;
}
}
self.inner.lock().tcp_static_public_dialinfo = static_public;
Ok(())
}
pub fn get_protocol_config(&self) -> Option<ProtocolConfig> { pub fn get_protocol_config(&self) -> Option<ProtocolConfig> {
self.inner.lock().protocol_config self.inner.lock().protocol_config
} }
@ -1083,7 +414,6 @@ impl Network {
// start listeners // start listeners
if protocol_config.udp_enabled { if protocol_config.udp_enabled {
self.start_udp_listeners().await?; self.start_udp_listeners().await?;
self.create_udp_outbound_sockets().await?;
} }
if protocol_config.ws_listen { if protocol_config.ws_listen {
self.start_ws_listeners().await?; self.start_ws_listeners().await?;

View File

@ -0,0 +1,242 @@
use super::*;
use utils::clone_stream::*;
impl Network {
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
return Ok(ts.clone());
}
let server_config = self
.load_server_config()
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
self.inner.lock().tls_acceptor = Some(acceptor.clone());
Ok(acceptor)
}
async fn try_tls_handlers(
&self,
tls_acceptor: &TlsAcceptor,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
tls_connection_initial_timeout: u64,
) {
match tls_acceptor.accept(stream).await {
Ok(ts) => {
let ps = AsyncPeekStream::new(CloneStream::new(ts));
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// Try the handlers but first get a chunk of data for them to process
// Don't waste more than N seconds getting it though, in case someone
// is trying to DoS us with a bunch of connections or something
// read a chunk of the stream
match io::timeout(
Duration::from_micros(tls_connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
{
Ok(()) => (),
Err(_) => return,
}
self.clone().try_handlers(ps, addr, protocol_handlers).await;
}
Err(e) => {
debug!("TLS stream failed handshake: {}", e);
}
}
}
async fn try_handlers(
&self,
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
) {
for ah in protocol_handlers.iter() {
if ah.on_accept(stream.clone(), addr).await == Ok(true) {
return;
}
}
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
// Get config
let (connection_initial_timeout, tls_connection_initial_timeout) = {
let c = self.config.get();
(
c.network.connection_initial_timeout,
c.network.tls.connection_initial_timeout,
)
};
// Create a reusable socket with no linger time, and no delay
let socket = new_shared_tcp_socket(addr)?;
// Listen on the socket
socket
.listen(128)
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
let listener = TcpListener::from(std_listener);
trace!("spawn_socket_listener: binding successful to {}", addr);
// Create protocol handler records
let listener_state = Arc::new(RwLock::new(ListenerState::new()));
self.inner
.lock()
.listener_states
.insert(addr, listener_state.clone());
// Spawn the socket task
let this = self.clone();
////////////////////////////////////////////////////////////
let jh = spawn(async move {
// moves listener object in and get incoming iterator
// when this task exists, the listener will close the socket
listener
.incoming()
.for_each_concurrent(None, |tcp_stream| async {
let tcp_stream = tcp_stream.unwrap();
let listener_state = listener_state.clone();
// match tcp_stream.set_nodelay(true) {
// Ok(_) => (),
// _ => continue,
// };
// Limit the number of connections from the same IP address
// and the number of total connections
let addr = match tcp_stream.peer_addr() {
Ok(addr) => addr,
Err(err) => {
error!("failed to get peer address: {}", err);
return;
}
};
// XXX limiting
trace!("TCP connection from: {}", addr);
// Create a stream we can peek on
let ps = AsyncPeekStream::new(tcp_stream);
/////////////////////////////////////////////////////////////
let mut first_packet = [0u8; PEEK_DETECT_LEN];
// read a chunk of the stream
trace!("reading chunk");
if io::timeout(
Duration::from_micros(connection_initial_timeout),
ps.peek_exact(&mut first_packet),
)
.await
.is_err()
{
// If we fail to get a packet within the connection initial timeout
// then we punt this connection
return;
}
// Run accept handlers on accepted stream
trace!("packet ready");
// Check is this could be TLS
let ls = listener_state.read().clone();
if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
trace!("trying TLS");
this.clone()
.try_tls_handlers(
ls.tls_acceptor.as_ref().unwrap(),
ps,
addr,
&ls.tls_protocol_handlers,
tls_connection_initial_timeout,
)
.await;
} else {
trace!("not TLS");
this.clone()
.try_handlers(ps, addr, &ls.protocol_handlers)
.await;
}
})
.await;
trace!("exited incoming loop for {}", addr);
// Remove our listener state from this address if we're stopping
this.inner.lock().listener_states.remove(&addr);
trace!("listener state removed for {}", addr);
// If this happened our low-level listener socket probably died
// so it's time to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handles
self.add_to_join_handles(jh);
Ok(())
}
/////////////////////////////////////////////////////////////////
// TCP listener that multiplexes ports so multiple protocols can exist on a single port
pub(super) async fn start_tcp_listener(
&self,
address: String,
is_tls: bool,
new_tcp_protocol_handler: Box<NewTcpProtocolHandler>,
) -> Result<Vec<SocketAddress>, String> {
let mut out = Vec::<SocketAddress>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
let ldi_addrs = Self::translate_unspecified_address(&*(self.inner.lock()), &addr);
// see if we've already bound to this already
// if not, spawn a listener
if !self.inner.lock().listener_states.contains_key(&addr) {
self.clone().spawn_socket_listener(addr).await?;
}
let ls = if let Some(ls) = self.inner.lock().listener_states.get_mut(&addr) {
ls.clone()
} else {
panic!("this shouldn't happen");
};
if is_tls {
if ls.read().tls_acceptor.is_none() {
ls.write().tls_acceptor = Some(self.clone().get_or_create_tls_acceptor()?);
}
ls.write()
.tls_protocol_handlers
.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
true,
addr,
));
} else {
ls.write().protocol_handlers.push(new_tcp_protocol_handler(
self.inner.lock().network_manager.clone(),
false,
addr,
));
}
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(SocketAddress::from_socket_addr(ldi_addr));
}
}
Ok(out)
}
}

View File

@ -0,0 +1,187 @@
use super::*;
use futures_util::stream;
impl Network {
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
c.network.protocol.udp.socket_pool_size
};
if task_count == 0 {
task_count = intf::get_concurrency() / 2;
if task_count == 0 {
task_count = 1;
}
}
trace!("task_count: {}", task_count);
for _ in 0..task_count {
trace!("Spawning UDP listener task");
////////////////////////////////////////////////////////////
// Run thread task to process stream of messages
let this = self.clone();
let jh = spawn(async move {
trace!("UDP listener task spawned");
// Collect all our protocol handlers into a vector
let mut protocol_handlers: Vec<RawUdpProtocolHandler> = this
.inner
.lock()
.inbound_udp_protocol_handlers
.values()
.cloned()
.collect();
if let Some(ph) = this.inner.lock().outbound_udpv4_protocol_handler.clone() {
protocol_handlers.push(ph);
}
if let Some(ph) = this.inner.lock().outbound_udpv6_protocol_handler.clone() {
protocol_handlers.push(ph);
}
// Spawn a local async task for each socket
let mut protocol_handlers_unordered = stream::FuturesUnordered::new();
for ph in protocol_handlers {
let jh = spawn_local(ph.clone().receive_loop());
protocol_handlers_unordered.push(jh);
}
// Now we wait for any join handle to exit,
// which would indicate an error needing
// us to completely restart the network
let _ = protocol_handlers_unordered.next().await;
trace!("UDP listener task stopped");
// If this loop fails, our socket died and we need to restart the network
this.inner.lock().network_needs_restart = true;
});
////////////////////////////////////////////////////////////
// Add to join handle
self.add_to_join_handles(jh);
}
Ok(())
}
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
let mut inner = self.inner.lock();
let mut port = inner.udp_port;
// v4
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v4) {
// Pull the port if we randomly bound, so v6 can be on the same port
port = socket
.local_addr()
.map_err(map_to_string)?
.as_socket_ipv4()
.ok_or_else(|| "expected ipv4 address type".to_owned())?
.port();
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv4_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv4_protocol_handler = Some(udpv4_handler);
}
//v6
let socket_addr_v6 =
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
if let Ok(socket) = new_shared_udp_socket(socket_addr_v6) {
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let udpv6_handler =
RawUdpProtocolHandler::new(inner.network_manager.clone(), socket_arc);
inner.outbound_udpv6_protocol_handler = Some(udpv6_handler);
}
Ok(())
}
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
log_net!("create_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket
let socket = new_shared_udp_socket(addr)?;
// Make an async UdpSocket from the socket2 socket
let std_udp_socket: std::net::UdpSocket = socket.into();
let udp_socket = UdpSocket::from(std_udp_socket);
let socket_arc = Arc::new(udp_socket);
// Create protocol handler
let protocol_handler =
RawUdpProtocolHandler::new(self.inner.lock().network_manager.clone(), socket_arc);
// Create message_handler records
self.inner
.lock()
.inbound_udp_protocol_handlers
.insert(addr, protocol_handler);
Ok(())
}
pub(super) async fn create_udp_inbound_sockets(
&self,
address: String,
) -> Result<Vec<DialInfo>, String> {
let mut out = Vec::<DialInfo>::new();
// convert to socketaddrs
let mut sockaddrs = address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", address, e))?;
for addr in &mut sockaddrs {
// see if we've already bound to this already
// if not, spawn a listener
if !self
.inner
.lock()
.inbound_udp_protocol_handlers
.contains_key(&addr)
{
let ldi_addrs = Self::translate_unspecified_address(&*self.inner.lock(), &addr);
self.clone().create_udp_inbound_socket(addr).await?;
// Return local dial infos we listen on
for ldi_addr in ldi_addrs {
out.push(DialInfo::udp_from_socketaddr(ldi_addr));
}
}
}
Ok(out)
}
/////////////////////////////////////////////////////////////////
pub(super) fn find_best_udp_protocol_handler(
&self,
peer_socket_addr: &SocketAddr,
local_socket_addr: &Option<SocketAddr>,
) -> Option<RawUdpProtocolHandler> {
// if our last communication with this peer came from a particular inbound udp protocol handler, use it
if let Some(sa) = local_socket_addr {
if let Some(ph) = self.inner.lock().inbound_udp_protocol_handlers.get(sa) {
return Some(ph.clone());
}
}
// otherwise find the outbound udp protocol handler that matches the ip protocol version of the peer addr
let inner = self.inner.lock();
match peer_socket_addr {
SocketAddr::V4(_) => inner.outbound_udpv4_protocol_handler.clone(),
SocketAddr::V6(_) => inner.outbound_udpv6_protocol_handler.clone(),
}
}
}

View File

@ -72,6 +72,8 @@ pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socke
.bind(&socket2_addr) .bind(&socket2_addr)
.map_err(|e| format!("failed to bind UDP socket: {}", e))?; .map_err(|e| format!("failed to bind UDP socket: {}", e))?;
log_net!("created shared udp socket on {:?}", &local_address);
Ok(socket) Ok(socket)
} }

View File

@ -30,6 +30,17 @@ impl RawUdpProtocolHandler {
} }
} }
pub async fn receive_loop(self) {
let mut data = vec![0u8; 65536];
let socket = self.inner.lock().socket.clone();
while let Ok((size, socket_addr)) = socket.recv_from(&mut data).await {
// XXX: Limit the number of packets from the same IP address?
trace!("UDP packet from: {}", socket_addr);
let _processed = self.clone().on_message(&data[..size], socket_addr).await;
}
}
pub async fn on_message(&self, data: &[u8], remote_addr: SocketAddr) -> Result<bool, String> { pub async fn on_message(&self, data: &[u8], remote_addr: SocketAddr) -> Result<bool, String> {
if data.len() > MAX_MESSAGE_SIZE { if data.len() > MAX_MESSAGE_SIZE {
return Err("received too large UDP message".to_owned()); return Err("received too large UDP message".to_owned());

View File

@ -0,0 +1,275 @@
use super::*;
impl Network {
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
// First, create outbound sockets and we'll listen on them too
self.create_udp_outbound_sockets().await?;
// Now create udp inbound sockets for whatever interfaces we're listening on
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.udp.listen_address.clone(),
c.network.protocol.udp.public_address.clone(),
)
};
info!("UDP: starting listener at {:?}", listen_address);
let dial_infos = self
.create_udp_inbound_sockets(listen_address.clone())
.await?;
let mut static_public = false;
for di in &dial_infos {
// Pick out UDP port for outbound connections (they will all be the same)
self.inner.lock().udp_port = di.port();
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::udp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
}
}
self.inner.lock().udp_static_public_dialinfo = static_public;
// Now create tasks for udp listeners
self.create_udp_listener_tasks().await
}
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url, path) = {
let c = self.config.get();
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
)
};
trace!("WS: starting listener at {:?}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WS: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out WS port for outbound connections (they will all be the same)
self.inner.lock().ws_port = socket_address.port();
if url.is_none() && socket_address.address().is_global() {
// Build global dial info request url
let global_url = format!("ws://{}/{}", socket_address, path);
// Create global dial info
let di = DialInfo::try_ws(socket_address, global_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(
di,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if socket_address.address().is_local() {
// Build local dial info request url
let local_url = format!("ws://{}/{}", socket_address, path);
// Create local dial info
let di = DialInfo::try_ws(socket_address, local_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
routing_table.register_dial_info(di, DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
}
split_url.scheme = "ws".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
static_public = true;
}
self.inner.lock().ws_static_public_dialinfo = static_public;
Ok(())
}
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, url) = {
let c = self.config.get();
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.url.clone(),
)
};
trace!("WSS: starting listener at {}", listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
true,
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
)
.await?;
trace!("WSS: listener started");
// NOTE: No local dial info for WSS, as there is no way to connect to a local dialinfo via TLS
// If the hostname is specified, it is the public dialinfo via the URL. If no hostname
// is specified, then TLS won't validate, so no local dialinfo is possible.
// This is not the case with unencrypted websockets, which can be specified solely by an IP address
//
if let Some(socket_address) = socket_addresses.first() {
// Pick out WSS port for outbound connections (they will all be the same)
self.inner.lock().wss_port = socket_address.port();
}
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
// Add static public dialinfo if it's configured
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
}
split_url.scheme = "wss".to_owned();
// Resolve static public hostnames
let global_socket_addrs = split_url
.host
.to_socket_addrs()
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
for gsa in global_socket_addrs {
routing_table.register_dial_info(
DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?,
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
}
} else {
return Err("WSS URL must be specified due to TLS requirements".to_owned());
}
Ok(())
}
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address) = {
let c = self.config.get();
(
c.network.protocol.tcp.listen_address.clone(),
c.network.protocol.tcp.public_address.clone(),
)
};
trace!("TCP: starting listener at {}", &listen_address);
let socket_addresses = self
.start_tcp_listener(
listen_address.clone(),
false,
Box::new(|n, _, a| Box::new(RawTcpProtocolHandler::new(n, a))),
)
.await?;
trace!("TCP: listener started");
let mut static_public = false;
for socket_address in socket_addresses {
// Pick out TCP port for outbound connections (they will all be the same)
self.inner.lock().tcp_port = socket_address.port();
let di = DialInfo::tcp(socket_address);
// Register local dial info only here if we specify a public address
if public_address.is_none() && di.is_global() {
// Register global dial info if no public address is specified
routing_table.register_dial_info(
di.clone(),
DialInfoOrigin::Static,
Some(NetworkClass::Server),
);
static_public = true;
} else if di.is_local() {
// Register local dial info
routing_table.register_dial_info(di.clone(), DialInfoOrigin::Static, None);
}
}
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.await
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
routing_table.register_dial_info(
DialInfo::tcp_from_socketaddr(pdi_addr),
DialInfoOrigin::Static,
None,
);
static_public = true;
}
}
self.inner.lock().tcp_static_public_dialinfo = static_public;
Ok(())
}
}

View File

@ -219,7 +219,7 @@ impl RoutingTable {
} else { } else {
"Other " "Other "
}, },
NodeDialInfoSingle { NodeDialInfo {
node_id: NodeId::new(inner.node_id), node_id: NodeId::new(inner.node_id),
dial_info dial_info
} }
@ -507,7 +507,7 @@ impl RoutingTable {
// Map all bootstrap entries to a single key with multiple dialinfo // Map all bootstrap entries to a single key with multiple dialinfo
let mut bsmap: BTreeMap<DHTKey, Vec<DialInfo>> = BTreeMap::new(); let mut bsmap: BTreeMap<DHTKey, Vec<DialInfo>> = BTreeMap::new();
for b in bootstrap { for b in bootstrap {
let ndis = NodeDialInfoSingle::from_str(b.as_str()) let ndis = NodeDialInfo::from_str(b.as_str())
.map_err(map_to_string) .map_err(map_to_string)
.map_err(logthru_rtab!("Invalid dial info in bootstrap entry: {}", b))?; .map_err(logthru_rtab!("Invalid dial info in bootstrap entry: {}", b))?;
let node_id = ndis.node_id.key; let node_id = ndis.node_id.key;

View File

@ -70,13 +70,17 @@ pub fn encode_dial_info(
&ws.socket_address, &ws.socket_address,
&mut di_ws_builder.reborrow().init_socket_address(), &mut di_ws_builder.reborrow().init_socket_address(),
)?; )?;
let request = dial_info
.request()
.ok_or_else(|| rpc_error_internal("no request for WS dialinfo"))?;
let mut requestb = di_ws_builder.init_request( let mut requestb = di_ws_builder.init_request(
ws.request request
.len() .len()
.try_into() .try_into()
.map_err(map_error_protocol!("request too long"))?, .map_err(map_error_protocol!("request too long"))?,
); );
requestb.push_str(ws.request.as_str()); requestb.push_str(request.as_str());
} }
DialInfo::WSS(wss) => { DialInfo::WSS(wss) => {
let mut di_wss_builder = builder.reborrow().init_wss(); let mut di_wss_builder = builder.reborrow().init_wss();
@ -84,13 +88,17 @@ pub fn encode_dial_info(
&wss.socket_address, &wss.socket_address,
&mut di_wss_builder.reborrow().init_socket_address(), &mut di_wss_builder.reborrow().init_socket_address(),
)?; )?;
let request = dial_info
.request()
.ok_or_else(|| rpc_error_internal("no request for WSS dialinfo"))?;
let mut requestb = di_wss_builder.init_request( let mut requestb = di_wss_builder.init_request(
wss.request request
.len() .len()
.try_into() .try_into()
.map_err(map_error_protocol!("request too long"))?, .map_err(map_error_protocol!("request too long"))?,
); );
requestb.push_str(wss.request.as_str()); requestb.push_str(request.as_str());
} }
}; };
Ok(()) Ok(())

View File

@ -1,6 +1,6 @@
mod address; mod address;
mod dial_info; mod dial_info;
mod node_dial_info_single; mod node_dial_info;
mod node_info; mod node_info;
mod nonce; mod nonce;
mod peer_info; mod peer_info;
@ -11,7 +11,7 @@ mod socket_address;
pub use address::*; pub use address::*;
pub use dial_info::*; pub use dial_info::*;
pub use node_dial_info_single::*; pub use node_dial_info::*;
pub use node_info::*; pub use node_info::*;
pub use nonce::*; pub use nonce::*;
pub use peer_info::*; pub use peer_info::*;

View File

@ -0,0 +1,33 @@
use crate::*;
use rpc_processor::*;
pub fn encode_node_dial_info(
ndis: &NodeDialInfo,
builder: &mut veilid_capnp::node_dial_info::Builder,
) -> Result<(), RPCError> {
let mut ni_builder = builder.reborrow().init_node_id();
encode_public_key(&ndis.node_id.key, &mut ni_builder)?;
let mut di_builder = builder.reborrow().init_dial_info();
encode_dial_info(&ndis.dial_info, &mut di_builder)?;
Ok(())
}
pub fn decode_node_dial_info(
reader: &veilid_capnp::node_dial_info::Reader,
) -> Result<NodeDialInfo, RPCError> {
let node_id = decode_public_key(
&reader
.get_node_id()
.map_err(map_error_protocol!("invalid public key in node_dial_info"))?,
);
let dial_info = decode_dial_info(
&reader
.get_dial_info()
.map_err(map_error_protocol!("invalid dial_info in node_dial_info"))?,
)?;
Ok(NodeDialInfo {
node_id: NodeId::new(node_id),
dial_info,
})
}

View File

@ -1,29 +0,0 @@
use crate::*;
use rpc_processor::*;
pub fn encode_node_dial_info_single(
ndis: &NodeDialInfoSingle,
builder: &mut veilid_capnp::node_dial_info_single::Builder,
) -> Result<(), RPCError> {
let mut ni_builder = builder.reborrow().init_node_id();
encode_public_key(&ndis.node_id.key, &mut ni_builder)?;
let mut di_builder = builder.reborrow().init_dial_info();
encode_dial_info(&ndis.dial_info, &mut di_builder)?;
Ok(())
}
pub fn decode_node_dial_info_single(
reader: &veilid_capnp::node_dial_info_single::Reader,
) -> Result<NodeDialInfoSingle, RPCError> {
let node_id = decode_public_key(&reader.get_node_id().map_err(map_error_protocol!(
"invalid public key in node_dial_info_single"
))?);
let dial_info = decode_dial_info(&reader.get_dial_info().map_err(map_error_protocol!(
"invalid dial_info in node_dial_info_single"
))?)?;
Ok(NodeDialInfoSingle {
node_id: NodeId::new(node_id),
dial_info,
})
}

View File

@ -13,7 +13,7 @@ pub struct RouteHopData {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RouteHop { pub struct RouteHop {
pub dial_info: NodeDialInfoSingle, pub dial_info: NodeDialInfo,
pub next_hop: Option<RouteHopData>, pub next_hop: Option<RouteHopData>,
} }
@ -61,7 +61,7 @@ pub fn encode_route_hop(
route_hop: &RouteHop, route_hop: &RouteHop,
builder: &mut veilid_capnp::route_hop::Builder, builder: &mut veilid_capnp::route_hop::Builder,
) -> Result<(), RPCError> { ) -> Result<(), RPCError> {
encode_node_dial_info_single( encode_node_dial_info(
&route_hop.dial_info, &route_hop.dial_info,
&mut builder.reborrow().init_dial_info(), &mut builder.reborrow().init_dial_info(),
)?; )?;
@ -133,7 +133,7 @@ pub fn decode_route_hop_data(
} }
pub fn decode_route_hop(reader: &veilid_capnp::route_hop::Reader) -> Result<RouteHop, RPCError> { pub fn decode_route_hop(reader: &veilid_capnp::route_hop::Reader) -> Result<RouteHop, RPCError> {
let dial_info = decode_node_dial_info_single( let dial_info = decode_node_dial_info(
&reader &reader
.reborrow() .reborrow()
.get_dial_info() .get_dial_info()

View File

@ -78,7 +78,7 @@ impl RPCProcessor {
let mut rh_message = ::capnp::message::Builder::new_default(); let mut rh_message = ::capnp::message::Builder::new_default();
let mut rh_builder = rh_message.init_root::<veilid_capnp::route_hop::Builder>(); let mut rh_builder = rh_message.init_root::<veilid_capnp::route_hop::Builder>();
let mut di_builder = rh_builder.reborrow().init_dial_info(); let mut di_builder = rh_builder.reborrow().init_dial_info();
encode_node_dial_info_single(&safety_route.hops[h].dial_info, &mut di_builder)?; encode_node_dial_info(&safety_route.hops[h].dial_info, &mut di_builder)?;
// RouteHopData // RouteHopData
let mut rhd_builder = rh_builder.init_next_hop(); let mut rhd_builder = rh_builder.init_next_hop();
// Add the nonce // Add the nonce

View File

@ -577,6 +577,14 @@ impl DialInfo {
Self::WSS(di) => di.socket_address.port, Self::WSS(di) => di.socket_address.port,
} }
} }
pub fn set_port(&mut self, port: u16) {
match self {
Self::UDP(di) => di.socket_address.port = port,
Self::TCP(di) => di.socket_address.port = port,
Self::WS(di) => di.socket_address.port = port,
Self::WSS(di) => di.socket_address.port = port,
}
}
pub fn to_socket_addr(&self) -> SocketAddr { pub fn to_socket_addr(&self) -> SocketAddr {
match self { match self {
Self::UDP(di) => di.socket_address.to_socket_addr(), Self::UDP(di) => di.socket_address.to_socket_addr(),
@ -757,40 +765,37 @@ impl MatchesDialInfoFilter for ConnectionDescriptor {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
pub struct NodeDialInfoSingle { pub struct NodeDialInfo {
pub node_id: NodeId, pub node_id: NodeId,
pub dial_info: DialInfo, pub dial_info: DialInfo,
} }
impl fmt::Display for NodeDialInfoSingle { impl fmt::Display for NodeDialInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "{}@{}", self.node_id, self.dial_info) write!(f, "{}@{}", self.node_id, self.dial_info)
} }
} }
impl FromStr for NodeDialInfoSingle { impl FromStr for NodeDialInfo {
type Err = VeilidAPIError; type Err = VeilidAPIError;
fn from_str(s: &str) -> Result<NodeDialInfoSingle, VeilidAPIError> { fn from_str(s: &str) -> Result<NodeDialInfo, VeilidAPIError> {
// split out node id from the dial info // split out node id from the dial info
let (node_id_str, rest) = s.split_once('@').ok_or_else(|| { let (node_id_str, rest) = s
parse_error!( .split_once('@')
"NodeDialInfoSingle::from_str missing @ node id separator", .ok_or_else(|| parse_error!("NodeDialInfo::from_str missing @ node id separator", s))?;
s
)
})?;
// parse out node id // parse out node id
let node_id = NodeId::new(DHTKey::try_decode(node_id_str).map_err(|e| { let node_id = NodeId::new(DHTKey::try_decode(node_id_str).map_err(|e| {
parse_error!( parse_error!(
format!("NodeDialInfoSingle::from_str couldn't parse node id: {}", e), format!("NodeDialInfo::from_str couldn't parse node id: {}", e),
s s
) )
})?); })?);
// parse out dial info // parse out dial info
let dial_info = DialInfo::from_str(rest)?; let dial_info = DialInfo::from_str(rest)?;
// return completed NodeDialInfoSingle // return completed NodeDialInfo
Ok(NodeDialInfoSingle { node_id, dial_info }) Ok(NodeDialInfo { node_id, dial_info })
} }
} }
@ -890,7 +895,7 @@ pub struct PartialTunnel {
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct RouteHopSpec { pub struct RouteHopSpec {
pub dial_info: NodeDialInfoSingle, pub dial_info: NodeDialInfo,
} }
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]

View File

@ -19,7 +19,7 @@ daemon: false
client_api: client_api:
enabled: true enabled: true
listen_address: "localhost:5959" listen_address: "localhost:5959"
auto_attach: false auto_attach: true
logging: logging:
terminal: terminal:
enabled: true enabled: true
@ -223,6 +223,46 @@ impl<'de> serde::Deserialize<'de> for ParsedUrl {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub struct ParsedNodeDialInfo {
pub node_dial_info_string: String,
pub node_dial_info: veilid_core::NodeDialInfo,
}
// impl ParsedNodeDialInfo {
// pub fn offset_port(&mut self, offset: u16) -> Result<(), ()> {
// // Bump port on dial_info
// self.node_dial_info
// .dial_info
// .set_port(self.node_dial_info.dial_info.port() + 1);
// self.node_dial_info_string = self.node_dial_info.to_string();
// Ok(())
// }
// }
impl FromStr for ParsedNodeDialInfo {
type Err = veilid_core::VeilidAPIError;
fn from_str(
node_dial_info_string: &str,
) -> Result<ParsedNodeDialInfo, veilid_core::VeilidAPIError> {
let node_dial_info = veilid_core::NodeDialInfo::from_str(node_dial_info_string)?;
Ok(Self {
node_dial_info_string: node_dial_info_string.to_owned(),
node_dial_info,
})
}
}
impl<'de> serde::Deserialize<'de> for ParsedNodeDialInfo {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
ParsedNodeDialInfo::from_str(s.as_str()).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct NamedSocketAddrs { pub struct NamedSocketAddrs {
pub name: String, pub name: String,
@ -420,7 +460,7 @@ pub struct Network {
pub connection_initial_timeout: u64, pub connection_initial_timeout: u64,
pub node_id: veilid_core::DHTKey, pub node_id: veilid_core::DHTKey,
pub node_id_secret: veilid_core::DHTKeySecret, pub node_id_secret: veilid_core::DHTKeySecret,
pub bootstrap: Vec<ParsedUrl>, pub bootstrap: Vec<ParsedNodeDialInfo>,
pub rpc: Rpc, pub rpc: Rpc,
pub dht: Dht, pub dht: Dht,
pub upnp: bool, pub upnp: bool,
@ -634,7 +674,7 @@ impl Settings {
.bootstrap .bootstrap
.clone() .clone()
.into_iter() .into_iter()
.map(|e| e.urlstring) .map(|e| e.node_dial_info_string)
.collect::<Vec<String>>(), .collect::<Vec<String>>(),
)), )),
"network.rpc.concurrency" => Ok(Box::new(inner.core.network.rpc.concurrency)), "network.rpc.concurrency" => Ok(Box::new(inner.core.network.rpc.concurrency)),
@ -933,7 +973,7 @@ mod tests {
.unwrap() .unwrap()
.collect::<Vec<SocketAddr>>() .collect::<Vec<SocketAddr>>()
); );
assert_eq!(s.auto_attach, false); assert_eq!(s.auto_attach, true);
assert_eq!(s.logging.terminal.enabled, true); assert_eq!(s.logging.terminal.enabled, true);
assert_eq!(s.logging.terminal.level, LogLevel::Info); assert_eq!(s.logging.terminal.level, LogLevel::Info);
assert_eq!(s.logging.file.enabled, false); assert_eq!(s.logging.file.enabled, false);

View File

@ -156,17 +156,20 @@ pub async fn main() -> Result<(), String> {
settingsrw.logging.terminal.level = settings::LogLevel::Trace; settingsrw.logging.terminal.level = settings::LogLevel::Trace;
} }
if matches.is_present("attach") { if matches.is_present("attach") {
settingsrw.auto_attach = !matches!(matches.value_of("attach"), Some("false")); settingsrw.auto_attach = !matches!(matches.value_of("attach"), Some("true"));
} }
if matches.occurrences_of("bootstrap") != 0 { if matches.occurrences_of("bootstrap") != 0 {
let bootstrap = match matches.value_of("bootstrap") { let bootstrap = match matches.value_of("bootstrap") {
Some(x) => { Some(x) => {
println!("Overriding bootstrap with: "); println!("Overriding bootstrap with: ");
let mut out: Vec<settings::ParsedUrl> = Vec::new(); let mut out: Vec<settings::ParsedNodeDialInfo> = Vec::new();
for x in x.split(',') { for x in x.split(',') {
println!(" {}", x); println!(" {}", x);
out.push(settings::ParsedUrl::from_str(x).map_err(|e| { out.push(settings::ParsedNodeDialInfo::from_str(x).map_err(|e| {
format!("unable to parse url in bootstrap list: {} for {}", e, x) format!(
"unable to parse dial info in bootstrap list: {} for {}",
e, x
)
})?); })?);
} }
out out
@ -197,6 +200,7 @@ pub async fn main() -> Result<(), String> {
cb.add_filter_ignore_str("async_tungstenite"); cb.add_filter_ignore_str("async_tungstenite");
cb.add_filter_ignore_str("tungstenite"); cb.add_filter_ignore_str("tungstenite");
cb.add_filter_ignore_str("netlink_proto"); cb.add_filter_ignore_str("netlink_proto");
cb.add_filter_ignore_str("netlink_sys");
if settingsr.logging.terminal.enabled { if settingsr.logging.terminal.enabled {
logs.push(TermLogger::new( logs.push(TermLogger::new(