refactor
This commit is contained in:
@@ -18,14 +18,16 @@ impl ConnectionHandle {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
pub fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
self.channel.send(message).map_err(map_to_string)
|
||||
pub fn send(&self, message: Vec<u8>) -> EyreResult<()> {
|
||||
self.channel
|
||||
.send(message)
|
||||
.wrap_err("failed to send to connection")
|
||||
}
|
||||
pub async fn send_async(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send_async(&self, message: Vec<u8>) -> EyreResult<()> {
|
||||
self.channel
|
||||
.send_async(message)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.wrap_err("failed to send_async to connection")
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,33 +1,17 @@
|
||||
use super::*;
|
||||
use alloc::collections::btree_map::Entry;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AddressFilterError {
|
||||
#[error("Count exceeded")]
|
||||
CountExceeded,
|
||||
#[error("Rate exceeded")]
|
||||
RateExceeded,
|
||||
}
|
||||
impl fmt::Display for AddressFilterError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match *self {
|
||||
Self::CountExceeded => "Count exceeded",
|
||||
Self::RateExceeded => "Rate exceeded",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
impl std::error::Error for AddressFilterError {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[error("Address not in table")]
|
||||
pub struct AddressNotInTableError {}
|
||||
impl fmt::Display for AddressNotInTableError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Address not in table")
|
||||
}
|
||||
}
|
||||
impl std::error::Error for AddressNotInTableError {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionLimits {
|
||||
|
@@ -142,13 +142,13 @@ impl ConnectionManager {
|
||||
&self,
|
||||
inner: &mut ConnectionManagerInner,
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<ConnectionHandle, String> {
|
||||
) -> EyreResult<ConnectionHandle> {
|
||||
log_net!("on_new_protocol_network_connection: {:?}", conn);
|
||||
|
||||
// Wrap with NetworkConnection object to start the connection processing loop
|
||||
let stop_token = match &inner.stop_source {
|
||||
Some(ss) => ss.token(),
|
||||
None => return Err("not creating connection because we are stopping".to_owned()),
|
||||
None => bail!("not creating connection because we are stopping"),
|
||||
};
|
||||
|
||||
let conn = NetworkConnection::from_protocol(self.clone(), stop_token, conn);
|
||||
@@ -165,7 +165,7 @@ impl ConnectionManager {
|
||||
&self,
|
||||
local_addr: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ConnectionHandle, String> {
|
||||
) -> EyreResult<ConnectionHandle> {
|
||||
let killed = {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
let inner = match &mut *inner {
|
||||
@@ -274,7 +274,7 @@ impl ConnectionManager {
|
||||
let inner = match &mut *inner {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
return Err("shutting down".to_owned());
|
||||
bail!("shutting down");
|
||||
}
|
||||
};
|
||||
self.on_new_protocol_network_connection(inner, conn)
|
||||
@@ -336,7 +336,7 @@ impl ConnectionManager {
|
||||
pub(super) async fn on_accepted_protocol_network_connection(
|
||||
&self,
|
||||
conn: ProtocolNetworkConnection,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Get channel sender
|
||||
let sender = {
|
||||
let mut inner = self.arc.inner.lock();
|
||||
|
@@ -3,6 +3,39 @@ use alloc::collections::btree_map::Entry;
|
||||
use futures_util::StreamExt;
|
||||
use hashlink::LruCache;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
#[derive(ThisError, Debug, Clone, Eq, PartialEq)]
|
||||
pub enum ConnectionTableAddError {
|
||||
#[error("Connection already added to table")]
|
||||
AlreadyExists,
|
||||
#[error("Connection address was filtered")]
|
||||
AddressFilter(AddressFilterError),
|
||||
}
|
||||
|
||||
impl ConnectionTableAddError {
|
||||
pub fn already_exists() -> Self {
|
||||
ConnectionTableAddError::AlreadyExists
|
||||
}
|
||||
pub fn address_filter(err: AddressFilterError) -> Self {
|
||||
ConnectionTableAddError::AddressFilter(err)
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
#[derive(ThisError, Debug, Clone, Eq, PartialEq)]
|
||||
pub enum ConnectionTableRemoveError {
|
||||
#[error("Connection not in table")]
|
||||
NotInTable,
|
||||
}
|
||||
|
||||
impl ConnectionTableRemoveError {
|
||||
pub fn not_in_table() -> Self {
|
||||
ConnectionTableRemoveError::NotInTable
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionTable {
|
||||
max_connections: Vec<usize>,
|
||||
@@ -53,20 +86,22 @@ impl ConnectionTable {
|
||||
while unord.next().await.is_some() {}
|
||||
}
|
||||
|
||||
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
|
||||
pub fn add_connection(
|
||||
&mut self,
|
||||
conn: NetworkConnection,
|
||||
) -> Result<(), ConnectionTableAddError> {
|
||||
let descriptor = conn.connection_descriptor();
|
||||
let ip_addr = descriptor.remote_address().to_ip_addr();
|
||||
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
if self.conn_by_descriptor[index].contains_key(&descriptor) {
|
||||
return Err(format!(
|
||||
"Connection already added to table: {:?}",
|
||||
descriptor
|
||||
));
|
||||
return Err(ConnectionTableAddError::already_exists());
|
||||
}
|
||||
|
||||
// Filter by ip for connection limits
|
||||
self.address_filter.add(ip_addr).map_err(map_to_string)?;
|
||||
self.address_filter
|
||||
.add(ip_addr)
|
||||
.map_err(ConnectionTableAddError::address_filter)?;
|
||||
|
||||
// Add the connection to the table
|
||||
let res = self.conn_by_descriptor[index].insert(descriptor.clone(), conn);
|
||||
@@ -164,11 +199,11 @@ impl ConnectionTable {
|
||||
pub fn remove_connection(
|
||||
&mut self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Result<NetworkConnection, String> {
|
||||
) -> Result<NetworkConnection, ConnectionTableRemoveError> {
|
||||
let index = protocol_to_index(descriptor.protocol_type());
|
||||
let conn = self.conn_by_descriptor[index]
|
||||
.remove(&descriptor)
|
||||
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
|
||||
.ok_or_else(|| ConnectionTableRemoveError::not_in_table())?;
|
||||
|
||||
self.remove_connection_records(descriptor);
|
||||
Ok(conn)
|
||||
|
@@ -127,8 +127,8 @@ struct NetworkManagerInner {
|
||||
|
||||
struct NetworkManagerUnlockedInner {
|
||||
// Background processes
|
||||
rolling_transfers_task: TickTask,
|
||||
relay_management_task: TickTask,
|
||||
rolling_transfers_task: TickTask<EyreReport>,
|
||||
relay_management_task: TickTask<EyreReport>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -236,7 +236,7 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub async fn init(&self, update_callback: UpdateCallback) -> Result<(), String> {
|
||||
pub async fn init(&self, update_callback: UpdateCallback) -> EyreResult<()> {
|
||||
let routing_table = RoutingTable::new(self.clone());
|
||||
routing_table.init().await?;
|
||||
self.inner.lock().routing_table = Some(routing_table.clone());
|
||||
@@ -257,7 +257,7 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub async fn internal_startup(&self) -> Result<(), String> {
|
||||
pub async fn internal_startup(&self) -> EyreResult<()> {
|
||||
trace!("NetworkManager::internal_startup begin");
|
||||
if self.inner.lock().components.is_some() {
|
||||
debug!("NetworkManager::internal_startup already started");
|
||||
@@ -292,7 +292,7 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub async fn startup(&self) -> Result<(), String> {
|
||||
pub async fn startup(&self) -> EyreResult<()> {
|
||||
if let Err(e) = self.internal_startup().await {
|
||||
self.shutdown().await;
|
||||
return Err(e);
|
||||
@@ -387,7 +387,7 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn restart_net(&self, net: Network) -> Result<(), String> {
|
||||
async fn restart_net(&self, net: Network) -> EyreResult<()> {
|
||||
net.shutdown().await;
|
||||
self.send_network_update();
|
||||
net.startup().await?;
|
||||
@@ -395,7 +395,7 @@ impl NetworkManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn tick(&self) -> Result<(), String> {
|
||||
pub async fn tick(&self) -> EyreResult<()> {
|
||||
let (routing_table, net, receipt_manager) = {
|
||||
let inner = self.inner.lock();
|
||||
let components = inner.components.as_ref().unwrap();
|
||||
@@ -481,7 +481,7 @@ impl NetworkManager {
|
||||
expected_returns: u32,
|
||||
extra_data: D,
|
||||
callback: impl ReceiptCallback,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> EyreResult<Vec<u8>> {
|
||||
let receipt_manager = self.receipt_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
@@ -490,7 +490,7 @@ impl NetworkManager {
|
||||
let receipt = Receipt::try_new(0, nonce, routing_table.node_id(), extra_data)?;
|
||||
let out = receipt
|
||||
.to_signed_data(&routing_table.node_id_secret())
|
||||
.map_err(|_| "failed to generate signed receipt".to_owned())?;
|
||||
.wrap_err("failed to generate signed receipt")?;
|
||||
|
||||
// Record the receipt for later
|
||||
let exp_ts = intf::get_timestamp() + expiration_us;
|
||||
@@ -505,7 +505,7 @@ impl NetworkManager {
|
||||
&self,
|
||||
expiration_us: u64,
|
||||
extra_data: D,
|
||||
) -> Result<(Vec<u8>, EventualValueFuture<ReceiptEvent>), String> {
|
||||
) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> {
|
||||
let receipt_manager = self.receipt_manager();
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
@@ -514,7 +514,7 @@ impl NetworkManager {
|
||||
let receipt = Receipt::try_new(0, nonce, routing_table.node_id(), extra_data)?;
|
||||
let out = receipt
|
||||
.to_signed_data(&routing_table.node_id_secret())
|
||||
.map_err(|_| "failed to generate signed receipt".to_owned())?;
|
||||
.wrap_err("failed to generate signed receipt")?;
|
||||
|
||||
// Record the receipt for later
|
||||
let exp_ts = intf::get_timestamp() + expiration_us;
|
||||
@@ -530,11 +530,11 @@ impl NetworkManager {
|
||||
pub async fn handle_out_of_band_receipt<R: AsRef<[u8]>>(
|
||||
&self,
|
||||
receipt_data: R,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let receipt_manager = self.receipt_manager();
|
||||
|
||||
let receipt = Receipt::from_signed_data(receipt_data.as_ref())
|
||||
.map_err(|_| "failed to parse signed out-of-band receipt".to_owned())?;
|
||||
.wrap_err("failed to parse signed out-of-band receipt")?;
|
||||
|
||||
receipt_manager.handle_receipt(receipt, None).await
|
||||
}
|
||||
@@ -545,11 +545,11 @@ impl NetworkManager {
|
||||
&self,
|
||||
receipt_data: R,
|
||||
inbound_nr: NodeRef,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let receipt_manager = self.receipt_manager();
|
||||
|
||||
let receipt = Receipt::from_signed_data(receipt_data.as_ref())
|
||||
.map_err(|_| "failed to parse signed in-band receipt".to_owned())?;
|
||||
.wrap_err("failed to parse signed in-band receipt")?;
|
||||
|
||||
receipt_manager
|
||||
.handle_receipt(receipt, Some(inbound_nr))
|
||||
@@ -558,7 +558,7 @@ impl NetworkManager {
|
||||
|
||||
// Process a received signal
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
pub async fn handle_signal(&self, signal_info: SignalInfo) -> Result<(), String> {
|
||||
pub async fn handle_signal(&self, signal_info: SignalInfo) -> EyreResult<()> {
|
||||
match signal_info {
|
||||
SignalInfo::ReverseConnect { receipt, peer_info } => {
|
||||
let routing_table = self.routing_table();
|
||||
@@ -573,7 +573,7 @@ impl NetworkManager {
|
||||
// Make a reverse connection to the peer and send the receipt to it
|
||||
rpc.rpc_call_return_receipt(Destination::Direct(peer_nr), None, receipt)
|
||||
.await
|
||||
.map_err(map_to_string)?;
|
||||
.wrap_err("rpc failure")?;
|
||||
}
|
||||
SignalInfo::HolePunch { receipt, peer_info } => {
|
||||
let routing_table = self.routing_table();
|
||||
@@ -589,7 +589,7 @@ impl NetworkManager {
|
||||
peer_nr.filter_protocols(ProtocolSet::only(ProtocolType::UDP));
|
||||
let hole_punch_dial_info_detail = peer_nr
|
||||
.first_filtered_dial_info_detail(Some(RoutingDomain::PublicInternet))
|
||||
.ok_or_else(|| "No hole punch capable dialinfo found for node".to_owned())?;
|
||||
.ok_or_else(|| eyre!("No hole punch capable dialinfo found for node"))?;
|
||||
|
||||
// Now that we picked a specific dialinfo, further restrict the noderef to the specific address type
|
||||
let mut filter = peer_nr.take_filter().unwrap();
|
||||
@@ -611,7 +611,7 @@ impl NetworkManager {
|
||||
// Return the receipt using the same dial info send the receipt to it
|
||||
rpc.rpc_call_return_receipt(Destination::Direct(peer_nr), None, receipt)
|
||||
.await
|
||||
.map_err(map_to_string)?;
|
||||
.wrap_err("rpc failure")?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -625,7 +625,7 @@ impl NetworkManager {
|
||||
dest_node_id: DHTKey,
|
||||
version: u8,
|
||||
body: B,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> EyreResult<Vec<u8>> {
|
||||
// DH to get encryption key
|
||||
let routing_table = self.routing_table();
|
||||
let node_id = routing_table.node_id();
|
||||
@@ -639,7 +639,7 @@ impl NetworkManager {
|
||||
let envelope = Envelope::new(version, ts, nonce, node_id, dest_node_id);
|
||||
envelope
|
||||
.to_encrypted_data(self.crypto.clone(), body.as_ref(), &node_id_secret)
|
||||
.map_err(|_| "envelope failed to encode".to_owned())
|
||||
.wrap_err("envelope failed to encode")
|
||||
}
|
||||
|
||||
// Called by the RPC handler when we want to issue an RPC request or response
|
||||
@@ -652,7 +652,7 @@ impl NetworkManager {
|
||||
node_ref: NodeRef,
|
||||
envelope_node_id: Option<DHTKey>,
|
||||
body: B,
|
||||
) -> Result<SendDataKind, String> {
|
||||
) -> EyreResult<SendDataKind> {
|
||||
let via_node_id = node_ref.node_id();
|
||||
let envelope_node_id = envelope_node_id.unwrap_or(via_node_id);
|
||||
|
||||
@@ -671,11 +671,12 @@ impl NetworkManager {
|
||||
{
|
||||
#[allow(clippy::absurd_extreme_comparisons)]
|
||||
if node_min > MAX_VERSION || node_max < MIN_VERSION {
|
||||
return Err(format!(
|
||||
bail!(
|
||||
"can't talk to this node {} because version is unsupported: ({},{})",
|
||||
via_node_id, node_min, node_max
|
||||
))
|
||||
.map_err(logthru_rpc!(warn));
|
||||
via_node_id,
|
||||
node_min,
|
||||
node_max
|
||||
);
|
||||
}
|
||||
cmp::min(node_max, MAX_VERSION)
|
||||
} else {
|
||||
@@ -703,7 +704,7 @@ impl NetworkManager {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
rcpt_data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Do we need to validate the outgoing receipt? Probably not
|
||||
// because it is supposed to be opaque and the
|
||||
// recipient/originator does the validation
|
||||
@@ -717,8 +718,8 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
// Figure out how to reach a node
|
||||
#[instrument(level = "trace", skip(self), ret, err)]
|
||||
fn get_contact_method(&self, mut target_node_ref: NodeRef) -> Result<ContactMethod, String> {
|
||||
#[instrument(level = "trace", skip(self), ret)]
|
||||
fn get_contact_method(&self, mut target_node_ref: NodeRef) -> ContactMethod {
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Get our network class and protocol config and node id
|
||||
@@ -727,14 +728,14 @@ impl NetworkManager {
|
||||
|
||||
// Scope noderef down to protocols we can do outbound
|
||||
if !target_node_ref.filter_protocols(our_protocol_config.outbound) {
|
||||
return Ok(ContactMethod::Unreachable);
|
||||
return ContactMethod::Unreachable;
|
||||
}
|
||||
|
||||
// Get the best matching local direct dial info if we have it
|
||||
let opt_target_local_did =
|
||||
target_node_ref.first_filtered_dial_info_detail(Some(RoutingDomain::LocalNetwork));
|
||||
if let Some(target_local_did) = opt_target_local_did {
|
||||
return Ok(ContactMethod::Direct(target_local_did.dial_info));
|
||||
return ContactMethod::Direct(target_local_did.dial_info);
|
||||
}
|
||||
|
||||
// Get the best match internet dial info if we have it
|
||||
@@ -744,7 +745,7 @@ impl NetworkManager {
|
||||
// Do we need to signal before going inbound?
|
||||
if !target_public_did.class.requires_signal() {
|
||||
// Go direct without signaling
|
||||
return Ok(ContactMethod::Direct(target_public_did.dial_info));
|
||||
return ContactMethod::Direct(target_public_did.dial_info);
|
||||
}
|
||||
|
||||
// Get the target's inbound relay, it must have one or it is not reachable
|
||||
@@ -767,10 +768,10 @@ impl NetworkManager {
|
||||
) {
|
||||
// Can we receive a direct reverse connection?
|
||||
if !reverse_did.class.requires_signal() {
|
||||
return Ok(ContactMethod::SignalReverse(
|
||||
return ContactMethod::SignalReverse(
|
||||
inbound_relay_nr,
|
||||
target_node_ref,
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -798,16 +799,16 @@ impl NetworkManager {
|
||||
)
|
||||
.is_some();
|
||||
if target_has_udp_dialinfo && self_has_udp_dialinfo {
|
||||
return Ok(ContactMethod::SignalHolePunch(
|
||||
return ContactMethod::SignalHolePunch(
|
||||
inbound_relay_nr,
|
||||
udp_target_nr,
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
// Otherwise we have to inbound relay
|
||||
}
|
||||
|
||||
return Ok(ContactMethod::InboundRelay(inbound_relay_nr));
|
||||
return ContactMethod::InboundRelay(inbound_relay_nr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -818,22 +819,17 @@ impl NetworkManager {
|
||||
.first_filtered_dial_info_detail(Some(RoutingDomain::PublicInternet))
|
||||
.is_some()
|
||||
{
|
||||
return Ok(ContactMethod::InboundRelay(target_inbound_relay_nr));
|
||||
return ContactMethod::InboundRelay(target_inbound_relay_nr);
|
||||
}
|
||||
}
|
||||
|
||||
// If we can't reach the node by other means, try our outbound relay if we have one
|
||||
if let Some(relay_node) = self.relay_node() {
|
||||
return Ok(ContactMethod::OutboundRelay(relay_node));
|
||||
return ContactMethod::OutboundRelay(relay_node);
|
||||
}
|
||||
// Otherwise, we can't reach this node
|
||||
debug!("unable to reach node {:?}", target_node_ref);
|
||||
// trace!(
|
||||
// "unable to reach node {:?}: {}",
|
||||
// target_node_ref,
|
||||
// target_node_ref.operate(|e| format!("{:#?}", e))
|
||||
// );
|
||||
Ok(ContactMethod::Unreachable)
|
||||
ContactMethod::Unreachable
|
||||
}
|
||||
|
||||
// Send a reverse connection signal and wait for the return receipt over it
|
||||
@@ -844,13 +840,11 @@ impl NetworkManager {
|
||||
relay_nr: NodeRef,
|
||||
target_nr: NodeRef,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Build a return receipt for the signal
|
||||
let receipt_timeout =
|
||||
ms_to_us(self.config.get().network.reverse_connection_receipt_time_ms);
|
||||
let (receipt, eventual_value) = self
|
||||
.generate_single_shot_receipt(receipt_timeout, [])
|
||||
.map_err(map_to_string)?;
|
||||
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;
|
||||
|
||||
// Get our peer info
|
||||
let peer_info = self.routing_table().get_own_peer_info();
|
||||
@@ -863,32 +857,25 @@ impl NetworkManager {
|
||||
SignalInfo::ReverseConnect { receipt, peer_info },
|
||||
)
|
||||
.await
|
||||
.map_err(logthru_net!("failed to send signal to {:?}", relay_nr))
|
||||
.map_err(map_to_string)?;
|
||||
.wrap_err("failed to send signal")?;
|
||||
// Wait for the return receipt
|
||||
let inbound_nr = match eventual_value.await.take_value().unwrap() {
|
||||
ReceiptEvent::ReturnedOutOfBand => {
|
||||
return Err("reverse connect receipt should be returned in-band".to_owned());
|
||||
bail!("reverse connect receipt should be returned in-band");
|
||||
}
|
||||
ReceiptEvent::ReturnedInBand { inbound_noderef } => inbound_noderef,
|
||||
ReceiptEvent::Expired => {
|
||||
return Err(format!(
|
||||
"reverse connect receipt expired from {:?}",
|
||||
target_nr
|
||||
));
|
||||
bail!("reverse connect receipt expired from {:?}", target_nr);
|
||||
}
|
||||
ReceiptEvent::Cancelled => {
|
||||
return Err(format!(
|
||||
"reverse connect receipt cancelled from {:?}",
|
||||
target_nr
|
||||
));
|
||||
bail!("reverse connect receipt cancelled from {:?}", target_nr);
|
||||
}
|
||||
};
|
||||
|
||||
// We expect the inbound noderef to be the same as the target noderef
|
||||
// if they aren't the same, we should error on this and figure out what then hell is up
|
||||
if target_nr != inbound_nr {
|
||||
error!("unexpected noderef mismatch on reverse connect");
|
||||
bail!("unexpected noderef mismatch on reverse connect");
|
||||
}
|
||||
|
||||
// And now use the existing connection to send over
|
||||
@@ -899,10 +886,10 @@ impl NetworkManager {
|
||||
.await?
|
||||
{
|
||||
None => Ok(()),
|
||||
Some(_) => Err("unable to send over reverse connection".to_owned()),
|
||||
Some(_) => bail!("unable to send over reverse connection"),
|
||||
}
|
||||
} else {
|
||||
Err("no reverse connection available".to_owned())
|
||||
bail!("no reverse connection available")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -914,7 +901,7 @@ impl NetworkManager {
|
||||
relay_nr: NodeRef,
|
||||
target_nr: NodeRef,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Ensure we are filtered down to UDP (the only hole punch protocol supported today)
|
||||
assert!(target_nr
|
||||
.filter_ref()
|
||||
@@ -923,17 +910,14 @@ impl NetworkManager {
|
||||
|
||||
// Build a return receipt for the signal
|
||||
let receipt_timeout = ms_to_us(self.config.get().network.hole_punch_receipt_time_ms);
|
||||
let (receipt, eventual_value) = self
|
||||
.generate_single_shot_receipt(receipt_timeout, [])
|
||||
.map_err(map_to_string)?;
|
||||
|
||||
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;
|
||||
// Get our peer info
|
||||
let peer_info = self.routing_table().get_own_peer_info();
|
||||
|
||||
// Get the udp direct dialinfo for the hole punch
|
||||
let hole_punch_did = target_nr
|
||||
.first_filtered_dial_info_detail(Some(RoutingDomain::PublicInternet))
|
||||
.ok_or_else(|| "No hole punch capable dialinfo found for node".to_owned())?;
|
||||
.ok_or_else(|| eyre!("No hole punch capable dialinfo found for node"))?;
|
||||
|
||||
// Do our half of the hole punch by sending an empty packet
|
||||
// Both sides will do this and then the receipt will get sent over the punched hole
|
||||
@@ -949,30 +933,30 @@ impl NetworkManager {
|
||||
SignalInfo::HolePunch { receipt, peer_info },
|
||||
)
|
||||
.await
|
||||
.map_err(logthru_net!("failed to send signal to {:?}", relay_nr))
|
||||
.map_err(map_to_string)?;
|
||||
.wrap_err("failed to send signal")?;
|
||||
|
||||
// Wait for the return receipt
|
||||
let inbound_nr = match eventual_value.await.take_value().unwrap() {
|
||||
ReceiptEvent::ReturnedOutOfBand => {
|
||||
return Err("hole punch receipt should be returned in-band".to_owned());
|
||||
bail!("hole punch receipt should be returned in-band");
|
||||
}
|
||||
ReceiptEvent::ReturnedInBand { inbound_noderef } => inbound_noderef,
|
||||
ReceiptEvent::Expired => {
|
||||
return Err(format!("hole punch receipt expired from {}", target_nr));
|
||||
bail!("hole punch receipt expired from {}", target_nr);
|
||||
}
|
||||
ReceiptEvent::Cancelled => {
|
||||
return Err(format!("hole punch receipt cancelled from {}", target_nr));
|
||||
bail!("hole punch receipt cancelled from {}", target_nr);
|
||||
}
|
||||
};
|
||||
|
||||
// We expect the inbound noderef to be the same as the target noderef
|
||||
// if they aren't the same, we should error on this and figure out what then hell is up
|
||||
if target_nr != inbound_nr {
|
||||
return Err(format!(
|
||||
bail!(
|
||||
"unexpected noderef mismatch on hole punch {}, expected {}",
|
||||
inbound_nr, target_nr
|
||||
));
|
||||
inbound_nr,
|
||||
target_nr
|
||||
);
|
||||
}
|
||||
|
||||
// And now use the existing connection to send over
|
||||
@@ -983,10 +967,10 @@ impl NetworkManager {
|
||||
.await?
|
||||
{
|
||||
None => Ok(()),
|
||||
Some(_) => Err("unable to send over hole punch".to_owned()),
|
||||
Some(_) => bail!("unable to send over hole punch"),
|
||||
}
|
||||
} else {
|
||||
Err("no hole punch available".to_owned())
|
||||
bail!("no hole punch available")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1003,7 +987,7 @@ impl NetworkManager {
|
||||
&self,
|
||||
node_ref: NodeRef,
|
||||
data: Vec<u8>,
|
||||
) -> SystemPinBoxFuture<Result<SendDataKind, String>> {
|
||||
) -> SystemPinBoxFuture<EyreResult<SendDataKind>> {
|
||||
let this = self.clone();
|
||||
Box::pin(async move {
|
||||
// First try to send data to the last socket we've seen this peer on
|
||||
@@ -1028,11 +1012,7 @@ impl NetworkManager {
|
||||
|
||||
log_net!("send_data via dialinfo to {:?}", node_ref);
|
||||
// If we don't have last_connection, try to reach out to the peer via its dial info
|
||||
match this
|
||||
.get_contact_method(node_ref.clone())
|
||||
.map_err(logthru_net!(debug))
|
||||
.map(logthru_net!("get_contact_method for {:?}", node_ref))?
|
||||
{
|
||||
match this.get_contact_method(node_ref.clone()) {
|
||||
ContactMethod::OutboundRelay(relay_nr) | ContactMethod::InboundRelay(relay_nr) => {
|
||||
this.send_data(relay_nr, data)
|
||||
.await
|
||||
@@ -1057,14 +1037,13 @@ impl NetworkManager {
|
||||
.do_hole_punch(relay_nr, target_node_ref, data)
|
||||
.await
|
||||
.map(|_| SendDataKind::GlobalDirect),
|
||||
ContactMethod::Unreachable => Err("Can't send to this node".to_owned()),
|
||||
ContactMethod::Unreachable => Err(eyre!("Can't send to this node")),
|
||||
}
|
||||
.map_err(logthru_net!(debug))
|
||||
})
|
||||
}
|
||||
|
||||
// Direct bootstrap request handler (separate fallback mechanism from cheaper TXT bootstrap mechanism)
|
||||
async fn handle_boot_request(&self, descriptor: ConnectionDescriptor) -> Result<(), String> {
|
||||
async fn handle_boot_request(&self, descriptor: ConnectionDescriptor) -> EyreResult<()> {
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
// Get a bunch of nodes with the various
|
||||
@@ -1087,12 +1066,13 @@ impl NetworkManager {
|
||||
// Bootstrap reply was sent
|
||||
Ok(())
|
||||
}
|
||||
Some(_) => Err("bootstrap reply could not be sent".to_owned()),
|
||||
Some(_) => Err(eyre!("bootstrap reply could not be sent")),
|
||||
}
|
||||
}
|
||||
|
||||
// Direct bootstrap request
|
||||
pub async fn boot_request(&self, dial_info: DialInfo) -> Result<Vec<PeerInfo>, String> {
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn boot_request(&self, dial_info: DialInfo) -> EyreResult<Vec<PeerInfo>> {
|
||||
let timeout_ms = {
|
||||
let c = self.config.get();
|
||||
c.network.rpc.timeout_ms
|
||||
@@ -1105,8 +1085,8 @@ impl NetworkManager {
|
||||
.await?;
|
||||
|
||||
let bootstrap_peerinfo: Vec<PeerInfo> =
|
||||
deserialize_json(std::str::from_utf8(&out_data).map_err(map_to_string)?)
|
||||
.map_err(map_to_string)?;
|
||||
deserialize_json(std::str::from_utf8(&out_data).wrap_err("bad utf8 in boot peerinfo")?)
|
||||
.wrap_err("failed to deserialize boot peerinfo")?;
|
||||
|
||||
Ok(bootstrap_peerinfo)
|
||||
}
|
||||
@@ -1119,7 +1099,7 @@ impl NetworkManager {
|
||||
&self,
|
||||
data: &[u8],
|
||||
descriptor: ConnectionDescriptor,
|
||||
) -> Result<bool, String> {
|
||||
) -> EyreResult<bool> {
|
||||
log_net!(
|
||||
"envelope of {} bytes received from {:?}",
|
||||
data.len(),
|
||||
@@ -1138,7 +1118,7 @@ impl NetworkManager {
|
||||
|
||||
// Ensure we can read the magic number
|
||||
if data.len() < 4 {
|
||||
return Err("short packet".to_owned());
|
||||
bail!("short packet");
|
||||
}
|
||||
|
||||
// Is this a direct bootstrap request instead of an envelope?
|
||||
@@ -1154,13 +1134,7 @@ impl NetworkManager {
|
||||
}
|
||||
|
||||
// Decode envelope header (may fail signature validation)
|
||||
let envelope = Envelope::from_signed_data(data).map_err(|_| {
|
||||
format!(
|
||||
"envelope failed to decode from {:?}: {} bytes",
|
||||
descriptor,
|
||||
data.len()
|
||||
)
|
||||
})?;
|
||||
let envelope = Envelope::from_signed_data(data).wrap_err("envelope failed to decode")?;
|
||||
|
||||
// Get routing table and rpc processor
|
||||
let (routing_table, rpc) = {
|
||||
@@ -1185,18 +1159,18 @@ impl NetworkManager {
|
||||
let ets = envelope.get_timestamp();
|
||||
if let Some(tsbehind) = tsbehind {
|
||||
if tsbehind > 0 && (ts > ets && ts - ets > tsbehind) {
|
||||
return Err(format!(
|
||||
bail!(
|
||||
"envelope time was too far in the past: {}ms ",
|
||||
timestamp_to_secs(ts - ets) * 1000f64
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(tsahead) = tsahead {
|
||||
if tsahead > 0 && (ts < ets && ets - ts > tsahead) {
|
||||
return Err(format!(
|
||||
bail!(
|
||||
"envelope time was too far in the future: {}ms",
|
||||
timestamp_to_secs(ets - ts) * 1000f64
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1211,12 +1185,9 @@ impl NetworkManager {
|
||||
|
||||
let relay_nr = if self.check_client_whitelist(sender_id) {
|
||||
// Full relay allowed, do a full resolve_node
|
||||
rpc.resolve_node(recipient_id).await.map_err(|e| {
|
||||
format!(
|
||||
"failed to resolve recipient node for relay, dropping outbound relayed packet...: {:?}",
|
||||
e
|
||||
)
|
||||
})?
|
||||
rpc.resolve_node(recipient_id).await.wrap_err(
|
||||
"failed to resolve recipient node for relay, dropping outbound relayed packet",
|
||||
)?
|
||||
} else {
|
||||
// If this is not a node in the client whitelist, only allow inbound relay
|
||||
// which only performs a lightweight lookup before passing the packet back out
|
||||
@@ -1226,7 +1197,7 @@ impl NetworkManager {
|
||||
// should be mutually in each others routing tables. The node needing the relay will be
|
||||
// pinging this node regularly to keep itself in the routing table
|
||||
routing_table.lookup_node_ref(recipient_id).ok_or_else(|| {
|
||||
format!(
|
||||
eyre!(
|
||||
"Inbound relay asked for recipient not in routing table: sender_id={:?} recipient={:?}",
|
||||
sender_id, recipient_id
|
||||
)
|
||||
@@ -1236,7 +1207,7 @@ impl NetworkManager {
|
||||
// Relay the packet to the desired destination
|
||||
self.send_data(relay_nr, data.to_vec())
|
||||
.await
|
||||
.map_err(|e| format!("failed to forward envelope: {}", e))?;
|
||||
.wrap_err("failed to forward envelope")?;
|
||||
// Inform caller that we dealt with the envelope, but did not process it locally
|
||||
return Ok(false);
|
||||
}
|
||||
@@ -1248,19 +1219,20 @@ impl NetworkManager {
|
||||
// xxx: punish nodes that send messages that fail to decrypt eventually
|
||||
let body = envelope
|
||||
.decrypt_body(self.crypto(), data, &node_id_secret)
|
||||
.map_err(|_| "failed to decrypt envelope body".to_owned())?;
|
||||
.wrap_err("failed to decrypt envelope body")?;
|
||||
|
||||
// Cache the envelope information in the routing table
|
||||
let source_noderef = routing_table
|
||||
.register_node_with_existing_connection(envelope.get_sender_id(), descriptor, ts)
|
||||
.map_err(|e| format!("node id registration failed: {}", e))?;
|
||||
let source_noderef = routing_table.register_node_with_existing_connection(
|
||||
envelope.get_sender_id(),
|
||||
descriptor,
|
||||
ts,
|
||||
)?;
|
||||
source_noderef.operate_mut(|e| e.set_min_max_version(envelope.get_min_max_version()));
|
||||
|
||||
// xxx: deal with spoofing and flooding here?
|
||||
|
||||
// Pass message to RPC system
|
||||
rpc.enqueue_message(envelope, body, source_noderef)
|
||||
.map_err(|e| format!("enqueing rpc message failed: {}", e))?;
|
||||
rpc.enqueue_message(envelope, body, source_noderef)?;
|
||||
|
||||
// Inform caller that we dealt with the envelope locally
|
||||
Ok(true)
|
||||
@@ -1273,7 +1245,7 @@ impl NetworkManager {
|
||||
stop_token: StopToken,
|
||||
_last_ts: u64,
|
||||
cur_ts: u64,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Get our node's current node info and network class and do the right thing
|
||||
let routing_table = self.routing_table();
|
||||
let node_info = routing_table.get_own_node_info();
|
||||
@@ -1354,7 +1326,7 @@ impl NetworkManager {
|
||||
stop_token: StopToken,
|
||||
last_ts: u64,
|
||||
cur_ts: u64,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// log_net!("--- network manager rolling_transfers task");
|
||||
{
|
||||
let inner = &mut *self.inner.lock();
|
||||
|
@@ -61,7 +61,7 @@ struct NetworkUnlockedInner {
|
||||
network_manager: NetworkManager,
|
||||
connection_manager: ConnectionManager,
|
||||
// Background processes
|
||||
update_network_class_task: TickTask,
|
||||
update_network_class_task: TickTask<EyreReport>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -266,7 +266,7 @@ impl Network {
|
||||
// See if our interface addresses have changed, if so we need to punt the network
|
||||
// and redo all our addresses. This is overkill, but anything more accurate
|
||||
// would require inspection of routing tables that we dont want to bother with
|
||||
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
|
||||
pub async fn check_interface_addresses(&self) -> EyreResult<bool> {
|
||||
let mut inner = self.inner.lock();
|
||||
if !inner.interfaces.refresh().await? {
|
||||
return Ok(false);
|
||||
@@ -286,7 +286,7 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let data_len = data.len();
|
||||
let res = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
@@ -300,7 +300,8 @@ impl Network {
|
||||
ProtocolType::WS | ProtocolType::WSS => {
|
||||
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
|
||||
}
|
||||
};
|
||||
}
|
||||
.wrap_err("low level network error");
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -320,7 +321,7 @@ impl Network {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> EyreResult<Vec<u8>> {
|
||||
let data_len = data.len();
|
||||
let out = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
@@ -358,7 +359,7 @@ impl Network {
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
) -> EyreResult<Option<Vec<u8>>> {
|
||||
let data_len = data.len();
|
||||
|
||||
// Handle connectionless protocol
|
||||
@@ -369,12 +370,10 @@ impl Network {
|
||||
&peer_socket_addr,
|
||||
&descriptor.local().map(|sa| sa.to_socket_addr()),
|
||||
) {
|
||||
log_net!(
|
||||
"send_data_to_existing_connection connectionless to {:?}",
|
||||
descriptor
|
||||
);
|
||||
|
||||
ph.clone().send_message(data, peer_socket_addr).await?;
|
||||
ph.clone()
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.wrap_err("sending data to existing conection")?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -389,10 +388,10 @@ impl Network {
|
||||
|
||||
// Try to send to the exact existing connection if one exists
|
||||
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
|
||||
log_net!("send_data_to_existing_connection to {:?}", descriptor);
|
||||
|
||||
// connection exists, send over it
|
||||
conn.send_async(data).await?;
|
||||
conn.send_async(data)
|
||||
.await
|
||||
.wrap_err("sending data to existing connection")?;
|
||||
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -413,13 +412,16 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let data_len = data.len();
|
||||
// Handle connectionless protocol
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
|
||||
let res = ph.send_message(data, peer_socket_addr).await;
|
||||
let res = ph
|
||||
.send_message(data, peer_socket_addr)
|
||||
.await
|
||||
.wrap_err("failed to send data to dial info");
|
||||
if res.is_ok() {
|
||||
// Network accounting
|
||||
self.network_manager()
|
||||
@@ -427,7 +429,7 @@ impl Network {
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
|
||||
bail!("no appropriate UDP protocol handler for dial_info");
|
||||
}
|
||||
|
||||
// Handle connection-oriented protocols
|
||||
@@ -453,7 +455,7 @@ impl Network {
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", err, skip_all)]
|
||||
pub async fn startup(&self) -> Result<(), String> {
|
||||
pub async fn startup(&self) -> EyreResult<()> {
|
||||
// initialize interfaces
|
||||
let mut interfaces = NetworkInterfaces::new();
|
||||
interfaces.refresh().await?;
|
||||
@@ -604,7 +606,7 @@ impl Network {
|
||||
|
||||
//////////////////////////////////////////
|
||||
|
||||
pub async fn tick(&self) -> Result<(), String> {
|
||||
pub async fn tick(&self) -> EyreResult<()> {
|
||||
let network_class = self.get_network_class().unwrap_or(NetworkClass::Invalid);
|
||||
let routing_table = self.routing_table();
|
||||
|
||||
|
@@ -250,7 +250,7 @@ impl DiscoveryContext {
|
||||
|
||||
// If we know we are not behind NAT, check our firewall status
|
||||
#[instrument(level = "trace", skip(self), err)]
|
||||
pub async fn protocol_process_no_nat(&self) -> Result<(), String> {
|
||||
pub async fn protocol_process_no_nat(&self) -> EyreResult<()> {
|
||||
let (node_1, external_1_dial_info) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
@@ -281,7 +281,7 @@ impl DiscoveryContext {
|
||||
|
||||
// If we know we are behind NAT check what kind
|
||||
#[instrument(level = "trace", skip(self), ret, err)]
|
||||
pub async fn protocol_process_nat(&self) -> Result<bool, String> {
|
||||
pub async fn protocol_process_nat(&self) -> EyreResult<bool> {
|
||||
let (node_1, external_1_dial_info, external_1_address, protocol_type, address_type) = {
|
||||
let inner = self.inner.lock();
|
||||
(
|
||||
@@ -375,7 +375,7 @@ impl Network {
|
||||
&self,
|
||||
context: &DiscoveryContext,
|
||||
protocol_type: ProtocolType,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let mut retry_count = {
|
||||
let c = self.config.get();
|
||||
c.network.restricted_nat_retries
|
||||
@@ -437,7 +437,7 @@ impl Network {
|
||||
&self,
|
||||
context: &DiscoveryContext,
|
||||
protocol_type: ProtocolType,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Start doing ipv6 protocol
|
||||
context.protocol_begin(protocol_type, AddressType::IPV6);
|
||||
|
||||
@@ -479,7 +479,7 @@ impl Network {
|
||||
stop_token: StopToken,
|
||||
_l: u64,
|
||||
_t: u64,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
// Ensure we aren't trying to update this without clearing it first
|
||||
let old_network_class = self.inner.lock().network_class;
|
||||
assert_eq!(old_network_class, None);
|
||||
|
@@ -25,14 +25,14 @@ impl ListenerState {
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
impl Network {
|
||||
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
|
||||
fn get_or_create_tls_acceptor(&self) -> EyreResult<TlsAcceptor> {
|
||||
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
|
||||
return Ok(ts.clone());
|
||||
}
|
||||
|
||||
let server_config = self
|
||||
.load_server_config()
|
||||
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
|
||||
.wrap_err("Couldn't create TLS configuration")?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(server_config));
|
||||
self.inner.lock().tls_acceptor = Some(acceptor.clone());
|
||||
Ok(acceptor)
|
||||
@@ -45,12 +45,11 @@ impl Network {
|
||||
addr: SocketAddr,
|
||||
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
tls_connection_initial_timeout_ms: u32,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> EyreResult<Option<ProtocolNetworkConnection>> {
|
||||
let tls_stream = tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
|
||||
.wrap_err("TLS stream failed handshake")?;
|
||||
let ps = AsyncPeekStream::new(tls_stream);
|
||||
let mut first_packet = [0u8; PEEK_DETECT_LEN];
|
||||
|
||||
@@ -63,8 +62,8 @@ impl Network {
|
||||
ps.peek_exact(&mut first_packet),
|
||||
)
|
||||
.await
|
||||
.map_err(map_to_string)?
|
||||
.map_err(map_to_string)?;
|
||||
.wrap_err("tls initial timeout")?
|
||||
.wrap_err("failed to peek tls stream")?;
|
||||
|
||||
self.try_handlers(ps, addr, protocol_handlers).await
|
||||
}
|
||||
@@ -74,9 +73,13 @@ impl Network {
|
||||
stream: AsyncPeekStream,
|
||||
addr: SocketAddr,
|
||||
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> EyreResult<Option<ProtocolNetworkConnection>> {
|
||||
for ah in protocol_accept_handlers.iter() {
|
||||
if let Some(nc) = ah.on_accept(stream.clone(), addr).await? {
|
||||
if let Some(nc) = ah
|
||||
.on_accept(stream.clone(), addr)
|
||||
.await
|
||||
.wrap_err("io error")?
|
||||
{
|
||||
return Ok(Some(nc));
|
||||
}
|
||||
}
|
||||
@@ -114,7 +117,7 @@ impl Network {
|
||||
return;
|
||||
}
|
||||
};
|
||||
// XXX limiting
|
||||
// XXX limiting here instead for connection table? may be faster and avoids tls negotiation
|
||||
|
||||
log_net!("TCP connection from: {}", addr);
|
||||
|
||||
@@ -185,7 +188,7 @@ impl Network {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<()> {
|
||||
// Get config
|
||||
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = {
|
||||
let c = self.config.get();
|
||||
@@ -196,11 +199,12 @@ impl Network {
|
||||
};
|
||||
|
||||
// Create a reusable socket with no linger time, and no delay
|
||||
let socket = new_bound_shared_tcp_socket(addr)?;
|
||||
let socket = new_bound_shared_tcp_socket(addr)
|
||||
.wrap_err("failed to create bound shared tcp socket")?;
|
||||
// Listen on the socket
|
||||
socket
|
||||
.listen(128)
|
||||
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
|
||||
.wrap_err("Couldn't listen on TCP socket")?;
|
||||
|
||||
// Make an async tcplistener from the socket2 socket
|
||||
let std_listener: std::net::TcpListener = socket.into();
|
||||
@@ -209,7 +213,7 @@ impl Network {
|
||||
let listener = TcpListener::from(std_listener);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let listener = TcpListener::from_std(std_listener).map_err(map_to_string)?;
|
||||
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +283,7 @@ impl Network {
|
||||
port: u16,
|
||||
is_tls: bool,
|
||||
new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
|
||||
) -> Result<Vec<SocketAddress>, String> {
|
||||
) -> EyreResult<Vec<SocketAddress>> {
|
||||
let mut out = Vec::<SocketAddress>::new();
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
|
@@ -3,7 +3,7 @@ use sockets::*;
|
||||
use stop_token::future::FutureExt;
|
||||
|
||||
impl Network {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
|
||||
pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> {
|
||||
// Spawn socket tasks
|
||||
let mut task_count = {
|
||||
let c = self.config.get();
|
||||
@@ -73,7 +73,7 @@ impl Network {
|
||||
.on_recv_envelope(&data[..size], descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error "failed to process received udp envelope: {}", e);
|
||||
log_net!(debug "failed to process received udp envelope: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(Err(_)) => {
|
||||
@@ -110,7 +110,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
|
||||
pub(super) async fn create_udp_outbound_sockets(&self) -> EyreResult<()> {
|
||||
let mut inner = self.inner.lock();
|
||||
let mut port = inner.udp_port;
|
||||
// v4
|
||||
@@ -119,9 +119,9 @@ impl Network {
|
||||
// Pull the port if we randomly bound, so v6 can be on the same port
|
||||
port = socket
|
||||
.local_addr()
|
||||
.map_err(map_to_string)?
|
||||
.wrap_err("failed to get local address")?
|
||||
.as_socket_ipv4()
|
||||
.ok_or_else(|| "expected ipv4 address type".to_owned())?
|
||||
.ok_or_else(|| eyre!("expected ipv4 address type"))?
|
||||
.port();
|
||||
|
||||
// Make an async UdpSocket from the socket2 socket
|
||||
@@ -131,7 +131,7 @@ impl Network {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make outbound v4 tokio udpsocket")?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
@@ -152,7 +152,7 @@ impl Network {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make outbound v6 tokio udpsocket")?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
@@ -166,7 +166,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
|
||||
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> EyreResult<()> {
|
||||
log_net!("create_udp_inbound_socket on {:?}", &addr);
|
||||
|
||||
// Create a reusable socket
|
||||
@@ -179,7 +179,7 @@ impl Network {
|
||||
let udp_socket = UdpSocket::from(std_udp_socket);
|
||||
} else if #[cfg(feature="rt-tokio")] {
|
||||
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
|
||||
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
|
||||
}
|
||||
}
|
||||
let socket_arc = Arc::new(udp_socket);
|
||||
@@ -200,7 +200,7 @@ impl Network {
|
||||
&self,
|
||||
ip_addrs: Vec<IpAddr>,
|
||||
port: u16,
|
||||
) -> Result<Vec<DialInfo>, String> {
|
||||
) -> EyreResult<Vec<DialInfo>> {
|
||||
let mut out = Vec::<DialInfo>::new();
|
||||
|
||||
for ip_addr in ip_addrs {
|
||||
|
@@ -6,6 +6,7 @@ pub mod ws;
|
||||
|
||||
use super::*;
|
||||
use crate::xx::*;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProtocolNetworkConnection {
|
||||
@@ -21,7 +22,7 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("Should not connect to UDP dialinfo");
|
||||
@@ -35,7 +36,7 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
@@ -55,7 +56,7 @@ impl ProtocolNetworkConnection {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
let peer_socket_addr = dial_info.to_socket_addr();
|
||||
@@ -102,7 +103,7 @@ impl ProtocolNetworkConnection {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::RawTcp(t) => t.send(message).await,
|
||||
@@ -111,7 +112,7 @@ impl ProtocolNetworkConnection {
|
||||
Self::Wss(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::RawTcp(t) => t.recv().await,
|
||||
|
@@ -1,6 +1,8 @@
|
||||
use crate::xx::*;
|
||||
use crate::*;
|
||||
use async_io::Async;
|
||||
use std::io;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature="rt-async-std")] {
|
||||
pub use async_std::net::{TcpStream, TcpListener, Shutdown, UdpSocket};
|
||||
@@ -19,12 +21,12 @@ cfg_if! {
|
||||
use winapi::ctypes::c_int;
|
||||
use std::os::windows::io::AsRawSocket;
|
||||
|
||||
fn set_exclusiveaddruse(socket: &Socket) -> Result<(), String> {
|
||||
fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
|
||||
unsafe {
|
||||
let optval:c_int = 1;
|
||||
if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
|
||||
std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
|
||||
return Err("Unable to SO_EXCLUSIVEADDRUSE".to_owned());
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -32,49 +34,37 @@ cfg_if! {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
pub fn new_unbound_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_address(true)?;
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = new_unbound_shared_udp_socket(domain)?;
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket.bind(&socket2_addr).map_err(|e| {
|
||||
format!(
|
||||
"failed to bind UDP socket to '{}' in domain '{:?}': {} ",
|
||||
local_address, domain, e
|
||||
)
|
||||
})?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
log_net!("created bound shared udp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
// Bind the socket -first- before turning on 'reuse address' this way it will
|
||||
// fail if the port is already taken
|
||||
@@ -87,18 +77,15 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
}
|
||||
}
|
||||
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind UDP socket: {}", e))?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
// Set 'reuse address' so future binds to this port will succeed
|
||||
// This does not work on Windows, where reuse options can not be set after the bind
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
.set_reuse_address(true)?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
log_net!("created bound first udp socket on {:?}", &local_address);
|
||||
@@ -106,10 +93,8 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
@@ -117,43 +102,33 @@ pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_address(true)?;
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = new_unbound_shared_tcp_socket(domain)?;
|
||||
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
log_net!("created bound shared tcp socket on {:?}", &local_address);
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
|
||||
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> io::Result<Socket> {
|
||||
let domain = Domain::for_address(local_address);
|
||||
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed to create TCP socket"))?;
|
||||
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
|
||||
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
|
||||
log_net!(error "Couldn't set TCP linger: {}", e);
|
||||
}
|
||||
@@ -161,9 +136,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
log_net!(error "Couldn't set TCP nodelay: {}", e);
|
||||
}
|
||||
if domain == Domain::IPV6 {
|
||||
socket
|
||||
.set_only_v6(true)
|
||||
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
|
||||
@@ -176,18 +149,15 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
// Bind the socket -first- before turning on 'reuse address' this way it will
|
||||
// fail if the port is already taken
|
||||
let socket2_addr = SockAddr::from(local_address);
|
||||
socket
|
||||
.bind(&socket2_addr)
|
||||
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
|
||||
socket.bind(&socket2_addr)?;
|
||||
|
||||
// Set 'reuse address' so future binds to this port will succeed
|
||||
// This does not work on Windows, where reuse options can not be set after the bind
|
||||
cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
|
||||
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
|
||||
.set_reuse_address(true)?;
|
||||
socket.set_reuse_port(true)?;
|
||||
}
|
||||
}
|
||||
log_net!("created bound first tcp socket on {:?}", &local_address);
|
||||
@@ -196,7 +166,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
|
||||
}
|
||||
|
||||
// Non-blocking connect is tricky when you want to start with a prepared socket
|
||||
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::Result<TcpStream> {
|
||||
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> io::Result<TcpStream> {
|
||||
// Set for non blocking connect
|
||||
socket.set_nonblocking(true)?;
|
||||
|
||||
|
@@ -42,47 +42,45 @@ impl RawTcpNetworkConnection {
|
||||
// }
|
||||
// }
|
||||
|
||||
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
|
||||
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> io::Result<()> {
|
||||
log_net!("sending TCP message of size {}", message.len());
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large TCP message".to_owned());
|
||||
bail_io_error_other!("sending too large TCP message");
|
||||
}
|
||||
let len = message.len() as u16;
|
||||
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
|
||||
|
||||
stream.write_all(&header).await.map_err(map_to_string)?;
|
||||
stream.write_all(&message).await.map_err(map_to_string)
|
||||
stream.write_all(&header).await?;
|
||||
stream.write_all(&message).await
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
let mut stream = self.stream.clone();
|
||||
Self::send_internal(&mut stream, message).await
|
||||
}
|
||||
|
||||
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<Vec<u8>> {
|
||||
let mut header = [0u8; 4];
|
||||
|
||||
stream
|
||||
.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| format!("TCP recv error: {}", e))?;
|
||||
stream.read_exact(&mut header).await?;
|
||||
|
||||
if header[0] != b'V' || header[1] != b'L' {
|
||||
return Err("received invalid TCP frame header".to_owned());
|
||||
bail_io_error_other!("received invalid TCP frame header");
|
||||
}
|
||||
let len = ((header[3] as usize) << 8) | (header[2] as usize);
|
||||
if len > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large TCP frame".to_owned());
|
||||
bail_io_error_other!("received too large TCP frame");
|
||||
}
|
||||
|
||||
let mut out: Vec<u8> = vec![0u8; len];
|
||||
stream.read_exact(&mut out).await.map_err(map_to_string)?;
|
||||
stream.read_exact(&mut out).await?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[instrument(level="trace", err, skip(self), fields(ret.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
let mut stream = self.stream.clone();
|
||||
let out = Self::recv_internal(&mut stream).await?;
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
@@ -121,14 +119,10 @@ impl RawTcpProtocolHandler {
|
||||
self,
|
||||
stream: AsyncPeekStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> io::Result<Option<ProtocolNetworkConnection>> {
|
||||
log_net!("TCP: on_accept_async: enter");
|
||||
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
|
||||
let peeklen = stream
|
||||
.peek(&mut peekbuf)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not peek tcp stream"))?;
|
||||
let peeklen = stream.peek(&mut peekbuf).await?;
|
||||
assert_eq!(peeklen, PEEK_DETECT_LEN);
|
||||
|
||||
let peer_addr = PeerAddress::new(
|
||||
@@ -150,7 +144,7 @@ impl RawTcpProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
// Get remote socket address to connect to
|
||||
let remote_socket_addr = dial_info.to_socket_addr();
|
||||
|
||||
@@ -163,15 +157,10 @@ impl RawTcpProtocolHandler {
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, remote_socket_addr).await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
let ts = nonblocking_connect(socket, remote_socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
let actual_local_address = ts
|
||||
.local_addr()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("could not get local address from TCP stream"))?;
|
||||
let actual_local_address = ts.local_addr()?;
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let ts = ts.compat();
|
||||
let ps = AsyncPeekStream::new(ts);
|
||||
@@ -189,12 +178,9 @@ impl RawTcpProtocolHandler {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound TCP message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound TCP message");
|
||||
}
|
||||
trace!(
|
||||
"sending unbound message of length {} to {}",
|
||||
@@ -206,10 +192,7 @@ impl RawTcpProtocolHandler {
|
||||
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
// let actual_local_address = ts
|
||||
@@ -231,9 +214,9 @@ impl RawTcpProtocolHandler {
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound TCP message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound TCP message");
|
||||
}
|
||||
trace!(
|
||||
"sending unbound message of length {} to {}",
|
||||
@@ -245,10 +228,7 @@ impl RawTcpProtocolHandler {
|
||||
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let ts = nonblocking_connect(socket, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
|
||||
let ts = nonblocking_connect(socket, socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with and turn this into a stream
|
||||
// let actual_local_address = ts
|
||||
@@ -265,7 +245,7 @@ impl RawTcpProtocolHandler {
|
||||
|
||||
let out = timeout(timeout_ms, RawTcpNetworkConnection::recv_internal(&mut ps))
|
||||
.await
|
||||
.map_err(map_to_string)??;
|
||||
.map_err(|e| e.to_io())??;
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
@@ -277,7 +257,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -12,27 +12,30 @@ impl RawUdpProtocolHandler {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn recv_message(
|
||||
&self,
|
||||
data: &mut [u8],
|
||||
) -> Result<(usize, ConnectionDescriptor), String> {
|
||||
let (size, remote_addr) = self.socket.recv_from(data).await.map_err(map_to_string)?;
|
||||
|
||||
if size > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large UDP message".to_owned());
|
||||
}
|
||||
|
||||
trace!(
|
||||
"receiving UDP message of length {} from {}",
|
||||
size,
|
||||
remote_addr
|
||||
);
|
||||
pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> {
|
||||
let (size, remote_addr) = loop {
|
||||
match self.socket.recv_from(data).await {
|
||||
Ok((size, remote_addr)) => {
|
||||
if size > MAX_MESSAGE_SIZE {
|
||||
bail_io_error_other!("received too large UDP message");
|
||||
}
|
||||
break (size, remote_addr);
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::ConnectionReset {
|
||||
// Ignore icmp
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = PeerAddress::new(
|
||||
SocketAddress::from_socket_addr(remote_addr),
|
||||
ProtocolType::UDP,
|
||||
);
|
||||
let local_socket_addr = self.socket.local_addr().map_err(map_to_string)?;
|
||||
let local_socket_addr = self.socket.local_addr()?;
|
||||
let descriptor = ConnectionDescriptor::new(
|
||||
peer_addr,
|
||||
SocketAddress::from_socket_addr(local_socket_addr),
|
||||
@@ -44,45 +47,24 @@ impl RawUdpProtocolHandler {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
|
||||
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large UDP message");
|
||||
}
|
||||
|
||||
log_net!(
|
||||
"sending UDP message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
let len = self
|
||||
.socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed udp send: addr={}", socket_addr))?;
|
||||
|
||||
let len = self.socket.send_to(&data, socket_addr).await?;
|
||||
if len != data.len() {
|
||||
Err("UDP partial send".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(())
|
||||
bail_io_error_other!("UDP partial send")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound UDP message".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large unbound UDP message");
|
||||
}
|
||||
log_net!(
|
||||
"sending unbound message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
// get local wildcard address for bind
|
||||
let local_socket_addr = match socket_addr {
|
||||
@@ -91,20 +73,13 @@ impl RawUdpProtocolHandler {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||
}
|
||||
};
|
||||
let socket = UdpSocket::bind(local_socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
|
||||
let len = socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
|
||||
let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
let len = socket.send_to(&data, socket_addr).await?;
|
||||
if len != data.len() {
|
||||
Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
Ok(())
|
||||
bail_io_error_other!("UDP partial unbound send")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))]
|
||||
@@ -112,16 +87,10 @@ impl RawUdpProtocolHandler {
|
||||
socket_addr: SocketAddr,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound UDP message".to_owned())
|
||||
.map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large unbound UDP message");
|
||||
}
|
||||
log_net!(
|
||||
"sending unbound message of length {} to {}",
|
||||
data.len(),
|
||||
socket_addr
|
||||
);
|
||||
|
||||
// get local wildcard address for bind
|
||||
let local_socket_addr = match socket_addr {
|
||||
@@ -132,29 +101,21 @@ impl RawUdpProtocolHandler {
|
||||
};
|
||||
|
||||
// get unspecified bound socket
|
||||
let socket = UdpSocket::bind(local_socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
|
||||
let len = socket
|
||||
.send_to(&data, socket_addr)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
|
||||
let socket = UdpSocket::bind(local_socket_addr).await?;
|
||||
let len = socket.send_to(&data, socket_addr).await?;
|
||||
if len != data.len() {
|
||||
return Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error));
|
||||
bail_io_error_other!("UDP partial unbound send");
|
||||
}
|
||||
|
||||
// receive single response
|
||||
let mut out = vec![0u8; MAX_MESSAGE_SIZE];
|
||||
let (len, from_addr) = timeout(timeout_ms, socket.recv_from(&mut out))
|
||||
.await
|
||||
.map_err(map_to_string)?
|
||||
.map_err(map_to_string)?;
|
||||
.map_err(|e| e.to_io())??;
|
||||
|
||||
// if the from address is not the same as the one we sent to, then drop this
|
||||
if from_addr != socket_addr {
|
||||
return Err(format!(
|
||||
bail_io_error_other!(format!(
|
||||
"Unbound response received from wrong address: addr={}",
|
||||
from_addr,
|
||||
));
|
||||
|
@@ -17,6 +17,25 @@ cfg_if! {
|
||||
}
|
||||
}
|
||||
|
||||
fn to_io(err: async_tungstenite::tungstenite::Error) -> io::Error {
|
||||
let kind = match err {
|
||||
async_tungstenite::tungstenite::Error::ConnectionClosed => io::ErrorKind::ConnectionReset,
|
||||
async_tungstenite::tungstenite::Error::AlreadyClosed => io::ErrorKind::NotConnected,
|
||||
async_tungstenite::tungstenite::Error::Io(x) => {
|
||||
return x;
|
||||
}
|
||||
async_tungstenite::tungstenite::Error::Tls(_) => io::ErrorKind::InvalidData,
|
||||
async_tungstenite::tungstenite::Error::Capacity(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Protocol(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::SendQueueFull(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Utf8 => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Url(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::Http(_) => io::ErrorKind::Other,
|
||||
async_tungstenite::tungstenite::Error::HttpFormat(_) => io::ErrorKind::Other,
|
||||
};
|
||||
io::Error::new(kind, err)
|
||||
}
|
||||
|
||||
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
|
||||
|
||||
pub struct WebsocketNetworkConnection<T>
|
||||
@@ -62,41 +81,49 @@ where
|
||||
// }
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("received too large WS message".to_owned());
|
||||
bail_io_error_other!("received too large WS message");
|
||||
}
|
||||
self.stream
|
||||
.clone()
|
||||
.send(Message::binary(message))
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "failed to send websocket message"))
|
||||
.map_err(to_io)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self), fields(ret.len))]
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
let out = match self.stream.clone().next().await {
|
||||
Some(Ok(Message::Binary(v))) => v,
|
||||
Some(Ok(Message::Close(e))) => {
|
||||
return Err(format!("WS connection closed: {:?}", e));
|
||||
Some(Ok(Message::Binary(v))) => {
|
||||
if v.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"too large ws message",
|
||||
));
|
||||
}
|
||||
v
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
return Err(io::Error::new(io::ErrorKind::ConnectionReset, "closeframe"))
|
||||
}
|
||||
Some(Ok(x)) => {
|
||||
return Err(format!("Unexpected WS message type: {:?}", x));
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
return Err(e.to_string()).map_err(logthru_net!(error));
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Unexpected WS message type: {:?}", x),
|
||||
));
|
||||
}
|
||||
Some(Err(e)) => return Err(to_io(e)),
|
||||
None => {
|
||||
return Err("WS stream closed".to_owned());
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"connection ended",
|
||||
))
|
||||
}
|
||||
};
|
||||
if out.len() > MAX_MESSAGE_SIZE {
|
||||
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
|
||||
} else {
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,21 +172,18 @@ impl WebsocketProtocolHandler {
|
||||
self,
|
||||
ps: AsyncPeekStream,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<Option<ProtocolNetworkConnection>, String> {
|
||||
) -> io::Result<Option<ProtocolNetworkConnection>> {
|
||||
log_net!("WS: on_accept_async: enter");
|
||||
let request_path_len = self.arc.request_path.len() + 2;
|
||||
|
||||
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
|
||||
match timeout(
|
||||
if let Err(_) = timeout(
|
||||
self.arc.connection_initial_timeout_ms,
|
||||
ps.peek_exact(&mut peekbuf),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
return Err(e.to_string());
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Check for websocket path
|
||||
@@ -169,15 +193,12 @@ impl WebsocketProtocolHandler {
|
||||
&& peekbuf[request_path_len - 1] == b' '));
|
||||
|
||||
if !matches_path {
|
||||
log_net!("WS: not websocket");
|
||||
return Ok(None);
|
||||
}
|
||||
log_net!("WS: found websocket");
|
||||
|
||||
let ws_stream = accept_async(ps)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!("failed websockets handshake"))?;
|
||||
.map_err(|e| io_error_other!(format!("failed websockets handshake: {}", e)))?;
|
||||
|
||||
// Wrap the websocket in a NetworkConnection and register it
|
||||
let protocol_type = if self.arc.tls {
|
||||
@@ -205,7 +226,7 @@ impl WebsocketProtocolHandler {
|
||||
async fn connect_internal(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
// Split dial info up
|
||||
let (tls, scheme) = match &dial_info {
|
||||
DialInfo::WS(_) => (false, "ws"),
|
||||
@@ -213,9 +234,9 @@ impl WebsocketProtocolHandler {
|
||||
_ => panic!("invalid dialinfo for WS/WSS protocol"),
|
||||
};
|
||||
let request = dial_info.request().unwrap();
|
||||
let split_url = SplitUrl::from_str(&request)?;
|
||||
let split_url = SplitUrl::from_str(&request).map_err(to_io_error_other)?;
|
||||
if split_url.scheme != scheme {
|
||||
return Err("invalid websocket url scheme".to_string());
|
||||
bail_io_error_other!("invalid websocket url scheme");
|
||||
}
|
||||
let domain = split_url.host.clone();
|
||||
|
||||
@@ -231,12 +252,10 @@ impl WebsocketProtocolHandler {
|
||||
};
|
||||
|
||||
// Non-blocking connect to remote address
|
||||
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
|
||||
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await?;
|
||||
|
||||
// See what local address we ended up with
|
||||
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
|
||||
let actual_local_addr = tcp_stream.local_addr()?;
|
||||
|
||||
#[cfg(feature = "rt-tokio")]
|
||||
let tcp_stream = tcp_stream.compat();
|
||||
@@ -249,15 +268,10 @@ impl WebsocketProtocolHandler {
|
||||
// Negotiate TLS if this is WSS
|
||||
if tls {
|
||||
let connector = TlsConnector::default();
|
||||
let tls_stream = connector
|
||||
.connect(domain.to_string(), tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let tls_stream = connector.connect(domain.to_string(), tcp_stream).await?;
|
||||
let (ws_stream, _response) = client_async(request, tls_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.map_err(to_io_error_other)?;
|
||||
|
||||
Ok(ProtocolNetworkConnection::Wss(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
@@ -265,8 +279,7 @@ impl WebsocketProtocolHandler {
|
||||
} else {
|
||||
let (ws_stream, _response) = client_async(request, tcp_stream)
|
||||
.await
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.map_err(to_io_error_other)?;
|
||||
Ok(ProtocolNetworkConnection::Ws(
|
||||
WebsocketNetworkConnection::new(descriptor, ws_stream),
|
||||
))
|
||||
@@ -277,19 +290,17 @@ impl WebsocketProtocolHandler {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
Self::connect_internal(local_address, dial_info).await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound WS message");
|
||||
}
|
||||
|
||||
let protconn = Self::connect_internal(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
|
||||
|
||||
protconn.send(data).await
|
||||
}
|
||||
@@ -299,19 +310,17 @@ impl WebsocketProtocolHandler {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
if data.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large unbound WS message".to_owned());
|
||||
bail_io_error_other!("sending too large unbound WS message");
|
||||
}
|
||||
|
||||
let protconn = Self::connect_internal(None, dial_info.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
|
||||
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
|
||||
|
||||
protconn.send(data).await?;
|
||||
let out = timeout(timeout_ms, protconn.recv())
|
||||
.await
|
||||
.map_err(map_to_string)??;
|
||||
.map_err(|e| e.to_io())??;
|
||||
|
||||
tracing::Span::current().record("ret.len", &out.len());
|
||||
Ok(out)
|
||||
@@ -323,7 +332,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
|
||||
Box::pin(self.clone().on_accept_async(stream, peer_addr))
|
||||
}
|
||||
}
|
||||
|
@@ -164,7 +164,7 @@ impl Network {
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
fn find_available_udp_port(&self) -> Result<u16, String> {
|
||||
fn find_available_udp_port(&self) -> EyreResult<u16> {
|
||||
// If the address is empty, iterate ports until we find one we can use.
|
||||
let mut udp_port = 5150u16;
|
||||
loop {
|
||||
@@ -175,14 +175,14 @@ impl Network {
|
||||
break;
|
||||
}
|
||||
if udp_port == 65535 {
|
||||
return Err("Could not find free udp port to listen on".to_owned());
|
||||
bail!("Could not find free udp port to listen on");
|
||||
}
|
||||
udp_port += 1;
|
||||
}
|
||||
Ok(udp_port)
|
||||
}
|
||||
|
||||
fn find_available_tcp_port(&self) -> Result<u16, String> {
|
||||
fn find_available_tcp_port(&self) -> EyreResult<u16> {
|
||||
// If the address is empty, iterate ports until we find one we can use.
|
||||
let mut tcp_port = 5150u16;
|
||||
loop {
|
||||
@@ -193,17 +193,14 @@ impl Network {
|
||||
break;
|
||||
}
|
||||
if tcp_port == 65535 {
|
||||
return Err("Could not find free tcp port to listen on".to_owned());
|
||||
bail!("Could not find free tcp port to listen on");
|
||||
}
|
||||
tcp_port += 1;
|
||||
}
|
||||
Ok(tcp_port)
|
||||
}
|
||||
|
||||
async fn allocate_udp_port(
|
||||
&self,
|
||||
listen_address: String,
|
||||
) -> Result<(u16, Vec<IpAddr>), String> {
|
||||
async fn allocate_udp_port(&self, listen_address: String) -> EyreResult<(u16, Vec<IpAddr>)> {
|
||||
if listen_address.is_empty() {
|
||||
// If listen address is empty, find us a port iteratively
|
||||
let port = self.find_available_udp_port()?;
|
||||
@@ -217,21 +214,17 @@ impl Network {
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
|
||||
if sockaddrs.is_empty() {
|
||||
return Err(format!("No valid listen address: {}", listen_address));
|
||||
bail!("No valid listen address: {}", listen_address);
|
||||
}
|
||||
let port = sockaddrs[0].port();
|
||||
if self.bind_first_udp_port(port) {
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
} else {
|
||||
Err("Could not find free udp port to listen on".to_owned())
|
||||
if !self.bind_first_udp_port(port) {
|
||||
bail!("Could not find free udp port to listen on");
|
||||
}
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn allocate_tcp_port(
|
||||
&self,
|
||||
listen_address: String,
|
||||
) -> Result<(u16, Vec<IpAddr>), String> {
|
||||
async fn allocate_tcp_port(&self, listen_address: String) -> EyreResult<(u16, Vec<IpAddr>)> {
|
||||
if listen_address.is_empty() {
|
||||
// If listen address is empty, find us a port iteratively
|
||||
let port = self.find_available_tcp_port()?;
|
||||
@@ -245,20 +238,19 @@ impl Network {
|
||||
// If the address is specified, only use the specified port and fail otherwise
|
||||
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
|
||||
if sockaddrs.is_empty() {
|
||||
return Err(format!("No valid listen address: {}", listen_address));
|
||||
bail!("No valid listen address: {}", listen_address);
|
||||
}
|
||||
let port = sockaddrs[0].port();
|
||||
if self.bind_first_tcp_port(port) {
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
} else {
|
||||
Err("Could not find free tcp port to listen on".to_owned())
|
||||
if !self.bind_first_tcp_port(port) {
|
||||
bail!("Could not find free tcp port to listen on");
|
||||
}
|
||||
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_udp_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting udp listeners");
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, public_address, enable_local_peer_scope) = {
|
||||
@@ -319,7 +311,7 @@ impl Network {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
.wrap_err(format!("Unable to resolve address: {}", public_address))?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
for pdi_addr in &mut public_sockaddrs {
|
||||
@@ -364,7 +356,7 @@ impl Network {
|
||||
self.create_udp_listener_tasks().await
|
||||
}
|
||||
|
||||
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_ws_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting ws listeners");
|
||||
let routing_table = self.routing_table();
|
||||
let (listen_address, url, path, enable_local_peer_scope) = {
|
||||
@@ -405,9 +397,9 @@ impl Network {
|
||||
|
||||
// Add static public dialinfo if it's configured
|
||||
if let Some(url) = url.as_ref() {
|
||||
let mut split_url = SplitUrl::from_str(url)?;
|
||||
let mut split_url = SplitUrl::from_str(url).wrap_err("couldn't split url")?;
|
||||
if split_url.scheme.to_ascii_lowercase() != "ws" {
|
||||
return Err("WS URL must use 'ws://' scheme".to_owned());
|
||||
bail!("WS URL must use 'ws://' scheme");
|
||||
}
|
||||
split_url.scheme = "ws".to_owned();
|
||||
|
||||
@@ -415,13 +407,11 @@ impl Network {
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(80)
|
||||
.to_socket_addrs()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.wrap_err("failed to resolve ws url")?;
|
||||
|
||||
for gsa in global_socket_addrs {
|
||||
let pdi = DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.wrap_err("try_ws failed")?;
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
@@ -458,9 +448,7 @@ impl Network {
|
||||
}
|
||||
// Build dial info request url
|
||||
let local_url = format!("ws://{}/{}", socket_address, path);
|
||||
let local_di = DialInfo::try_ws(socket_address, local_url)
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
let local_di = DialInfo::try_ws(socket_address, local_url).wrap_err("try_ws failed")?;
|
||||
|
||||
if url.is_none() && (socket_address.address().is_global() || enable_local_peer_scope) {
|
||||
// Register public dial info
|
||||
@@ -490,7 +478,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_wss_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting wss listeners");
|
||||
|
||||
let routing_table = self.routing_table();
|
||||
@@ -538,7 +526,7 @@ impl Network {
|
||||
// Add static public dialinfo if it's configured
|
||||
let mut split_url = SplitUrl::from_str(url)?;
|
||||
if split_url.scheme.to_ascii_lowercase() != "wss" {
|
||||
return Err("WSS URL must use 'wss://' scheme".to_owned());
|
||||
bail!("WSS URL must use 'wss://' scheme");
|
||||
}
|
||||
split_url.scheme = "wss".to_owned();
|
||||
|
||||
@@ -546,13 +534,10 @@ impl Network {
|
||||
let global_socket_addrs = split_url
|
||||
.host_port(443)
|
||||
.to_socket_addrs()
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
|
||||
.wrap_err("failed to resolve wss url")?;
|
||||
for gsa in global_socket_addrs {
|
||||
let pdi = DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
|
||||
.map_err(map_to_string)
|
||||
.map_err(logthru_net!(error))?;
|
||||
.wrap_err("try_wss failed")?;
|
||||
|
||||
routing_table.register_dial_info(
|
||||
RoutingDomain::PublicInternet,
|
||||
@@ -581,7 +566,7 @@ impl Network {
|
||||
registered_addresses.insert(gsa.ip());
|
||||
}
|
||||
} else {
|
||||
return Err("WSS URL must be specified due to TLS requirements".to_owned());
|
||||
bail!("WSS URL must be specified due to TLS requirements");
|
||||
}
|
||||
|
||||
if static_public {
|
||||
@@ -594,7 +579,7 @@ impl Network {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
|
||||
pub(super) async fn start_tcp_listeners(&self) -> EyreResult<()> {
|
||||
trace!("starting tcp listeners");
|
||||
|
||||
let routing_table = self.routing_table();
|
||||
@@ -659,7 +644,7 @@ impl Network {
|
||||
// Resolve statically configured public dialinfo
|
||||
let mut public_sockaddrs = public_address
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
|
||||
.wrap_err("failed to resolve tcp address")?;
|
||||
|
||||
// Add all resolved addresses as public dialinfo
|
||||
for pdi_addr in &mut public_sockaddrs {
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use futures_util::{FutureExt, StreamExt};
|
||||
use std::io;
|
||||
use stop_token::prelude::*;
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
@@ -15,7 +16,7 @@ cfg_if::cfg_if! {
|
||||
&self,
|
||||
stream: AsyncPeekStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>>;
|
||||
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>>;
|
||||
}
|
||||
|
||||
pub trait ProtocolAcceptHandlerClone {
|
||||
@@ -52,13 +53,13 @@ impl DummyNetworkConnection {
|
||||
pub fn descriptor(&self) -> ConnectionDescriptor {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
pub fn close(&self) -> Result<(), String> {
|
||||
// pub fn close(&self) -> Result<(), String> {
|
||||
// Ok(())
|
||||
// }
|
||||
pub fn send(&self, _message: Vec<u8>) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
pub fn send(&self, _message: Vec<u8>) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
pub fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
@@ -178,7 +179,7 @@ impl NetworkConnection {
|
||||
protocol_connection: &ProtocolNetworkConnection,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
message: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> io::Result<()> {
|
||||
let ts = intf::get_timestamp();
|
||||
let out = protocol_connection.send(message).await;
|
||||
if out.is_ok() {
|
||||
@@ -190,7 +191,7 @@ impl NetworkConnection {
|
||||
async fn recv_internal(
|
||||
protocol_connection: &ProtocolNetworkConnection,
|
||||
stats: Arc<Mutex<NetworkConnectionStats>>,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let ts = intf::get_timestamp();
|
||||
let out = protocol_connection.recv().await;
|
||||
if out.is_ok() {
|
||||
@@ -222,7 +223,7 @@ impl NetworkConnection {
|
||||
) -> SystemPinBoxFuture<()> {
|
||||
Box::pin(async move {
|
||||
log_net!(
|
||||
"Starting process_connection loop for {:?}",
|
||||
"== Starting process_connection loop for {:?}",
|
||||
descriptor.green()
|
||||
);
|
||||
|
||||
@@ -236,7 +237,7 @@ impl NetworkConnection {
|
||||
let new_timer = || {
|
||||
intf::sleep(connection_inactivity_timeout_ms).then(|_| async {
|
||||
// timeout
|
||||
log_net!("connection timeout on {:?}", descriptor.green());
|
||||
log_net!("== Connection timeout on {:?}", descriptor.green());
|
||||
RecvLoopAction::Timeout
|
||||
})
|
||||
};
|
||||
@@ -288,7 +289,7 @@ impl NetworkConnection {
|
||||
.on_recv_envelope(message.as_slice(), descriptor)
|
||||
.await
|
||||
{
|
||||
log_net!(error e);
|
||||
log_net!(debug "failed to process received envelope: {}", e);
|
||||
RecvLoopAction::Finish
|
||||
} else {
|
||||
RecvLoopAction::Recv
|
||||
@@ -296,7 +297,7 @@ impl NetworkConnection {
|
||||
}
|
||||
Err(e) => {
|
||||
// Connection unable to receive, closed
|
||||
log_net!(warn e);
|
||||
log_net!(debug e);
|
||||
RecvLoopAction::Finish
|
||||
}
|
||||
}
|
||||
|
@@ -56,7 +56,7 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let data_len = data.len();
|
||||
|
||||
let res = match dial_info.protocol_type() {
|
||||
@@ -90,7 +90,7 @@ impl Network {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> EyreResult<Vec<u8>> {
|
||||
let data_len = data.len();
|
||||
let out = match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
@@ -124,7 +124,7 @@ impl Network {
|
||||
&self,
|
||||
descriptor: ConnectionDescriptor,
|
||||
data: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, String> {
|
||||
) -> EyreResult<Option<Vec<u8>>> {
|
||||
let data_len = data.len();
|
||||
match descriptor.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
@@ -161,7 +161,7 @@ impl Network {
|
||||
&self,
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> EyreResult<()> {
|
||||
let data_len = data.len();
|
||||
if dial_info.protocol_type() == ProtocolType::UDP {
|
||||
return Err("no support for UDP protocol".to_owned()).map_err(logthru_net!(error))
|
||||
@@ -187,7 +187,7 @@ impl Network {
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub async fn startup(&self) -> Result<(), String> {
|
||||
pub async fn startup(&self) -> EyreResult<()> {
|
||||
// get protocol config
|
||||
self.inner.lock().protocol_config = Some({
|
||||
let c = self.config.get();
|
||||
@@ -269,7 +269,7 @@ impl Network {
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
pub async fn tick(&self) -> Result<(), String> {
|
||||
pub async fn tick(&self) -> EyreResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -15,7 +15,7 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn connect(
|
||||
local_address: Option<SocketAddr>,
|
||||
dial_info: DialInfo,
|
||||
) -> Result<ProtocolNetworkConnection, String> {
|
||||
) -> io::Result<ProtocolNetworkConnection> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("UDP dial info is not supported on WASM targets");
|
||||
@@ -32,7 +32,7 @@ impl ProtocolNetworkConnection {
|
||||
pub async fn send_unbound_message(
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
) -> io::Result<()> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("UDP dial info is not supported on WASM targets");
|
||||
@@ -50,7 +50,7 @@ impl ProtocolNetworkConnection {
|
||||
dial_info: DialInfo,
|
||||
data: Vec<u8>,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
) -> io::Result<Vec<u8>> {
|
||||
match dial_info.protocol_type() {
|
||||
ProtocolType::UDP => {
|
||||
panic!("UDP dial info is not supported on WASM targets");
|
||||
@@ -72,20 +72,20 @@ impl ProtocolNetworkConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.close(),
|
||||
Self::Ws(w) => w.close().await,
|
||||
}
|
||||
}
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
// pub async fn close(&self) -> io::Result<()> {
|
||||
// match self {
|
||||
// Self::Dummy(d) => d.close(),
|
||||
// Self::Ws(w) => w.close().await,
|
||||
// }
|
||||
// }
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.send(message),
|
||||
Self::Ws(w) => w.send(message).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn recv(&self) -> Result<Vec<u8>, String> {
|
||||
pub async fn recv(&self) -> io::Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::Dummy(d) => d.recv(),
|
||||
Self::Ws(w) => w.recv().await,
|
||||
|
@@ -1,12 +1,19 @@
|
||||
use super::*;
|
||||
use ws_stream_wasm::*;
|
||||
use futures_util::{StreamExt, SinkExt};
|
||||
use std::io;
|
||||
|
||||
struct WebsocketNetworkConnectionInner {
|
||||
ws_meta: WsMeta,
|
||||
ws_stream: CloneStream<WsStream>,
|
||||
}
|
||||
|
||||
fn to_io(err: WsErr) -> io::Error {
|
||||
let kind = match err {
|
||||
WsErr::InvalidWsState {supplied:_} => io::ErrorKind::
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketNetworkConnection {
|
||||
descriptor: ConnectionDescriptor,
|
||||
@@ -36,15 +43,15 @@ impl WebsocketNetworkConnection {
|
||||
self.descriptor.clone()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", err, skip(self))]
|
||||
pub async fn close(&self) -> Result<(), String> {
|
||||
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
|
||||
}
|
||||
// #[instrument(level = "trace", err, skip(self))]
|
||||
// pub async fn close(&self) -> Result<(), String> {
|
||||
// self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
|
||||
// }
|
||||
|
||||
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
|
||||
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
|
||||
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
|
||||
if message.len() > MAX_MESSAGE_SIZE {
|
||||
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
|
||||
bail_io_error_other!("sending too large WS message");
|
||||
}
|
||||
self.inner.ws_stream.clone()
|
||||
.send(WsMessage::Binary(message)).await
|
||||
|
Reference in New Issue
Block a user