network refactor for connection manager
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
mod table_db;
|
||||
|
||||
mod user_secret;
|
||||
use crate::xx::*;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
pub use user_secret::*;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod wasm;
|
||||
@@ -11,44 +11,3 @@ pub use wasm::*;
|
||||
mod native;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use native::*;
|
||||
|
||||
pub async fn save_user_secret(namespace: &str, key: &str, value: &[u8]) -> Result<bool, String> {
|
||||
let mut s = BASE64URL_NOPAD.encode(value);
|
||||
s.push('!');
|
||||
|
||||
save_user_secret_string(namespace, key, s.as_str()).await
|
||||
}
|
||||
|
||||
pub async fn load_user_secret(namespace: &str, key: &str) -> Result<Option<Vec<u8>>, String> {
|
||||
let mut s = match load_user_secret_string(namespace, key).await? {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
if s.pop() != Some('!') {
|
||||
return Err("User secret is not a buffer".to_owned());
|
||||
}
|
||||
|
||||
let mut bytes = Vec::<u8>::new();
|
||||
let res = BASE64URL_NOPAD.decode_len(s.len());
|
||||
match res {
|
||||
Ok(l) => {
|
||||
bytes.resize(l, 0u8);
|
||||
}
|
||||
Err(_) => {
|
||||
return Err("Failed to decode".to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let res = BASE64URL_NOPAD.decode_mut(s.as_bytes(), &mut bytes);
|
||||
match res {
|
||||
Ok(_) => Ok(Some(bytes)),
|
||||
Err(_) => Err("Failed to decode".to_owned()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_user_secret(namespace: &str, key: &str) -> Result<bool, String> {
|
||||
remove_user_secret_string(namespace, key).await
|
||||
}
|
||||
|
@@ -1,54 +0,0 @@
|
||||
use crate::intf::*;
|
||||
use crate::network_manager::*;
|
||||
use utils::async_peek_stream::*;
|
||||
|
||||
use async_std::net::*;
|
||||
use async_tls::TlsAcceptor;
|
||||
|
||||
pub trait TcpProtocolHandler: TcpProtocolHandlerClone + Send + Sync {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SendPinBoxFuture<Result<bool, String>>;
|
||||
}
|
||||
|
||||
pub trait TcpProtocolHandlerClone {
|
||||
fn clone_box(&self) -> Box<dyn TcpProtocolHandler>;
|
||||
}
|
||||
|
||||
impl<T> TcpProtocolHandlerClone for T
|
||||
where
|
||||
T: 'static + TcpProtocolHandler + Clone,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn TcpProtocolHandler> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
impl Clone for Box<dyn TcpProtocolHandler> {
|
||||
fn clone(&self) -> Box<dyn TcpProtocolHandler> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
pub type NewTcpProtocolHandler =
|
||||
dyn Fn(NetworkManager, bool, SocketAddr) -> Box<dyn TcpProtocolHandler> + Send;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ListenerState {
|
||||
pub protocol_handlers: Vec<Box<dyn TcpProtocolHandler + 'static>>,
|
||||
pub tls_protocol_handlers: Vec<Box<dyn TcpProtocolHandler + 'static>>,
|
||||
pub tls_acceptor: Option<TlsAcceptor>,
|
||||
}
|
||||
|
||||
impl ListenerState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
protocol_handlers: Vec::new(),
|
||||
tls_protocol_handlers: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,15 +1,15 @@
|
||||
mod listener_state;
|
||||
mod network_tcp;
|
||||
mod network_udp;
|
||||
mod protocol;
|
||||
mod public_dialinfo_discovery;
|
||||
mod start_protocols;
|
||||
|
||||
use crate::connection_manager::*;
|
||||
use crate::intf::*;
|
||||
use crate::network_manager::*;
|
||||
use crate::routing_table::*;
|
||||
use crate::*;
|
||||
use listener_state::*;
|
||||
use network_tcp::*;
|
||||
use protocol::tcp::RawTcpProtocolHandler;
|
||||
use protocol::udp::RawUdpProtocolHandler;
|
||||
use protocol::ws::WebsocketProtocolHandler;
|
||||
@@ -136,10 +136,18 @@ impl Network {
|
||||
this
|
||||
}
|
||||
|
||||
fn network_manager(&self) -> NetworkManager {
|
||||
self.inner.lock().network_manager.clone()
|
||||
}
|
||||
|
||||
fn routing_table(&self) -> RoutingTable {
|
||||
self.inner.lock().routing_table.clone()
|
||||
}
|
||||
|
||||
fn connection_manager(&self) -> ConnectionManager {
|
||||
self.inner.lock().network_manager.connection_manager()
|
||||
}
|
||||
|
||||
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
|
||||
let cvec = certs(&mut BufReader::new(File::open(path)?))
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid TLS certificate"))?;
|
||||
@@ -223,63 +231,28 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_preferred_local_address(
|
||||
&self,
|
||||
local_port: u16,
|
||||
peer_socket_addr: &SocketAddr,
|
||||
) -> SocketAddr {
|
||||
match peer_socket_addr {
|
||||
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), local_port),
|
||||
SocketAddr::V6(_) => SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
|
||||
local_port,
|
||||
),
|
||||
fn get_preferred_local_address(&self, dial_info: &DialInfo) -> SocketAddr {
|
||||
let inner = self.inner.lock();
|
||||
|
||||
let local_port = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => inner.udp_port,
|
||||
ProtocolType::TCP => inner.tcp_port,
|
||||
ProtocolType::WS => inner.ws_port,
|
||||
ProtocolType::WSS => inner.wss_port,
|
||||
};
|
||||
|
||||
match dial_info.address_type() {
|
||||
AddressType::IPV4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), local_port),
|
||||
AddressType::IPV6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), local_port),
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
async fn send_data_to_existing_connection(
|
||||
&self,
|
||||
descriptor: &ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
match descriptor.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
// send over the best udp socket we have bound since UDP is not connection oriented
|
||||
let peer_socket_addr = descriptor.remote.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(
|
||||
&peer_socket_addr,
|
||||
&descriptor.local.map(|sa| sa.to_socket_addr()),
|
||||
) {
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?;
|
||||
// Data was consumed
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
ProtocolType::TCP | ProtocolType::WS | ProtocolType::WSS => {
|
||||
// find an existing connection in the connection table if one exists
|
||||
let network_manager = self.inner.lock().network_manager.clone();
|
||||
if let Some(entry) = network_manager
|
||||
.connection_table()
|
||||
.get_connection(descriptor)
|
||||
{
|
||||
// connection exists, send over it
|
||||
entry.conn.send(data).await.map_err(logthru_net!())?;
|
||||
|
||||
// Data was consumed
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
// connection or local socket didn't exist, we'll need to use dialinfo to create one
|
||||
// Pass the data back out so we don't own it any more
|
||||
Ok(Some(data))
|
||||
}
|
||||
|
||||
// Send data to a dial info, unbound, using a new connection from a random port
|
||||
// This creates a short-lived connection in the case of connection-oriented protocols
|
||||
// for the purpose of sending this one message.
|
||||
// This bypasses the connection table as it is not a 'node to node' connection.
|
||||
pub async fn send_data_unbound_to_dial_info(
|
||||
&self,
|
||||
dial_info: &DialInfo,
|
||||
@@ -305,61 +278,113 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
// Initiate a new low-level protocol connection to a node
|
||||
pub async fn connect_to_dial_info(
|
||||
&self,
|
||||
local_addr: Option<SocketAddr>,
|
||||
dial_info: &DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
let connection_manager = self.connection_manager();
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
Ok(match &dial_info {
|
||||
DialInfo::UDP(_) => {
|
||||
panic!("Do not attempt to connect to UDP dial info")
|
||||
}
|
||||
DialInfo::TCP(_) => {
|
||||
let local_addr =
|
||||
self.get_preferred_local_address(self.inner.lock().tcp_port, &peer_socket_addr);
|
||||
RawTcpProtocolHandler::connect(connection_manager, local_addr, dial_info)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
}
|
||||
DialInfo::WS(_) => {
|
||||
let local_addr =
|
||||
self.get_preferred_local_address(self.inner.lock().ws_port, &peer_socket_addr);
|
||||
WebsocketProtocolHandler::connect(connection_manager, local_addr, dial_info)
|
||||
.await
|
||||
.map_err(logthru_net!(error))?
|
||||
}
|
||||
DialInfo::WSS(_) => {
|
||||
let local_addr =
|
||||
self.get_preferred_local_address(self.inner.lock().wss_port, &peer_socket_addr);
|
||||
WebsocketProtocolHandler::connect(connection_manager, local_addr, dial_info)
|
||||
.await
|
||||
.map_err(logthru_net!(error))?
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_data_to_existing_connection(
|
||||
&self,
|
||||
descriptor: &ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
match descriptor.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
// send over the best udp socket we have bound since UDP is not connection oriented
|
||||
let peer_socket_addr = descriptor.remote.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(
|
||||
&peer_socket_addr,
|
||||
&descriptor.local.map(|sa| sa.to_socket_addr()),
|
||||
) {
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?;
|
||||
// Data was consumed
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
ProtocolType::TCP | ProtocolType::WS | ProtocolType::WSS => {
|
||||
// find an existing connection in the connection table if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor) {
|
||||
// connection exists, send over it
|
||||
conn.send(data).await.map_err(logthru_net!())?;
|
||||
|
||||
// Data was consumed
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
// connection or local socket didn't exist, we'll need to use dialinfo to create one
|
||||
// Pass the data back out so we don't own it any more
|
||||
Ok(Some(data))
|
||||
}
|
||||
|
||||
// Send data directly to a dial info, possibly without knowing which node it is going to
|
||||
pub async fn send_data_to_dial_info(
|
||||
&self,
|
||||
dial_info: &DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
let network_manager = self.inner.lock().network_manager.clone();
|
||||
// Handle connectionless protocol
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
return ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!());
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
|
||||
let conn = match &dial_info {
|
||||
DialInfo::UDP(_) => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
return ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!());
|
||||
} else {
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
}
|
||||
DialInfo::TCP(_) => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
let local_addr =
|
||||
self.get_preferred_local_address(self.inner.lock().tcp_port, &peer_socket_addr);
|
||||
RawTcpProtocolHandler::connect(network_manager, local_addr, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!())?
|
||||
}
|
||||
DialInfo::WS(_) => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
let local_addr =
|
||||
self.get_preferred_local_address(self.inner.lock().ws_port, &peer_socket_addr);
|
||||
WebsocketProtocolHandler::connect(network_manager, local_addr, dial_info)
|
||||
.await
|
||||
.map_err(logthru_net!(error))?
|
||||
}
|
||||
DialInfo::WSS(_) => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
let local_addr =
|
||||
self.get_preferred_local_address(self.inner.lock().wss_port, &peer_socket_addr);
|
||||
WebsocketProtocolHandler::connect(network_manager, local_addr, dial_info)
|
||||
.await
|
||||
.map_err(logthru_net!(error))?
|
||||
}
|
||||
};
|
||||
// Handle connection-oriented protocols
|
||||
let conn = self.connect_to_dial_info(dial_info).await?;
|
||||
|
||||
conn.send(data).await.map_err(logthru_net!(error))
|
||||
}
|
||||
|
||||
// Send data to node
|
||||
// We may not have dial info for a node, but have an existing connection for it
|
||||
// because an inbound connection happened first, and no FindNodeQ has happened to that
|
||||
// node yet to discover its dial info. The existing connection should be tried first
|
||||
// in this case.
|
||||
pub async fn send_data(&self, node_ref: NodeRef, data: Vec<u8>) -> Result<(), String> {
|
||||
let dial_info = node_ref.best_dial_info();
|
||||
let descriptor = node_ref.last_connection();
|
||||
|
||||
// First try to send data to the last socket we've seen this peer on
|
||||
let di_data = if let Some(descriptor) = descriptor {
|
||||
let data = if let Some(descriptor) = node_ref.last_connection() {
|
||||
match self
|
||||
.clone()
|
||||
.send_data_to_existing_connection(&descriptor, data)
|
||||
@@ -375,11 +400,30 @@ impl Network {
|
||||
};
|
||||
|
||||
// If that fails, try to make a connection or reach out to the peer via its dial info
|
||||
if let Some(di) = dial_info {
|
||||
self.clone().send_data_to_dial_info(&di, di_data).await
|
||||
} else {
|
||||
Err("couldn't send data, no dial info or peer address".to_owned())
|
||||
let dial_info = node_ref
|
||||
.best_dial_info()
|
||||
.ok_or_else(|| "couldn't send data, no dial info or peer address".to_owned())?;
|
||||
|
||||
// Handle connectionless protocol
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
return ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.map_err(logthru_net!());
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
let local_addr = self.get_preferred_local_address(&dial_info);
|
||||
let conn = self
|
||||
.connection_manager()
|
||||
.get_or_create_connection(dial_info, Some(local_addr)); xxx implement this and pass thru to NetworkConnection::connect
|
||||
|
||||
conn.send(data).await.map_err(logthru_net!(error))
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
@@ -437,9 +481,10 @@ impl Network {
|
||||
pub async fn shutdown(&self) {
|
||||
info!("stopping network");
|
||||
|
||||
let network_manager = self.network_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Reset state
|
||||
let network_manager = self.inner.lock().network_manager.clone();
|
||||
let routing_table = network_manager.routing_table();
|
||||
|
||||
// Drop all dial info
|
||||
routing_table.clear_dial_info_details();
|
||||
@@ -453,8 +498,6 @@ impl Network {
|
||||
//////////////////////////////////////////
|
||||
pub fn get_network_class(&self) -> Option<NetworkClass> {
|
||||
let inner = self.inner.lock();
|
||||
let routing_table = inner.routing_table.clone();
|
||||
|
||||
if !inner.network_started {
|
||||
return None;
|
||||
}
|
||||
@@ -466,7 +509,7 @@ impl Network {
|
||||
|
||||
// Go through our global dialinfo and see what our best network class is
|
||||
let mut network_class = NetworkClass::Invalid;
|
||||
for did in routing_table.global_dial_info_details() {
|
||||
for did in inner.routing_table.global_dial_info_details() {
|
||||
if let Some(nc) = did.network_class {
|
||||
if nc < network_class {
|
||||
network_class = nc;
|
||||
@@ -488,7 +531,7 @@ impl Network {
|
||||
) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
inner.network_manager.routing_table(),
|
||||
inner.routing_table.clone(),
|
||||
inner.protocol_config.unwrap_or_default(),
|
||||
inner.udp_static_public_dialinfo,
|
||||
inner.tcp_static_public_dialinfo,
|
||||
|
@@ -1,6 +1,31 @@
|
||||
use super::*;
|
||||
use crate::connection_manager::*;
|
||||
use crate::intf::*;
|
||||
use utils::clone_stream::*;
|
||||
|
||||
use async_tls::TlsAcceptor;
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ListenerState {
|
||||
pub protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_protocol_handlers: Vec<Box<dyn ProtocolAcceptHandler + 'static>>,
|
||||
pub tls_acceptor: Option<TlsAcceptor>,
|
||||
}
|
||||
|
||||
impl ListenerState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
protocol_handlers: Vec::new(),
|
||||
tls_protocol_handlers: Vec::new(),
|
||||
tls_acceptor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
impl Network {
|
||||
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
|
||||
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
|
||||
@@ -20,46 +45,44 @@ impl Network {
|
||||
tls_acceptor: &TlsAcceptor,
|
||||
stream: AsyncPeekStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout: u64,
|
||||
) {
|
||||
match tls_acceptor.accept(stream).await {
|
||||
Ok(ts) => {
|
||||
let ps = AsyncPeekStream::new(CloneStream::new(ts));
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
let ts = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
|
||||
let ps = AsyncPeekStream::new(CloneStream::new(ts));
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// Try the handlers but first get a chunk of data for them to process
|
||||
// Don't waste more than N seconds getting it though, in case someone
|
||||
// is trying to DoS us with a bunch of connections or something
|
||||
// read a chunk of the stream
|
||||
match io::timeout(
|
||||
Duration::from_micros(tls_connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => (),
|
||||
Err(_) => return,
|
||||
}
|
||||
self.clone().try_handlers(ps, addr, protocol_handlers).await;
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("TLS stream failed handshake: {}", e);
|
||||
}
|
||||
}
|
||||
// Try the handlers but first get a chunk of data for them to process
|
||||
// Don't waste more than N seconds getting it though, in case someone
|
||||
// is trying to DoS us with a bunch of connections or something
|
||||
// read a chunk of the stream
|
||||
io::timeout(
|
||||
Duration::from_micros(tls_connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
|
||||
self.try_handlers(ps, addr, protocol_handlers).await
|
||||
}
|
||||
|
||||
async fn try_handlers(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn TcpProtocolHandler>],
|
||||
) {
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
for ah in protocol_handlers.iter() {
|
||||
if ah.on_accept(stream.clone(), addr).await == Ok(true) {
|
||||
return;
|
||||
if let Some(nc) = ah.on_accept(stream.clone(), addr).await? {
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
@@ -73,7 +96,7 @@ impl Network {
|
||||
};
|
||||
|
||||
// Create a reusable socket with no linger time, and no delay
|
||||
let socket = new_shared_tcp_socket(addr)?;
|
||||
let socket = new_bound_shared_tcp_socket(addr)?;
|
||||
// Listen on the socket
|
||||
socket
|
||||
.listen(128)
|
||||
@@ -94,6 +117,7 @@ impl Network {
|
||||
|
||||
// Spawn the socket task
|
||||
let this = self.clone();
|
||||
let connection_manager = self.connection_manager();
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
let jh = spawn(async move {
|
||||
@@ -104,10 +128,7 @@ impl Network {
|
||||
.for_each_concurrent(None, |tcp_stream| async {
|
||||
let tcp_stream = tcp_stream.unwrap();
|
||||
let listener_state = listener_state.clone();
|
||||
// match tcp_stream.set_nodelay(true) {
|
||||
// Ok(_) => (),
|
||||
// _ => continue,
|
||||
// };
|
||||
let connection_manager = connection_manager.clone();
|
||||
|
||||
// Limit the number of connections from the same IP address
|
||||
// and the number of total connections
|
||||
@@ -129,7 +150,6 @@ impl Network {
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
// read a chunk of the stream
|
||||
trace!("reading chunk");
|
||||
if io::timeout(
|
||||
Duration::from_micros(connection_initial_timeout),
|
||||
ps.peek_exact(&mut first_packet),
|
||||
@@ -143,26 +163,35 @@ impl Network {
|
||||
}
|
||||
|
||||
// Run accept handlers on accepted stream
|
||||
trace!("packet ready");
|
||||
|
||||
// Check is this could be TLS
|
||||
let ls = listener_state.read().clone();
|
||||
if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
trace!("trying TLS");
|
||||
this.clone()
|
||||
.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await;
|
||||
let conn = if ls.tls_acceptor.is_some() && first_packet[0] == 0x16 {
|
||||
this.try_tls_handlers(
|
||||
ls.tls_acceptor.as_ref().unwrap(),
|
||||
ps,
|
||||
addr,
|
||||
&ls.tls_protocol_handlers,
|
||||
tls_connection_initial_timeout,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
trace!("not TLS");
|
||||
this.clone()
|
||||
.try_handlers(ps, addr, &ls.protocol_handlers)
|
||||
.await;
|
||||
}
|
||||
this.try_handlers(ps, addr, &ls.protocol_handlers).await
|
||||
};
|
||||
let conn = match conn {
|
||||
Ok(Some(c)) => c,
|
||||
Ok(None) => {
|
||||
// No protocol handlers matched? drop it.
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to negotiate connection? drop it.
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Register the new connection in the connection manager
|
||||
connection_manager.on_new_connection(conn).await;
|
||||
})
|
||||
.await;
|
||||
trace!("exited incoming loop for {}", addr);
|
||||
@@ -189,7 +218,7 @@ impl Network {
|
||||
&self,
|
||||
address: String,
|
||||
is_tls: bool,
|
||||
new_tcp_protocol_handler: Box<NewTcpProtocolHandler>,
|
||||
new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
|
||||
) -> Result<Vec<SocketAddress>, String> {
|
||||
let mut out = Vec::<SocketAddress>::new();
|
||||
// convert to socketaddrs
|
||||
@@ -218,17 +247,19 @@ impl Network {
|
||||
}
|
||||
ls.write()
|
||||
.tls_protocol_handlers
|
||||
.push(new_tcp_protocol_handler(
|
||||
self.inner.lock().network_manager.clone(),
|
||||
.push(new_protocol_accept_handler(
|
||||
self.connection_manager(),
|
||||
true,
|
||||
addr,
|
||||
));
|
||||
} else {
|
||||
ls.write().protocol_handlers.push(new_tcp_protocol_handler(
|
||||
self.inner.lock().network_manager.clone(),
|
||||
false,
|
||||
addr,
|
||||
));
|
||||
ls.write()
|
||||
.protocol_handlers
|
||||
.push(new_protocol_accept_handler(
|
||||
self.connection_manager(),
|
||||
false,
|
||||
addr,
|
||||
));
|
||||
}
|
||||
|
||||
// Return local dial infos we listen on
|
||||
|
@@ -68,7 +68,7 @@ impl Network {
|
||||
let mut port = inner.udp_port;
|
||||
// v4
|
||||
let socket_addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
||||
if let Ok(socket) = new_shared_udp_socket(socket_addr_v4) {
|
||||
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v4) {
|
||||
// Pull the port if we randomly bound, so v6 can be on the same port
|
||||
port = socket
|
||||
.local_addr()
|
||||
@@ -91,7 +91,7 @@ impl Network {
|
||||
//v6
|
||||
let socket_addr_v6 =
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port);
|
||||
if let Ok(socket) = new_shared_udp_socket(socket_addr_v6) {
|
||||
if let Ok(socket) = new_bound_shared_udp_socket(socket_addr_v6) {
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
@@ -111,7 +111,7 @@ impl Network {
|
||||
log_net!("create_udp_inbound_socket on {:?}", &addr);
|
||||
|
||||
// Create a reusable socket
|
||||
let socket = new_shared_udp_socket(addr)?;
|
||||
let socket = new_bound_shared_udp_socket(addr)?;
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
let std_udp_socket: std::net::UdpSocket = socket.into();
|
||||
|
@@ -3,7 +3,6 @@ pub mod udp;
|
||||
pub mod wrtc;
|
||||
pub mod ws;
|
||||
|
||||
use super::listener_state::*;
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
@@ -12,11 +11,17 @@ use socket2::{Domain, Protocol, Socket, Type};
|
||||
pub struct DummyNetworkConnection {}
|
||||
|
||||
impl DummyNetworkConnection {
|
||||
pub fn send(&self, _message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
Box::pin(async { Ok(()) })
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
ConnectionDescriptor::new_no_local(PeerAddress::new(
|
||||
SocketAddress::default(),
|
||||
ProtocolType::UDP,
|
||||
))
|
||||
}
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
Box::pin(async { Ok(Vec::new()) })
|
||||
pub async fn send(&self, _message: Vec<u8>) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,28 +36,53 @@ pub enum NetworkConnection {
|
||||
}
|
||||
|
||||
impl NetworkConnection {
|
||||
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::RawTcp(t) => t.send(message),
|
||||
Self::WsAccepted(w) => w.send(message),
|
||||
Self::Ws(w) => w.send(message),
|
||||
Self::Wss(w) => w.send(message),
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
tcp::RawTcpProtocolHandler::connect(local_address, dial_info).await
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::RawTcp(t) => t.recv(),
|
||||
Self::WsAccepted(w) => w.recv(),
|
||||
Self::Ws(w) => w.recv(),
|
||||
Self::Wss(w) => w.recv(),
|
||||
Self::Dummy(d) => d.connection_descriptor(),
|
||||
Self::RawTcp(t) => t.connection_descriptor(),
|
||||
Self::WsAccepted(w) => w.connection_descriptor(),
|
||||
Self::Ws(w) => w.connection_descriptor(),
|
||||
Self::Wss(w) => w.connection_descriptor(),
|
||||
}
|
||||
}
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message).await,
|
||||
Self::RawTcp(t) => t.send(message).await,
|
||||
Self::WsAccepted(w) => w.send(message).await,
|
||||
Self::Ws(w) => w.send(message).await,
|
||||
Self::Wss(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv().await,
|
||||
Self::RawTcp(t) => t.recv().await,
|
||||
Self::WsAccepted(w) => w.recv().await,
|
||||
Self::Ws(w) => w.recv().await,
|
||||
Self::Wss(w) => w.recv().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<socket2::Socket, String> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
|
||||
@@ -66,7 +96,12 @@ pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socke
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_unbound_shared_udp_socket(domain)?;
|
||||
let socket2_addr = socket2::SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
@@ -77,8 +112,7 @@ pub fn new_shared_udp_socket(local_address: SocketAddr) -> Result<socket2::Socke
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_shared_tcp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<socket2::Socket, String> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
@@ -98,13 +132,18 @@ pub fn new_shared_tcp_socket(local_address: SocketAddr) -> Result<socket2::Socke
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<socket2::Socket, String> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = new_unbound_shared_tcp_socket(domain)?;
|
||||
|
||||
let socket2_addr = socket2::SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
|
||||
log_net!("created shared tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use crate::connection_manager::*;
|
||||
use crate::intf::native::utils::async_peek_stream::*;
|
||||
use crate::intf::*;
|
||||
use crate::network_manager::{NetworkManager, MAX_MESSAGE_SIZE};
|
||||
use crate::network_manager::MAX_MESSAGE_SIZE;
|
||||
use crate::*;
|
||||
use async_std::net::*;
|
||||
use async_std::prelude::*;
|
||||
@@ -15,11 +16,14 @@ struct RawTcpNetworkConnectionInner {
|
||||
#[derive(Clone)]
|
||||
pub struct RawTcpNetworkConnection {
|
||||
inner: Arc<AsyncMutex<RawTcpNetworkConnectionInner>>,
|
||||
connection_descriptor: ConnectionDescriptor,
|
||||
}
|
||||
|
||||
impl fmt::Debug for RawTcpNetworkConnection {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", std::any::type_name::<Self>())
|
||||
f.debug_struct("RawTCPNetworkConnection")
|
||||
.field("connection_descriptor", &self.connection_descriptor)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,68 +40,63 @@ impl RawTcpNetworkConnection {
|
||||
RawTcpNetworkConnectionInner { stream }
|
||||
}
|
||||
|
||||
pub fn new(stream: AsyncPeekStream) -> Self {
|
||||
pub fn new(stream: AsyncPeekStream, connection_descriptor: ConnectionDescriptor) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(AsyncMutex::new(Self::new_inner(stream))),
|
||||
connection_descriptor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RawTcpNetworkConnection {
|
||||
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
let inner = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
}
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
let mut inner = inner.lock().await;
|
||||
inner
|
||||
.stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
inner
|
||||
.stream
|
||||
.write_all(&message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
})
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
self.connection_descriptor.clone()
|
||||
}
|
||||
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
let inner = self.inner.clone();
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
}
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
Box::pin(async move {
|
||||
let mut header = [0u8; 4];
|
||||
let mut inner = inner.lock().await;
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner
|
||||
.stream
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
inner
|
||||
.stream
|
||||
.write_all(&message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())
|
||||
}
|
||||
|
||||
inner
|
||||
.stream
|
||||
.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| format!("TCP recv error: {}", e))?;
|
||||
if header[0] != b'V' || header[1] != b'L' {
|
||||
return Err("received invalid TCP frame header".to_owned());
|
||||
}
|
||||
let len = ((header[3] as usize) << 8) | (header[2] as usize);
|
||||
if len > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large TCP frame".to_owned());
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut header = [0u8; 4];
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
inner
|
||||
.stream
|
||||
.read_exact(&mut out)
|
||||
.await
|
||||
.map_err(map_to_string)?;
|
||||
Ok(out)
|
||||
})
|
||||
inner
|
||||
.stream
|
||||
.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| format!("TCP recv error: {}", e))?;
|
||||
if header[0] != b'V' || header[1] != b'L' {
|
||||
return Err("received invalid TCP frame header".to_owned());
|
||||
}
|
||||
let len = ((header[3] as usize) << 8) | (header[2] as usize);
|
||||
if len > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large TCP frame".to_owned());
|
||||
}
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
inner
|
||||
.stream
|
||||
.read_exact(&mut out)
|
||||
.await
|
||||
.map_err(map_to_string)?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,32 +104,35 @@ impl RawTcpNetworkConnection {
|
||||
///
|
||||
|
||||
struct RawTcpProtocolHandlerInner {
|
||||
network_manager: NetworkManager,
|
||||
connection_manager: ConnectionManager,
|
||||
local_address: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RawTcpProtocolHandler
|
||||
where
|
||||
Self: TcpProtocolHandler,
|
||||
Self: ProtocolAcceptHandler,
|
||||
{
|
||||
inner: Arc<Mutex<RawTcpProtocolHandlerInner>>,
|
||||
}
|
||||
|
||||
impl RawTcpProtocolHandler {
|
||||
fn new_inner(
|
||||
network_manager: NetworkManager,
|
||||
connection_manager: ConnectionManager,
|
||||
local_address: SocketAddr,
|
||||
) -> RawTcpProtocolHandlerInner {
|
||||
RawTcpProtocolHandlerInner {
|
||||
network_manager,
|
||||
connection_manager,
|
||||
local_address,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(network_manager: NetworkManager, local_address: SocketAddr) -> Self {
|
||||
pub fn new(connection_manager: ConnectionManager, local_address: SocketAddr) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(Self::new_inner(network_manager, local_address))),
|
||||
inner: Arc::new(Mutex::new(Self::new_inner(
|
||||
connection_manager,
|
||||
local_address,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,7 +140,7 @@ impl RawTcpProtocolHandler {
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<bool, String> {
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
|
||||
let peeklen = stream
|
||||
.peek(&mut peekbuf)
|
||||
@@ -147,51 +149,47 @@ impl RawTcpProtocolHandler {
|
||||
.map_err(logthru_net!("could not peek tcp stream"))?;
|
||||
assert_eq!(peeklen, PEEK_DETECT_LEN);
|
||||
|
||||
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream));
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(socket_addr),
|
||||
ProtocolType::TCP,
|
||||
);
|
||||
let (network_manager, local_address) = {
|
||||
let inner = self.inner.lock();
|
||||
(inner.network_manager.clone(), inner.local_address)
|
||||
(inner.connection_manager.clone(), inner.local_address)
|
||||
};
|
||||
network_manager
|
||||
.on_new_connection(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_address),
|
||||
),
|
||||
conn,
|
||||
)
|
||||
.await?;
|
||||
Ok(true)
|
||||
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
stream,
|
||||
ConnectionDescriptor::new(peer_addr, SocketAddress::from_socket_addr(local_address)),
|
||||
));
|
||||
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
network_manager: NetworkManager,
|
||||
local_address: SocketAddr,
|
||||
remote_socket_addr: SocketAddr,
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
// Get remote socket address to connect to
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
// Make a shared socket
|
||||
let socket = new_shared_tcp_socket(local_address)?;
|
||||
let socket = match local_address {
|
||||
Some(a) => new_bound_shared_tcp_socket(a)?,
|
||||
None => new_unbound_shared_tcp_socket(Domain::for_address(remote_socket_addr))?,
|
||||
};
|
||||
|
||||
// Connect to the remote address
|
||||
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr);
|
||||
socket
|
||||
.connect(&remote_socket2_addr)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
log_net!(
|
||||
"tcp connect successful: local_address={} remote_addr={}",
|
||||
local_address,
|
||||
remote_socket_addr
|
||||
);
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let ts = TcpStream::from(std_stream);
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let local_address = ts
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
@@ -202,16 +200,13 @@ impl RawTcpProtocolHandler {
|
||||
);
|
||||
|
||||
// Wrap the stream in a network connection and register it
|
||||
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(ps));
|
||||
network_manager
|
||||
.on_new_connection(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_address),
|
||||
),
|
||||
conn.clone(),
|
||||
)
|
||||
.await?;
|
||||
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(
|
||||
ps,
|
||||
ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_address)),
|
||||
remote: dial_info.to_peer_address(),
|
||||
},
|
||||
));
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
@@ -235,12 +230,12 @@ impl RawTcpProtocolHandler {
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpProtocolHandler for RawTcpProtocolHandler {
|
||||
impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SendPinBoxFuture<Result<bool, String>> {
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use crate::connection_manager::*;
|
||||
use crate::intf::native::utils::async_peek_stream::*;
|
||||
use crate::intf::*;
|
||||
use crate::network_manager::{NetworkManager, MAX_MESSAGE_SIZE};
|
||||
use crate::network_manager::MAX_MESSAGE_SIZE;
|
||||
use crate::*;
|
||||
use async_std::io;
|
||||
use async_std::net::*;
|
||||
@@ -32,6 +33,7 @@ where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
tls: bool,
|
||||
connection_descriptor: ConnectionDescriptor,
|
||||
inner: Arc<AsyncMutex<WebSocketNetworkConnectionInner<T>>>,
|
||||
}
|
||||
|
||||
@@ -42,6 +44,7 @@ where
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
tls: self.tls,
|
||||
connection_descriptor: self.connection_descriptor.clone(),
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
@@ -61,7 +64,9 @@ where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.tls == other.tls && Arc::as_ptr(&self.inner) == Arc::as_ptr(&other.inner)
|
||||
self.tls == other.tls
|
||||
&& self.connection_descriptor == other.connection_descriptor
|
||||
&& Arc::as_ptr(&self.inner) == Arc::as_ptr(&other.inner)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,56 +76,56 @@ impl<T> WebsocketNetworkConnection<T>
|
||||
where
|
||||
T: io::Read + io::Write + Send + Unpin + 'static,
|
||||
{
|
||||
pub fn new(tls: bool, ws_stream: WebSocketStream<T>) -> Self {
|
||||
pub fn new(
|
||||
tls: bool,
|
||||
connection_descriptor: ConnectionDescriptor,
|
||||
ws_stream: WebSocketStream<T>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tls,
|
||||
connection_descriptor,
|
||||
inner: Arc::new(AsyncMutex::new(WebSocketNetworkConnectionInner {
|
||||
ws_stream,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
let inner = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
}
|
||||
let mut inner = inner.lock().await;
|
||||
inner
|
||||
.ws_stream
|
||||
.send(Message::binary(message))
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
})
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
self.connection_descriptor.clone()
|
||||
}
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
let inner = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut inner = inner.lock().await;
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
}
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner
|
||||
.ws_stream
|
||||
.send(Message::binary(message))
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
let out = match inner.ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
Some(Ok(_)) => {
|
||||
return Err("Unexpected WS message type".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(out)
|
||||
let out = match inner.ws_stream.next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
Some(Ok(_)) => {
|
||||
return Err("Unexpected WS message type".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
})
|
||||
Some(Err(e)) => {
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!());
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +133,7 @@ where
|
||||
///
|
||||
struct WebsocketProtocolHandlerInner {
|
||||
tls: bool,
|
||||
network_manager: NetworkManager,
|
||||
connection_manager: ConnectionManager,
|
||||
local_address: SocketAddr,
|
||||
request_path: Vec<u8>,
|
||||
connection_initial_timeout: u64,
|
||||
@@ -137,13 +142,17 @@ struct WebsocketProtocolHandlerInner {
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketProtocolHandler
|
||||
where
|
||||
Self: TcpProtocolHandler,
|
||||
Self: ProtocolAcceptHandler,
|
||||
{
|
||||
inner: Arc<WebsocketProtocolHandlerInner>,
|
||||
}
|
||||
impl WebsocketProtocolHandler {
|
||||
pub fn new(network_manager: NetworkManager, tls: bool, local_address: SocketAddr) -> Self {
|
||||
let config = network_manager.config();
|
||||
pub fn new(
|
||||
connection_manager: ConnectionManager,
|
||||
tls: bool,
|
||||
local_address: SocketAddr,
|
||||
) -> Self {
|
||||
let config = connection_manager.config();
|
||||
let c = config.get();
|
||||
let path = if tls {
|
||||
format!("GET {}", c.network.protocol.ws.path.trim_end_matches('/'))
|
||||
@@ -158,7 +167,7 @@ impl WebsocketProtocolHandler {
|
||||
|
||||
let inner = WebsocketProtocolHandlerInner {
|
||||
tls,
|
||||
network_manager,
|
||||
connection_manager,
|
||||
local_address,
|
||||
request_path: path.as_bytes().to_vec(),
|
||||
connection_initial_timeout,
|
||||
@@ -172,7 +181,7 @@ impl WebsocketProtocolHandler {
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<bool, String> {
|
||||
) -> Result<Option<NetworkConnection>, String> {
|
||||
let request_path_len = self.inner.request_path.len() + 2;
|
||||
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
|
||||
match io::timeout(
|
||||
@@ -197,7 +206,7 @@ impl WebsocketProtocolHandler {
|
||||
|
||||
if !matches_path {
|
||||
log_net!("not websocket");
|
||||
return Ok(false);
|
||||
return Ok(None);
|
||||
}
|
||||
log_net!("found websocket");
|
||||
|
||||
@@ -218,26 +227,19 @@ impl WebsocketProtocolHandler {
|
||||
|
||||
let conn = NetworkConnection::WsAccepted(WebsocketNetworkConnection::new(
|
||||
self.inner.tls,
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(self.inner.local_address),
|
||||
),
|
||||
ws_stream,
|
||||
));
|
||||
self.inner
|
||||
.network_manager
|
||||
.clone()
|
||||
.on_new_connection(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(self.inner.local_address),
|
||||
),
|
||||
conn,
|
||||
)
|
||||
.await?;
|
||||
Ok(true)
|
||||
|
||||
Ok(Some(conn))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
network_manager: NetworkManager,
|
||||
local_address: SocketAddr,
|
||||
dial_info: &DialInfo,
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match &dial_info {
|
||||
@@ -256,14 +258,17 @@ impl WebsocketProtocolHandler {
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
// Make a shared socket
|
||||
let socket = new_shared_tcp_socket(local_address)?;
|
||||
let socket = match local_address {
|
||||
Some(a) => new_bound_shared_tcp_socket(a)?,
|
||||
None => new_unbound_shared_tcp_socket(Domain::for_address(remote_socket_addr))?,
|
||||
};
|
||||
|
||||
// Connect to the remote address
|
||||
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr);
|
||||
socket
|
||||
.connect(&remote_socket2_addr)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={} remote_socket_addr={}", local_address, remote_socket_addr))?;
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_socket_addr={}", local_address, remote_socket_addr))?;
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let tcp_stream = TcpStream::from(std_stream);
|
||||
|
||||
@@ -273,6 +278,11 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!())?;
|
||||
|
||||
// Make our connection descriptor
|
||||
let connection_descriptor = ConnectionDescriptor {
|
||||
local: Some(SocketAddress::from_socket_addr(actual_local_addr)),
|
||||
remote: dial_info.to_peer_address(),
|
||||
};
|
||||
// Negotiate TLS if this is WSS
|
||||
if tls {
|
||||
let connector = TlsConnector::default();
|
||||
@@ -285,59 +295,32 @@ impl WebsocketProtocolHandler {
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let conn = NetworkConnection::Wss(WebsocketNetworkConnection::new(tls, ws_stream));
|
||||
|
||||
// Make the connection descriptor peer address
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(remote_socket_addr),
|
||||
ProtocolType::WSS,
|
||||
);
|
||||
|
||||
// Register the WSS connection
|
||||
network_manager
|
||||
.on_new_connection(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(actual_local_addr),
|
||||
),
|
||||
conn.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(conn)
|
||||
Ok(NetworkConnection::Wss(WebsocketNetworkConnection::new(
|
||||
tls,
|
||||
connection_descriptor,
|
||||
ws_stream,
|
||||
)))
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let conn = NetworkConnection::Ws(WebsocketNetworkConnection::new(tls, ws_stream));
|
||||
|
||||
// Make the connection descriptor peer address
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(remote_socket_addr),
|
||||
ProtocolType::WS,
|
||||
);
|
||||
|
||||
// Register the WS connection
|
||||
network_manager
|
||||
.on_new_connection(
|
||||
ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(actual_local_addr),
|
||||
),
|
||||
conn.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(conn)
|
||||
Ok(NetworkConnection::Ws(WebsocketNetworkConnection::new(
|
||||
tls,
|
||||
connection_descriptor,
|
||||
ws_stream,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpProtocolHandler for WebsocketProtocolHandler {
|
||||
impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
fn on_accept(
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<bool, String>> {
|
||||
) -> SystemPinBoxFuture<Result<Option<NetworkConnection>, String>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -78,7 +78,7 @@ impl Network {
|
||||
.start_tcp_listener(
|
||||
listen_address.clone(),
|
||||
false,
|
||||
Box::new(|n, t, a| Box::new(WebsocketProtocolHandler::new(n, t, a))),
|
||||
Box::new(|c, t, a| Box::new(WebsocketProtocolHandler::new(c, t, a))),
|
||||
)
|
||||
.await?;
|
||||
trace!("WS: listener started");
|
||||
|
@@ -1,7 +1,7 @@
|
||||
use crate::xx::*;
|
||||
use async_std::io::{Read, ReadExt, Result, Write};
|
||||
use core::pin::Pin;
|
||||
use core::task::{Context, Poll};
|
||||
use std::pin::Pin;
|
||||
|
||||
////////
|
||||
///
|
||||
|
43
veilid-core/src/intf/user_secret.rs
Normal file
43
veilid-core/src/intf/user_secret.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use super::*;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
|
||||
pub async fn save_user_secret(namespace: &str, key: &str, value: &[u8]) -> Result<bool, String> {
|
||||
let mut s = BASE64URL_NOPAD.encode(value);
|
||||
s.push('!');
|
||||
|
||||
save_user_secret_string(namespace, key, s.as_str()).await
|
||||
}
|
||||
|
||||
pub async fn load_user_secret(namespace: &str, key: &str) -> Result<Option<Vec<u8>>, String> {
|
||||
let mut s = match load_user_secret_string(namespace, key).await? {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
if s.pop() != Some('!') {
|
||||
return Err("User secret is not a buffer".to_owned());
|
||||
}
|
||||
|
||||
let mut bytes = Vec::<u8>::new();
|
||||
let res = BASE64URL_NOPAD.decode_len(s.len());
|
||||
match res {
|
||||
Ok(l) => {
|
||||
bytes.resize(l, 0u8);
|
||||
}
|
||||
Err(_) => {
|
||||
return Err("Failed to decode".to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let res = BASE64URL_NOPAD.decode_mut(s.as_bytes(), &mut bytes);
|
||||
match res {
|
||||
Ok(_) => Ok(Some(bytes)),
|
||||
Err(_) => Err("Failed to decode".to_owned()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_user_secret(namespace: &str, key: &str) -> Result<bool, String> {
|
||||
remove_user_secret_string(namespace, key).await
|
||||
}
|
@@ -8,14 +8,17 @@ use crate::xx::*;
|
||||
pub struct DummyNetworkConnection {}
|
||||
|
||||
impl DummyNetworkConnection {
|
||||
pub fn protocol_type(&self) -> ProtocolType {
|
||||
ProtocolType::UDP
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
ConnectionDescriptor::new_no_local(PeerAddress::new(
|
||||
SocketAddress::default(),
|
||||
ProtocolType::UDP,
|
||||
))
|
||||
}
|
||||
pub fn send(&self, _message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
Box::pin(async { Ok(()) })
|
||||
pub async fn send(&self, _message: Vec<u8>) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
Box::pin(async { Ok(Vec::new()) })
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,16 +30,33 @@ pub enum NetworkConnection {
|
||||
}
|
||||
|
||||
impl NetworkConnection {
|
||||
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::WS(w) => w.send(message),
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
}
|
||||
ProtocolType::TCP => {
|
||||
panic!("TCP dial info is not support on WASM targets");
|
||||
}
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
ws::WebsocketProtocolHandler::connect(local_address, dial_info).await
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::WS(w) => w.recv(),
|
||||
Self::Dummy(d) => d.send(message).await,
|
||||
Self::WS(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv().await,
|
||||
Self::WS(w) => w.recv().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -14,6 +14,7 @@ struct WebsocketNetworkConnectionInner {
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketNetworkConnection {
|
||||
tls: bool,
|
||||
connection_descriptor: ConnectionDescriptor,
|
||||
inner: Arc<Mutex<WebsocketNetworkConnectionInner>>,
|
||||
}
|
||||
|
||||
@@ -32,52 +33,49 @@ impl PartialEq for WebsocketNetworkConnection {
|
||||
impl Eq for WebsocketNetworkConnection {}
|
||||
|
||||
impl WebsocketNetworkConnection {
|
||||
pub fn new(tls: bool, ws_meta: WsMeta, ws_stream: WsStream) -> Self {
|
||||
pub fn new(tls: bool, connection_descriptor: ConnectionDescriptor, ws_stream: WsStream) -> Self {
|
||||
let ws = ws_stream.wrapped().clone();
|
||||
Self {
|
||||
tls,
|
||||
connection_descriptor,
|
||||
inner: Arc::new(Mutex::new(WebsocketNetworkConnectionInner {
|
||||
ws_stream,
|
||||
ws,
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebsocketNetworkConnection {
|
||||
pub fn send(&self, message: Vec<u8>) -> SystemPinBoxFuture<Result<(), String>> {
|
||||
let inner = self.inner.clone();
|
||||
Box::pin(async move {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
inner
|
||||
.lock()
|
||||
.ws
|
||||
.send_with_u8_array(&message)
|
||||
.map_err(|_| "failed to send to websocket".to_owned())
|
||||
.map_err(logthru_net!(error))
|
||||
})
|
||||
pub fn connection_descriptor(&self) -> ConnectionDescriptor {
|
||||
self.connection_descriptor.clone()
|
||||
}
|
||||
pub fn recv(&self) -> SystemPinBoxFuture<Result<Vec<u8>, String>> {
|
||||
let inner = self.inner.clone();
|
||||
Box::pin(async move {
|
||||
let out = match inner.lock().ws_stream.next().await {
|
||||
Some(WsMessage::Binary(v)) => v,
|
||||
Some(_) => {
|
||||
return Err("Unexpected WS message type".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(out)
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
self.inner
|
||||
.lock()
|
||||
.ws
|
||||
.send_with_u8_array(&message)
|
||||
.map_err(|_| "failed to send to websocket".to_owned())
|
||||
.map_err(logthru_net!(error))
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
let out = match self.inner.lock().ws_stream.next().await {
|
||||
Some(WsMessage::Binary(v)) => v,
|
||||
Some(_) => {
|
||||
return Err("Unexpected WS message type".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
}
|
||||
})
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned()).map_err(logthru_net!(error));
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +86,7 @@ pub struct WebsocketProtocolHandler {}
|
||||
|
||||
impl WebsocketProtocolHandler {
|
||||
pub async fn connect(
|
||||
network_manager: NetworkManager,
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: &DialInfo,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
let url = dial_info
|
||||
@@ -113,18 +111,18 @@ impl WebsocketProtocolHandler {
|
||||
.map_err(logthru_net!(error))
|
||||
}
|
||||
};
|
||||
let peer_addr = dial_info.to_peer_address();
|
||||
|
||||
let (ws, wsio) = WsMeta::connect(url, None)
|
||||
let (_, wsio) = WsMeta::connect(url, None)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
let conn = NetworkConnection::WS(WebsocketNetworkConnection::new(tls, ws, wsio));
|
||||
network_manager
|
||||
.on_new_connection(ConnectionDescriptor::new_no_local(peer_addr), conn.clone())
|
||||
.await?;
|
||||
// Make our connection descriptor
|
||||
let connection_descriptor = ConnectionDescriptor {
|
||||
local: None,
|
||||
remote: dial_info.to_peer_address(),
|
||||
};
|
||||
|
||||
Ok(conn)
|
||||
Ok(NetworkConnection::WS(WebsocketNetworkConnection::new(tls, connection_descriptor, wsio)))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user