This commit is contained in:
John Smith
2022-07-10 17:36:50 -04:00
parent cd0cd78e30
commit 7e0d7dad06
108 changed files with 1378 additions and 1535 deletions

View File

@@ -18,14 +18,16 @@ impl ConnectionHandle {
self.descriptor.clone()
}
pub fn send(&self, message: Vec<u8>) -> Result<(), String> {
self.channel.send(message).map_err(map_to_string)
pub fn send(&self, message: Vec<u8>) -> EyreResult<()> {
self.channel
.send(message)
.wrap_err("failed to send to connection")
}
pub async fn send_async(&self, message: Vec<u8>) -> Result<(), String> {
pub async fn send_async(&self, message: Vec<u8>) -> EyreResult<()> {
self.channel
.send_async(message)
.await
.map_err(map_to_string)
.wrap_err("failed to send_async to connection")
}
}

View File

@@ -1,33 +1,17 @@
use super::*;
use alloc::collections::btree_map::Entry;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressFilterError {
#[error("Count exceeded")]
CountExceeded,
#[error("Rate exceeded")]
RateExceeded,
}
impl fmt::Display for AddressFilterError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
match *self {
Self::CountExceeded => "Count exceeded",
Self::RateExceeded => "Rate exceeded",
}
)
}
}
impl std::error::Error for AddressFilterError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)]
#[error("Address not in table")]
pub struct AddressNotInTableError {}
impl fmt::Display for AddressNotInTableError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Address not in table")
}
}
impl std::error::Error for AddressNotInTableError {}
#[derive(Debug)]
pub struct ConnectionLimits {

View File

@@ -142,13 +142,13 @@ impl ConnectionManager {
&self,
inner: &mut ConnectionManagerInner,
conn: ProtocolNetworkConnection,
) -> Result<ConnectionHandle, String> {
) -> EyreResult<ConnectionHandle> {
log_net!("on_new_protocol_network_connection: {:?}", conn);
// Wrap with NetworkConnection object to start the connection processing loop
let stop_token = match &inner.stop_source {
Some(ss) => ss.token(),
None => return Err("not creating connection because we are stopping".to_owned()),
None => bail!("not creating connection because we are stopping"),
};
let conn = NetworkConnection::from_protocol(self.clone(), stop_token, conn);
@@ -165,7 +165,7 @@ impl ConnectionManager {
&self,
local_addr: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ConnectionHandle, String> {
) -> EyreResult<ConnectionHandle> {
let killed = {
let mut inner = self.arc.inner.lock();
let inner = match &mut *inner {
@@ -274,7 +274,7 @@ impl ConnectionManager {
let inner = match &mut *inner {
Some(v) => v,
None => {
return Err("shutting down".to_owned());
bail!("shutting down");
}
};
self.on_new_protocol_network_connection(inner, conn)
@@ -336,7 +336,7 @@ impl ConnectionManager {
pub(super) async fn on_accepted_protocol_network_connection(
&self,
conn: ProtocolNetworkConnection,
) -> Result<(), String> {
) -> EyreResult<()> {
// Get channel sender
let sender = {
let mut inner = self.arc.inner.lock();

View File

@@ -3,6 +3,39 @@ use alloc::collections::btree_map::Entry;
use futures_util::StreamExt;
use hashlink::LruCache;
///////////////////////////////////////////////////////////////////////////////
#[derive(ThisError, Debug, Clone, Eq, PartialEq)]
pub enum ConnectionTableAddError {
#[error("Connection already added to table")]
AlreadyExists,
#[error("Connection address was filtered")]
AddressFilter(AddressFilterError),
}
impl ConnectionTableAddError {
pub fn already_exists() -> Self {
ConnectionTableAddError::AlreadyExists
}
pub fn address_filter(err: AddressFilterError) -> Self {
ConnectionTableAddError::AddressFilter(err)
}
}
///////////////////////////////////////////////////////////////////////////////
#[derive(ThisError, Debug, Clone, Eq, PartialEq)]
pub enum ConnectionTableRemoveError {
#[error("Connection not in table")]
NotInTable,
}
impl ConnectionTableRemoveError {
pub fn not_in_table() -> Self {
ConnectionTableRemoveError::NotInTable
}
}
///////////////////////////////////////////////////////////////////////////////
#[derive(Debug)]
pub struct ConnectionTable {
max_connections: Vec<usize>,
@@ -53,20 +86,22 @@ impl ConnectionTable {
while unord.next().await.is_some() {}
}
pub fn add_connection(&mut self, conn: NetworkConnection) -> Result<(), String> {
pub fn add_connection(
&mut self,
conn: NetworkConnection,
) -> Result<(), ConnectionTableAddError> {
let descriptor = conn.connection_descriptor();
let ip_addr = descriptor.remote_address().to_ip_addr();
let index = protocol_to_index(descriptor.protocol_type());
if self.conn_by_descriptor[index].contains_key(&descriptor) {
return Err(format!(
"Connection already added to table: {:?}",
descriptor
));
return Err(ConnectionTableAddError::already_exists());
}
// Filter by ip for connection limits
self.address_filter.add(ip_addr).map_err(map_to_string)?;
self.address_filter
.add(ip_addr)
.map_err(ConnectionTableAddError::address_filter)?;
// Add the connection to the table
let res = self.conn_by_descriptor[index].insert(descriptor.clone(), conn);
@@ -164,11 +199,11 @@ impl ConnectionTable {
pub fn remove_connection(
&mut self,
descriptor: ConnectionDescriptor,
) -> Result<NetworkConnection, String> {
) -> Result<NetworkConnection, ConnectionTableRemoveError> {
let index = protocol_to_index(descriptor.protocol_type());
let conn = self.conn_by_descriptor[index]
.remove(&descriptor)
.ok_or_else(|| format!("Connection not in table: {:?}", descriptor))?;
.ok_or_else(|| ConnectionTableRemoveError::not_in_table())?;
self.remove_connection_records(descriptor);
Ok(conn)

View File

@@ -127,8 +127,8 @@ struct NetworkManagerInner {
struct NetworkManagerUnlockedInner {
// Background processes
rolling_transfers_task: TickTask,
relay_management_task: TickTask,
rolling_transfers_task: TickTask<EyreReport>,
relay_management_task: TickTask<EyreReport>,
}
#[derive(Clone)]
@@ -236,7 +236,7 @@ impl NetworkManager {
}
#[instrument(level = "debug", skip_all, err)]
pub async fn init(&self, update_callback: UpdateCallback) -> Result<(), String> {
pub async fn init(&self, update_callback: UpdateCallback) -> EyreResult<()> {
let routing_table = RoutingTable::new(self.clone());
routing_table.init().await?;
self.inner.lock().routing_table = Some(routing_table.clone());
@@ -257,7 +257,7 @@ impl NetworkManager {
}
#[instrument(level = "debug", skip_all, err)]
pub async fn internal_startup(&self) -> Result<(), String> {
pub async fn internal_startup(&self) -> EyreResult<()> {
trace!("NetworkManager::internal_startup begin");
if self.inner.lock().components.is_some() {
debug!("NetworkManager::internal_startup already started");
@@ -292,7 +292,7 @@ impl NetworkManager {
}
#[instrument(level = "debug", skip_all, err)]
pub async fn startup(&self) -> Result<(), String> {
pub async fn startup(&self) -> EyreResult<()> {
if let Err(e) = self.internal_startup().await {
self.shutdown().await;
return Err(e);
@@ -387,7 +387,7 @@ impl NetworkManager {
}
#[instrument(level = "debug", skip_all, err)]
async fn restart_net(&self, net: Network) -> Result<(), String> {
async fn restart_net(&self, net: Network) -> EyreResult<()> {
net.shutdown().await;
self.send_network_update();
net.startup().await?;
@@ -395,7 +395,7 @@ impl NetworkManager {
Ok(())
}
pub async fn tick(&self) -> Result<(), String> {
pub async fn tick(&self) -> EyreResult<()> {
let (routing_table, net, receipt_manager) = {
let inner = self.inner.lock();
let components = inner.components.as_ref().unwrap();
@@ -481,7 +481,7 @@ impl NetworkManager {
expected_returns: u32,
extra_data: D,
callback: impl ReceiptCallback,
) -> Result<Vec<u8>, String> {
) -> EyreResult<Vec<u8>> {
let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table();
@@ -490,7 +490,7 @@ impl NetworkManager {
let receipt = Receipt::try_new(0, nonce, routing_table.node_id(), extra_data)?;
let out = receipt
.to_signed_data(&routing_table.node_id_secret())
.map_err(|_| "failed to generate signed receipt".to_owned())?;
.wrap_err("failed to generate signed receipt")?;
// Record the receipt for later
let exp_ts = intf::get_timestamp() + expiration_us;
@@ -505,7 +505,7 @@ impl NetworkManager {
&self,
expiration_us: u64,
extra_data: D,
) -> Result<(Vec<u8>, EventualValueFuture<ReceiptEvent>), String> {
) -> EyreResult<(Vec<u8>, EventualValueFuture<ReceiptEvent>)> {
let receipt_manager = self.receipt_manager();
let routing_table = self.routing_table();
@@ -514,7 +514,7 @@ impl NetworkManager {
let receipt = Receipt::try_new(0, nonce, routing_table.node_id(), extra_data)?;
let out = receipt
.to_signed_data(&routing_table.node_id_secret())
.map_err(|_| "failed to generate signed receipt".to_owned())?;
.wrap_err("failed to generate signed receipt")?;
// Record the receipt for later
let exp_ts = intf::get_timestamp() + expiration_us;
@@ -530,11 +530,11 @@ impl NetworkManager {
pub async fn handle_out_of_band_receipt<R: AsRef<[u8]>>(
&self,
receipt_data: R,
) -> Result<(), String> {
) -> EyreResult<()> {
let receipt_manager = self.receipt_manager();
let receipt = Receipt::from_signed_data(receipt_data.as_ref())
.map_err(|_| "failed to parse signed out-of-band receipt".to_owned())?;
.wrap_err("failed to parse signed out-of-band receipt")?;
receipt_manager.handle_receipt(receipt, None).await
}
@@ -545,11 +545,11 @@ impl NetworkManager {
&self,
receipt_data: R,
inbound_nr: NodeRef,
) -> Result<(), String> {
) -> EyreResult<()> {
let receipt_manager = self.receipt_manager();
let receipt = Receipt::from_signed_data(receipt_data.as_ref())
.map_err(|_| "failed to parse signed in-band receipt".to_owned())?;
.wrap_err("failed to parse signed in-band receipt")?;
receipt_manager
.handle_receipt(receipt, Some(inbound_nr))
@@ -558,7 +558,7 @@ impl NetworkManager {
// Process a received signal
#[instrument(level = "trace", skip(self), err)]
pub async fn handle_signal(&self, signal_info: SignalInfo) -> Result<(), String> {
pub async fn handle_signal(&self, signal_info: SignalInfo) -> EyreResult<()> {
match signal_info {
SignalInfo::ReverseConnect { receipt, peer_info } => {
let routing_table = self.routing_table();
@@ -573,7 +573,7 @@ impl NetworkManager {
// Make a reverse connection to the peer and send the receipt to it
rpc.rpc_call_return_receipt(Destination::Direct(peer_nr), None, receipt)
.await
.map_err(map_to_string)?;
.wrap_err("rpc failure")?;
}
SignalInfo::HolePunch { receipt, peer_info } => {
let routing_table = self.routing_table();
@@ -589,7 +589,7 @@ impl NetworkManager {
peer_nr.filter_protocols(ProtocolSet::only(ProtocolType::UDP));
let hole_punch_dial_info_detail = peer_nr
.first_filtered_dial_info_detail(Some(RoutingDomain::PublicInternet))
.ok_or_else(|| "No hole punch capable dialinfo found for node".to_owned())?;
.ok_or_else(|| eyre!("No hole punch capable dialinfo found for node"))?;
// Now that we picked a specific dialinfo, further restrict the noderef to the specific address type
let mut filter = peer_nr.take_filter().unwrap();
@@ -611,7 +611,7 @@ impl NetworkManager {
// Return the receipt using the same dial info send the receipt to it
rpc.rpc_call_return_receipt(Destination::Direct(peer_nr), None, receipt)
.await
.map_err(map_to_string)?;
.wrap_err("rpc failure")?;
}
}
@@ -625,7 +625,7 @@ impl NetworkManager {
dest_node_id: DHTKey,
version: u8,
body: B,
) -> Result<Vec<u8>, String> {
) -> EyreResult<Vec<u8>> {
// DH to get encryption key
let routing_table = self.routing_table();
let node_id = routing_table.node_id();
@@ -639,7 +639,7 @@ impl NetworkManager {
let envelope = Envelope::new(version, ts, nonce, node_id, dest_node_id);
envelope
.to_encrypted_data(self.crypto.clone(), body.as_ref(), &node_id_secret)
.map_err(|_| "envelope failed to encode".to_owned())
.wrap_err("envelope failed to encode")
}
// Called by the RPC handler when we want to issue an RPC request or response
@@ -652,7 +652,7 @@ impl NetworkManager {
node_ref: NodeRef,
envelope_node_id: Option<DHTKey>,
body: B,
) -> Result<SendDataKind, String> {
) -> EyreResult<SendDataKind> {
let via_node_id = node_ref.node_id();
let envelope_node_id = envelope_node_id.unwrap_or(via_node_id);
@@ -671,11 +671,12 @@ impl NetworkManager {
{
#[allow(clippy::absurd_extreme_comparisons)]
if node_min > MAX_VERSION || node_max < MIN_VERSION {
return Err(format!(
bail!(
"can't talk to this node {} because version is unsupported: ({},{})",
via_node_id, node_min, node_max
))
.map_err(logthru_rpc!(warn));
via_node_id,
node_min,
node_max
);
}
cmp::min(node_max, MAX_VERSION)
} else {
@@ -703,7 +704,7 @@ impl NetworkManager {
&self,
dial_info: DialInfo,
rcpt_data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
// Do we need to validate the outgoing receipt? Probably not
// because it is supposed to be opaque and the
// recipient/originator does the validation
@@ -717,8 +718,8 @@ impl NetworkManager {
}
// Figure out how to reach a node
#[instrument(level = "trace", skip(self), ret, err)]
fn get_contact_method(&self, mut target_node_ref: NodeRef) -> Result<ContactMethod, String> {
#[instrument(level = "trace", skip(self), ret)]
fn get_contact_method(&self, mut target_node_ref: NodeRef) -> ContactMethod {
let routing_table = self.routing_table();
// Get our network class and protocol config and node id
@@ -727,14 +728,14 @@ impl NetworkManager {
// Scope noderef down to protocols we can do outbound
if !target_node_ref.filter_protocols(our_protocol_config.outbound) {
return Ok(ContactMethod::Unreachable);
return ContactMethod::Unreachable;
}
// Get the best matching local direct dial info if we have it
let opt_target_local_did =
target_node_ref.first_filtered_dial_info_detail(Some(RoutingDomain::LocalNetwork));
if let Some(target_local_did) = opt_target_local_did {
return Ok(ContactMethod::Direct(target_local_did.dial_info));
return ContactMethod::Direct(target_local_did.dial_info);
}
// Get the best match internet dial info if we have it
@@ -744,7 +745,7 @@ impl NetworkManager {
// Do we need to signal before going inbound?
if !target_public_did.class.requires_signal() {
// Go direct without signaling
return Ok(ContactMethod::Direct(target_public_did.dial_info));
return ContactMethod::Direct(target_public_did.dial_info);
}
// Get the target's inbound relay, it must have one or it is not reachable
@@ -767,10 +768,10 @@ impl NetworkManager {
) {
// Can we receive a direct reverse connection?
if !reverse_did.class.requires_signal() {
return Ok(ContactMethod::SignalReverse(
return ContactMethod::SignalReverse(
inbound_relay_nr,
target_node_ref,
));
);
}
}
@@ -798,16 +799,16 @@ impl NetworkManager {
)
.is_some();
if target_has_udp_dialinfo && self_has_udp_dialinfo {
return Ok(ContactMethod::SignalHolePunch(
return ContactMethod::SignalHolePunch(
inbound_relay_nr,
udp_target_nr,
));
);
}
}
// Otherwise we have to inbound relay
}
return Ok(ContactMethod::InboundRelay(inbound_relay_nr));
return ContactMethod::InboundRelay(inbound_relay_nr);
}
}
}
@@ -818,22 +819,17 @@ impl NetworkManager {
.first_filtered_dial_info_detail(Some(RoutingDomain::PublicInternet))
.is_some()
{
return Ok(ContactMethod::InboundRelay(target_inbound_relay_nr));
return ContactMethod::InboundRelay(target_inbound_relay_nr);
}
}
// If we can't reach the node by other means, try our outbound relay if we have one
if let Some(relay_node) = self.relay_node() {
return Ok(ContactMethod::OutboundRelay(relay_node));
return ContactMethod::OutboundRelay(relay_node);
}
// Otherwise, we can't reach this node
debug!("unable to reach node {:?}", target_node_ref);
// trace!(
// "unable to reach node {:?}: {}",
// target_node_ref,
// target_node_ref.operate(|e| format!("{:#?}", e))
// );
Ok(ContactMethod::Unreachable)
ContactMethod::Unreachable
}
// Send a reverse connection signal and wait for the return receipt over it
@@ -844,13 +840,11 @@ impl NetworkManager {
relay_nr: NodeRef,
target_nr: NodeRef,
data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
// Build a return receipt for the signal
let receipt_timeout =
ms_to_us(self.config.get().network.reverse_connection_receipt_time_ms);
let (receipt, eventual_value) = self
.generate_single_shot_receipt(receipt_timeout, [])
.map_err(map_to_string)?;
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;
// Get our peer info
let peer_info = self.routing_table().get_own_peer_info();
@@ -863,32 +857,25 @@ impl NetworkManager {
SignalInfo::ReverseConnect { receipt, peer_info },
)
.await
.map_err(logthru_net!("failed to send signal to {:?}", relay_nr))
.map_err(map_to_string)?;
.wrap_err("failed to send signal")?;
// Wait for the return receipt
let inbound_nr = match eventual_value.await.take_value().unwrap() {
ReceiptEvent::ReturnedOutOfBand => {
return Err("reverse connect receipt should be returned in-band".to_owned());
bail!("reverse connect receipt should be returned in-band");
}
ReceiptEvent::ReturnedInBand { inbound_noderef } => inbound_noderef,
ReceiptEvent::Expired => {
return Err(format!(
"reverse connect receipt expired from {:?}",
target_nr
));
bail!("reverse connect receipt expired from {:?}", target_nr);
}
ReceiptEvent::Cancelled => {
return Err(format!(
"reverse connect receipt cancelled from {:?}",
target_nr
));
bail!("reverse connect receipt cancelled from {:?}", target_nr);
}
};
// We expect the inbound noderef to be the same as the target noderef
// if they aren't the same, we should error on this and figure out what then hell is up
if target_nr != inbound_nr {
error!("unexpected noderef mismatch on reverse connect");
bail!("unexpected noderef mismatch on reverse connect");
}
// And now use the existing connection to send over
@@ -899,10 +886,10 @@ impl NetworkManager {
.await?
{
None => Ok(()),
Some(_) => Err("unable to send over reverse connection".to_owned()),
Some(_) => bail!("unable to send over reverse connection"),
}
} else {
Err("no reverse connection available".to_owned())
bail!("no reverse connection available")
}
}
@@ -914,7 +901,7 @@ impl NetworkManager {
relay_nr: NodeRef,
target_nr: NodeRef,
data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
// Ensure we are filtered down to UDP (the only hole punch protocol supported today)
assert!(target_nr
.filter_ref()
@@ -923,17 +910,14 @@ impl NetworkManager {
// Build a return receipt for the signal
let receipt_timeout = ms_to_us(self.config.get().network.hole_punch_receipt_time_ms);
let (receipt, eventual_value) = self
.generate_single_shot_receipt(receipt_timeout, [])
.map_err(map_to_string)?;
let (receipt, eventual_value) = self.generate_single_shot_receipt(receipt_timeout, [])?;
// Get our peer info
let peer_info = self.routing_table().get_own_peer_info();
// Get the udp direct dialinfo for the hole punch
let hole_punch_did = target_nr
.first_filtered_dial_info_detail(Some(RoutingDomain::PublicInternet))
.ok_or_else(|| "No hole punch capable dialinfo found for node".to_owned())?;
.ok_or_else(|| eyre!("No hole punch capable dialinfo found for node"))?;
// Do our half of the hole punch by sending an empty packet
// Both sides will do this and then the receipt will get sent over the punched hole
@@ -949,30 +933,30 @@ impl NetworkManager {
SignalInfo::HolePunch { receipt, peer_info },
)
.await
.map_err(logthru_net!("failed to send signal to {:?}", relay_nr))
.map_err(map_to_string)?;
.wrap_err("failed to send signal")?;
// Wait for the return receipt
let inbound_nr = match eventual_value.await.take_value().unwrap() {
ReceiptEvent::ReturnedOutOfBand => {
return Err("hole punch receipt should be returned in-band".to_owned());
bail!("hole punch receipt should be returned in-band");
}
ReceiptEvent::ReturnedInBand { inbound_noderef } => inbound_noderef,
ReceiptEvent::Expired => {
return Err(format!("hole punch receipt expired from {}", target_nr));
bail!("hole punch receipt expired from {}", target_nr);
}
ReceiptEvent::Cancelled => {
return Err(format!("hole punch receipt cancelled from {}", target_nr));
bail!("hole punch receipt cancelled from {}", target_nr);
}
};
// We expect the inbound noderef to be the same as the target noderef
// if they aren't the same, we should error on this and figure out what then hell is up
if target_nr != inbound_nr {
return Err(format!(
bail!(
"unexpected noderef mismatch on hole punch {}, expected {}",
inbound_nr, target_nr
));
inbound_nr,
target_nr
);
}
// And now use the existing connection to send over
@@ -983,10 +967,10 @@ impl NetworkManager {
.await?
{
None => Ok(()),
Some(_) => Err("unable to send over hole punch".to_owned()),
Some(_) => bail!("unable to send over hole punch"),
}
} else {
Err("no hole punch available".to_owned())
bail!("no hole punch available")
}
}
@@ -1003,7 +987,7 @@ impl NetworkManager {
&self,
node_ref: NodeRef,
data: Vec<u8>,
) -> SystemPinBoxFuture<Result<SendDataKind, String>> {
) -> SystemPinBoxFuture<EyreResult<SendDataKind>> {
let this = self.clone();
Box::pin(async move {
// First try to send data to the last socket we've seen this peer on
@@ -1028,11 +1012,7 @@ impl NetworkManager {
log_net!("send_data via dialinfo to {:?}", node_ref);
// If we don't have last_connection, try to reach out to the peer via its dial info
match this
.get_contact_method(node_ref.clone())
.map_err(logthru_net!(debug))
.map(logthru_net!("get_contact_method for {:?}", node_ref))?
{
match this.get_contact_method(node_ref.clone()) {
ContactMethod::OutboundRelay(relay_nr) | ContactMethod::InboundRelay(relay_nr) => {
this.send_data(relay_nr, data)
.await
@@ -1057,14 +1037,13 @@ impl NetworkManager {
.do_hole_punch(relay_nr, target_node_ref, data)
.await
.map(|_| SendDataKind::GlobalDirect),
ContactMethod::Unreachable => Err("Can't send to this node".to_owned()),
ContactMethod::Unreachable => Err(eyre!("Can't send to this node")),
}
.map_err(logthru_net!(debug))
})
}
// Direct bootstrap request handler (separate fallback mechanism from cheaper TXT bootstrap mechanism)
async fn handle_boot_request(&self, descriptor: ConnectionDescriptor) -> Result<(), String> {
async fn handle_boot_request(&self, descriptor: ConnectionDescriptor) -> EyreResult<()> {
let routing_table = self.routing_table();
// Get a bunch of nodes with the various
@@ -1087,12 +1066,13 @@ impl NetworkManager {
// Bootstrap reply was sent
Ok(())
}
Some(_) => Err("bootstrap reply could not be sent".to_owned()),
Some(_) => Err(eyre!("bootstrap reply could not be sent")),
}
}
// Direct bootstrap request
pub async fn boot_request(&self, dial_info: DialInfo) -> Result<Vec<PeerInfo>, String> {
#[instrument(level = "trace", err, skip(self))]
pub async fn boot_request(&self, dial_info: DialInfo) -> EyreResult<Vec<PeerInfo>> {
let timeout_ms = {
let c = self.config.get();
c.network.rpc.timeout_ms
@@ -1105,8 +1085,8 @@ impl NetworkManager {
.await?;
let bootstrap_peerinfo: Vec<PeerInfo> =
deserialize_json(std::str::from_utf8(&out_data).map_err(map_to_string)?)
.map_err(map_to_string)?;
deserialize_json(std::str::from_utf8(&out_data).wrap_err("bad utf8 in boot peerinfo")?)
.wrap_err("failed to deserialize boot peerinfo")?;
Ok(bootstrap_peerinfo)
}
@@ -1119,7 +1099,7 @@ impl NetworkManager {
&self,
data: &[u8],
descriptor: ConnectionDescriptor,
) -> Result<bool, String> {
) -> EyreResult<bool> {
log_net!(
"envelope of {} bytes received from {:?}",
data.len(),
@@ -1138,7 +1118,7 @@ impl NetworkManager {
// Ensure we can read the magic number
if data.len() < 4 {
return Err("short packet".to_owned());
bail!("short packet");
}
// Is this a direct bootstrap request instead of an envelope?
@@ -1154,13 +1134,7 @@ impl NetworkManager {
}
// Decode envelope header (may fail signature validation)
let envelope = Envelope::from_signed_data(data).map_err(|_| {
format!(
"envelope failed to decode from {:?}: {} bytes",
descriptor,
data.len()
)
})?;
let envelope = Envelope::from_signed_data(data).wrap_err("envelope failed to decode")?;
// Get routing table and rpc processor
let (routing_table, rpc) = {
@@ -1185,18 +1159,18 @@ impl NetworkManager {
let ets = envelope.get_timestamp();
if let Some(tsbehind) = tsbehind {
if tsbehind > 0 && (ts > ets && ts - ets > tsbehind) {
return Err(format!(
bail!(
"envelope time was too far in the past: {}ms ",
timestamp_to_secs(ts - ets) * 1000f64
));
);
}
}
if let Some(tsahead) = tsahead {
if tsahead > 0 && (ts < ets && ets - ts > tsahead) {
return Err(format!(
bail!(
"envelope time was too far in the future: {}ms",
timestamp_to_secs(ets - ts) * 1000f64
));
);
}
}
@@ -1211,12 +1185,9 @@ impl NetworkManager {
let relay_nr = if self.check_client_whitelist(sender_id) {
// Full relay allowed, do a full resolve_node
rpc.resolve_node(recipient_id).await.map_err(|e| {
format!(
"failed to resolve recipient node for relay, dropping outbound relayed packet...: {:?}",
e
)
})?
rpc.resolve_node(recipient_id).await.wrap_err(
"failed to resolve recipient node for relay, dropping outbound relayed packet",
)?
} else {
// If this is not a node in the client whitelist, only allow inbound relay
// which only performs a lightweight lookup before passing the packet back out
@@ -1226,7 +1197,7 @@ impl NetworkManager {
// should be mutually in each others routing tables. The node needing the relay will be
// pinging this node regularly to keep itself in the routing table
routing_table.lookup_node_ref(recipient_id).ok_or_else(|| {
format!(
eyre!(
"Inbound relay asked for recipient not in routing table: sender_id={:?} recipient={:?}",
sender_id, recipient_id
)
@@ -1236,7 +1207,7 @@ impl NetworkManager {
// Relay the packet to the desired destination
self.send_data(relay_nr, data.to_vec())
.await
.map_err(|e| format!("failed to forward envelope: {}", e))?;
.wrap_err("failed to forward envelope")?;
// Inform caller that we dealt with the envelope, but did not process it locally
return Ok(false);
}
@@ -1248,19 +1219,20 @@ impl NetworkManager {
// xxx: punish nodes that send messages that fail to decrypt eventually
let body = envelope
.decrypt_body(self.crypto(), data, &node_id_secret)
.map_err(|_| "failed to decrypt envelope body".to_owned())?;
.wrap_err("failed to decrypt envelope body")?;
// Cache the envelope information in the routing table
let source_noderef = routing_table
.register_node_with_existing_connection(envelope.get_sender_id(), descriptor, ts)
.map_err(|e| format!("node id registration failed: {}", e))?;
let source_noderef = routing_table.register_node_with_existing_connection(
envelope.get_sender_id(),
descriptor,
ts,
)?;
source_noderef.operate_mut(|e| e.set_min_max_version(envelope.get_min_max_version()));
// xxx: deal with spoofing and flooding here?
// Pass message to RPC system
rpc.enqueue_message(envelope, body, source_noderef)
.map_err(|e| format!("enqueing rpc message failed: {}", e))?;
rpc.enqueue_message(envelope, body, source_noderef)?;
// Inform caller that we dealt with the envelope locally
Ok(true)
@@ -1273,7 +1245,7 @@ impl NetworkManager {
stop_token: StopToken,
_last_ts: u64,
cur_ts: u64,
) -> Result<(), String> {
) -> EyreResult<()> {
// Get our node's current node info and network class and do the right thing
let routing_table = self.routing_table();
let node_info = routing_table.get_own_node_info();
@@ -1354,7 +1326,7 @@ impl NetworkManager {
stop_token: StopToken,
last_ts: u64,
cur_ts: u64,
) -> Result<(), String> {
) -> EyreResult<()> {
// log_net!("--- network manager rolling_transfers task");
{
let inner = &mut *self.inner.lock();

View File

@@ -61,7 +61,7 @@ struct NetworkUnlockedInner {
network_manager: NetworkManager,
connection_manager: ConnectionManager,
// Background processes
update_network_class_task: TickTask,
update_network_class_task: TickTask<EyreReport>,
}
#[derive(Clone)]
@@ -266,7 +266,7 @@ impl Network {
// See if our interface addresses have changed, if so we need to punt the network
// and redo all our addresses. This is overkill, but anything more accurate
// would require inspection of routing tables that we dont want to bother with
pub async fn check_interface_addresses(&self) -> Result<bool, String> {
pub async fn check_interface_addresses(&self) -> EyreResult<bool> {
let mut inner = self.inner.lock();
if !inner.interfaces.refresh().await? {
return Ok(false);
@@ -286,7 +286,7 @@ impl Network {
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
let data_len = data.len();
let res = match dial_info.protocol_type() {
ProtocolType::UDP => {
@@ -300,7 +300,8 @@ impl Network {
ProtocolType::WS | ProtocolType::WSS => {
WebsocketProtocolHandler::send_unbound_message(dial_info.clone(), data).await
}
};
}
.wrap_err("low level network error");
if res.is_ok() {
// Network accounting
self.network_manager()
@@ -320,7 +321,7 @@ impl Network {
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> EyreResult<Vec<u8>> {
let data_len = data.len();
let out = match dial_info.protocol_type() {
ProtocolType::UDP => {
@@ -358,7 +359,7 @@ impl Network {
&self,
descriptor: ConnectionDescriptor,
data: Vec<u8>,
) -> Result<Option<Vec<u8>>, String> {
) -> EyreResult<Option<Vec<u8>>> {
let data_len = data.len();
// Handle connectionless protocol
@@ -369,12 +370,10 @@ impl Network {
&peer_socket_addr,
&descriptor.local().map(|sa| sa.to_socket_addr()),
) {
log_net!(
"send_data_to_existing_connection connectionless to {:?}",
descriptor
);
ph.clone().send_message(data, peer_socket_addr).await?;
ph.clone()
.send_message(data, peer_socket_addr)
.await
.wrap_err("sending data to existing conection")?;
// Network accounting
self.network_manager()
@@ -389,10 +388,10 @@ impl Network {
// Try to send to the exact existing connection if one exists
if let Some(conn) = self.connection_manager().get_connection(descriptor).await {
log_net!("send_data_to_existing_connection to {:?}", descriptor);
// connection exists, send over it
conn.send_async(data).await?;
conn.send_async(data)
.await
.wrap_err("sending data to existing connection")?;
// Network accounting
self.network_manager()
@@ -413,13 +412,16 @@ impl Network {
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
let data_len = data.len();
// Handle connectionless protocol
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
let res = ph.send_message(data, peer_socket_addr).await;
let res = ph
.send_message(data, peer_socket_addr)
.await
.wrap_err("failed to send data to dial info");
if res.is_ok() {
// Network accounting
self.network_manager()
@@ -427,7 +429,7 @@ impl Network {
}
return res;
}
return Err("no appropriate UDP protocol handler for dial_info".to_owned());
bail!("no appropriate UDP protocol handler for dial_info");
}
// Handle connection-oriented protocols
@@ -453,7 +455,7 @@ impl Network {
}
#[instrument(level = "debug", err, skip_all)]
pub async fn startup(&self) -> Result<(), String> {
pub async fn startup(&self) -> EyreResult<()> {
// initialize interfaces
let mut interfaces = NetworkInterfaces::new();
interfaces.refresh().await?;
@@ -604,7 +606,7 @@ impl Network {
//////////////////////////////////////////
pub async fn tick(&self) -> Result<(), String> {
pub async fn tick(&self) -> EyreResult<()> {
let network_class = self.get_network_class().unwrap_or(NetworkClass::Invalid);
let routing_table = self.routing_table();

View File

@@ -250,7 +250,7 @@ impl DiscoveryContext {
// If we know we are not behind NAT, check our firewall status
#[instrument(level = "trace", skip(self), err)]
pub async fn protocol_process_no_nat(&self) -> Result<(), String> {
pub async fn protocol_process_no_nat(&self) -> EyreResult<()> {
let (node_1, external_1_dial_info) = {
let inner = self.inner.lock();
(
@@ -281,7 +281,7 @@ impl DiscoveryContext {
// If we know we are behind NAT check what kind
#[instrument(level = "trace", skip(self), ret, err)]
pub async fn protocol_process_nat(&self) -> Result<bool, String> {
pub async fn protocol_process_nat(&self) -> EyreResult<bool> {
let (node_1, external_1_dial_info, external_1_address, protocol_type, address_type) = {
let inner = self.inner.lock();
(
@@ -375,7 +375,7 @@ impl Network {
&self,
context: &DiscoveryContext,
protocol_type: ProtocolType,
) -> Result<(), String> {
) -> EyreResult<()> {
let mut retry_count = {
let c = self.config.get();
c.network.restricted_nat_retries
@@ -437,7 +437,7 @@ impl Network {
&self,
context: &DiscoveryContext,
protocol_type: ProtocolType,
) -> Result<(), String> {
) -> EyreResult<()> {
// Start doing ipv6 protocol
context.protocol_begin(protocol_type, AddressType::IPV6);
@@ -479,7 +479,7 @@ impl Network {
stop_token: StopToken,
_l: u64,
_t: u64,
) -> Result<(), String> {
) -> EyreResult<()> {
// Ensure we aren't trying to update this without clearing it first
let old_network_class = self.inner.lock().network_class;
assert_eq!(old_network_class, None);

View File

@@ -25,14 +25,14 @@ impl ListenerState {
/////////////////////////////////////////////////////////////////
impl Network {
fn get_or_create_tls_acceptor(&self) -> Result<TlsAcceptor, String> {
fn get_or_create_tls_acceptor(&self) -> EyreResult<TlsAcceptor> {
if let Some(ts) = self.inner.lock().tls_acceptor.as_ref() {
return Ok(ts.clone());
}
let server_config = self
.load_server_config()
.map_err(|e| format!("Couldn't create TLS configuration: {}", e))?;
.wrap_err("Couldn't create TLS configuration")?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
self.inner.lock().tls_acceptor = Some(acceptor.clone());
Ok(acceptor)
@@ -45,12 +45,11 @@ impl Network {
addr: SocketAddr,
protocol_handlers: &[Box<dyn ProtocolAcceptHandler>],
tls_connection_initial_timeout_ms: u32,
) -> Result<Option<ProtocolNetworkConnection>, String> {
) -> EyreResult<Option<ProtocolNetworkConnection>> {
let tls_stream = tls_acceptor
.accept(stream)
.await
.map_err(map_to_string)
.map_err(logthru_net!(debug "TLS stream failed handshake"))?;
.wrap_err("TLS stream failed handshake")?;
let ps = AsyncPeekStream::new(tls_stream);
let mut first_packet = [0u8; PEEK_DETECT_LEN];
@@ -63,8 +62,8 @@ impl Network {
ps.peek_exact(&mut first_packet),
)
.await
.map_err(map_to_string)?
.map_err(map_to_string)?;
.wrap_err("tls initial timeout")?
.wrap_err("failed to peek tls stream")?;
self.try_handlers(ps, addr, protocol_handlers).await
}
@@ -74,9 +73,13 @@ impl Network {
stream: AsyncPeekStream,
addr: SocketAddr,
protocol_accept_handlers: &[Box<dyn ProtocolAcceptHandler>],
) -> Result<Option<ProtocolNetworkConnection>, String> {
) -> EyreResult<Option<ProtocolNetworkConnection>> {
for ah in protocol_accept_handlers.iter() {
if let Some(nc) = ah.on_accept(stream.clone(), addr).await? {
if let Some(nc) = ah
.on_accept(stream.clone(), addr)
.await
.wrap_err("io error")?
{
return Ok(Some(nc));
}
}
@@ -114,7 +117,7 @@ impl Network {
return;
}
};
// XXX limiting
// XXX limiting here instead for connection table? may be faster and avoids tls negotiation
log_net!("TCP connection from: {}", addr);
@@ -185,7 +188,7 @@ impl Network {
}
}
async fn spawn_socket_listener(&self, addr: SocketAddr) -> Result<(), String> {
async fn spawn_socket_listener(&self, addr: SocketAddr) -> EyreResult<()> {
// Get config
let (connection_initial_timeout_ms, tls_connection_initial_timeout_ms) = {
let c = self.config.get();
@@ -196,11 +199,12 @@ impl Network {
};
// Create a reusable socket with no linger time, and no delay
let socket = new_bound_shared_tcp_socket(addr)?;
let socket = new_bound_shared_tcp_socket(addr)
.wrap_err("failed to create bound shared tcp socket")?;
// Listen on the socket
socket
.listen(128)
.map_err(|e| format!("Couldn't listen on TCP socket: {}", e))?;
.wrap_err("Couldn't listen on TCP socket")?;
// Make an async tcplistener from the socket2 socket
let std_listener: std::net::TcpListener = socket.into();
@@ -209,7 +213,7 @@ impl Network {
let listener = TcpListener::from(std_listener);
} else if #[cfg(feature="rt-tokio")] {
std_listener.set_nonblocking(true).expect("failed to set nonblocking");
let listener = TcpListener::from_std(std_listener).map_err(map_to_string)?;
let listener = TcpListener::from_std(std_listener).wrap_err("failed to create tokio tcp listener")?;
}
}
@@ -279,7 +283,7 @@ impl Network {
port: u16,
is_tls: bool,
new_protocol_accept_handler: Box<NewProtocolAcceptHandler>,
) -> Result<Vec<SocketAddress>, String> {
) -> EyreResult<Vec<SocketAddress>> {
let mut out = Vec::<SocketAddress>::new();
for ip_addr in ip_addrs {

View File

@@ -3,7 +3,7 @@ use sockets::*;
use stop_token::future::FutureExt;
impl Network {
pub(super) async fn create_udp_listener_tasks(&self) -> Result<(), String> {
pub(super) async fn create_udp_listener_tasks(&self) -> EyreResult<()> {
// Spawn socket tasks
let mut task_count = {
let c = self.config.get();
@@ -73,7 +73,7 @@ impl Network {
.on_recv_envelope(&data[..size], descriptor)
.await
{
log_net!(error "failed to process received udp envelope: {}", e);
log_net!(debug "failed to process received udp envelope: {}", e);
}
}
Ok(Err(_)) => {
@@ -110,7 +110,7 @@ impl Network {
Ok(())
}
pub(super) async fn create_udp_outbound_sockets(&self) -> Result<(), String> {
pub(super) async fn create_udp_outbound_sockets(&self) -> EyreResult<()> {
let mut inner = self.inner.lock();
let mut port = inner.udp_port;
// v4
@@ -119,9 +119,9 @@ impl Network {
// Pull the port if we randomly bound, so v6 can be on the same port
port = socket
.local_addr()
.map_err(map_to_string)?
.wrap_err("failed to get local address")?
.as_socket_ipv4()
.ok_or_else(|| "expected ipv4 address type".to_owned())?
.ok_or_else(|| eyre!("expected ipv4 address type"))?
.port();
// Make an async UdpSocket from the socket2 socket
@@ -131,7 +131,7 @@ impl Network {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make outbound v4 tokio udpsocket")?;
}
}
let socket_arc = Arc::new(udp_socket);
@@ -152,7 +152,7 @@ impl Network {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make outbound v6 tokio udpsocket")?;
}
}
let socket_arc = Arc::new(udp_socket);
@@ -166,7 +166,7 @@ impl Network {
Ok(())
}
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> Result<(), String> {
async fn create_udp_inbound_socket(&self, addr: SocketAddr) -> EyreResult<()> {
log_net!("create_udp_inbound_socket on {:?}", &addr);
// Create a reusable socket
@@ -179,7 +179,7 @@ impl Network {
let udp_socket = UdpSocket::from(std_udp_socket);
} else if #[cfg(feature="rt-tokio")] {
std_udp_socket.set_nonblocking(true).expect("failed to set nonblocking");
let udp_socket = UdpSocket::from_std(std_udp_socket).map_err(map_to_string)?;
let udp_socket = UdpSocket::from_std(std_udp_socket).wrap_err("failed to make inbound tokio udpsocket")?;
}
}
let socket_arc = Arc::new(udp_socket);
@@ -200,7 +200,7 @@ impl Network {
&self,
ip_addrs: Vec<IpAddr>,
port: u16,
) -> Result<Vec<DialInfo>, String> {
) -> EyreResult<Vec<DialInfo>> {
let mut out = Vec::<DialInfo>::new();
for ip_addr in ip_addrs {

View File

@@ -6,6 +6,7 @@ pub mod ws;
use super::*;
use crate::xx::*;
use std::io;
#[derive(Debug)]
pub enum ProtocolNetworkConnection {
@@ -21,7 +22,7 @@ impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
) -> io::Result<ProtocolNetworkConnection> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("Should not connect to UDP dialinfo");
@@ -35,7 +36,7 @@ impl ProtocolNetworkConnection {
}
}
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
@@ -55,7 +56,7 @@ impl ProtocolNetworkConnection {
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> io::Result<Vec<u8>> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
@@ -102,7 +103,7 @@ impl ProtocolNetworkConnection {
// }
// }
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
match self {
Self::Dummy(d) => d.send(message),
Self::RawTcp(t) => t.send(message).await,
@@ -111,7 +112,7 @@ impl ProtocolNetworkConnection {
Self::Wss(w) => w.send(message).await,
}
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
pub async fn recv(&self) -> io::Result<Vec<u8>> {
match self {
Self::Dummy(d) => d.recv(),
Self::RawTcp(t) => t.recv().await,

View File

@@ -1,6 +1,8 @@
use crate::xx::*;
use crate::*;
use async_io::Async;
use std::io;
cfg_if! {
if #[cfg(feature="rt-async-std")] {
pub use async_std::net::{TcpStream, TcpListener, Shutdown, UdpSocket};
@@ -19,12 +21,12 @@ cfg_if! {
use winapi::ctypes::c_int;
use std::os::windows::io::AsRawSocket;
fn set_exclusiveaddruse(socket: &Socket) -> Result<(), String> {
fn set_exclusiveaddruse(socket: &Socket) -> io::Result<()> {
unsafe {
let optval:c_int = 1;
if setsockopt(socket.as_raw_socket().try_into().unwrap(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (&optval as *const c_int).cast(),
std::mem::size_of::<c_int>() as c_int) == SOCKET_ERROR {
return Err("Unable to SO_EXCLUSIVEADDRUSE".to_owned());
return Err(io::Error::last_os_error());
}
Ok(())
}
@@ -32,49 +34,37 @@ cfg_if! {
}
}
pub fn new_unbound_shared_udp_socket(domain: Domain) -> Result<Socket, String> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
pub fn new_unbound_shared_udp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
socket.set_only_v6(true)?;
}
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
pub fn new_bound_shared_udp_socket(local_address: SocketAddr) -> io::Result<Socket> {
let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_udp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
socket.bind(&socket2_addr).map_err(|e| {
format!(
"failed to bind UDP socket to '{}' in domain '{:?}': {} ",
local_address, domain, e
)
})?;
socket.bind(&socket2_addr)?;
log_net!("created bound shared udp socket on {:?}", &local_address);
Ok(socket)
}
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, String> {
pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> io::Result<Socket> {
let domain = Domain::for_address(local_address);
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| format!("Couldn't create UDP socket: {}", e))?;
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
socket.set_only_v6(true)?;
}
// Bind the socket -first- before turning on 'reuse address' this way it will
// fail if the port is already taken
@@ -87,18 +77,15 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
}
}
socket
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind UDP socket: {}", e))?;
socket.bind(&socket2_addr)?;
// Set 'reuse address' so future binds to this port will succeed
// This does not work on Windows, where reuse options can not be set after the bind
cfg_if! {
if #[cfg(unix)] {
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
.set_reuse_address(true)?;
socket.set_reuse_port(true)?;
}
}
log_net!("created bound first udp socket on {:?}", &local_address);
@@ -106,10 +93,8 @@ pub fn new_bound_first_udp_socket(local_address: SocketAddr) -> Result<Socket, S
Ok(socket)
}
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(map_to_string)
.map_err(logthru_net!("failed to create TCP socket"))?;
pub fn new_unbound_shared_tcp_socket(domain: Domain) -> io::Result<Socket> {
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
@@ -117,43 +102,33 @@ pub fn new_unbound_shared_tcp_socket(domain: Domain) -> Result<Socket, String> {
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
socket.set_only_v6(true)?;
}
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
socket.set_reuse_address(true)?;
cfg_if! {
if #[cfg(unix)] {
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
socket.set_reuse_port(true)?;
}
}
Ok(socket)
}
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
pub fn new_bound_shared_tcp_socket(local_address: SocketAddr) -> io::Result<Socket> {
let domain = Domain::for_address(local_address);
let socket = new_unbound_shared_tcp_socket(domain)?;
let socket2_addr = SockAddr::from(local_address);
socket
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
socket.bind(&socket2_addr)?;
log_net!("created bound shared tcp socket on {:?}", &local_address);
Ok(socket)
}
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, String> {
pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> io::Result<Socket> {
let domain = Domain::for_address(local_address);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(map_to_string)
.map_err(logthru_net!("failed to create TCP socket"))?;
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if let Err(e) = socket.set_linger(Some(core::time::Duration::from_secs(0))) {
log_net!(error "Couldn't set TCP linger: {}", e);
}
@@ -161,9 +136,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
log_net!(error "Couldn't set TCP nodelay: {}", e);
}
if domain == Domain::IPV6 {
socket
.set_only_v6(true)
.map_err(|e| format!("Couldn't set IPV6_V6ONLY: {}", e))?;
socket.set_only_v6(true)?;
}
// On windows, do SO_EXCLUSIVEADDRUSE before the bind to ensure the port is fully available
@@ -176,18 +149,15 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
// Bind the socket -first- before turning on 'reuse address' this way it will
// fail if the port is already taken
let socket2_addr = SockAddr::from(local_address);
socket
.bind(&socket2_addr)
.map_err(|e| format!("failed to bind TCP socket: {}", e))?;
socket.bind(&socket2_addr)?;
// Set 'reuse address' so future binds to this port will succeed
// This does not work on Windows, where reuse options can not be set after the bind
cfg_if! {
if #[cfg(unix)] {
socket
.set_reuse_address(true)
.map_err(|e| format!("Couldn't set reuse address: {}", e))?;
socket.set_reuse_port(true).map_err(|e| format!("Couldn't set reuse port: {}", e))?;
.set_reuse_address(true)?;
socket.set_reuse_port(true)?;
}
}
log_net!("created bound first tcp socket on {:?}", &local_address);
@@ -196,7 +166,7 @@ pub fn new_bound_first_tcp_socket(local_address: SocketAddr) -> Result<Socket, S
}
// Non-blocking connect is tricky when you want to start with a prepared socket
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> std::io::Result<TcpStream> {
pub async fn nonblocking_connect(socket: Socket, addr: SocketAddr) -> io::Result<TcpStream> {
// Set for non blocking connect
socket.set_nonblocking(true)?;

View File

@@ -42,47 +42,45 @@ impl RawTcpNetworkConnection {
// }
// }
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> Result<(), String> {
async fn send_internal(stream: &mut AsyncPeekStream, message: Vec<u8>) -> io::Result<()> {
log_net!("sending TCP message of size {}", message.len());
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large TCP message".to_owned());
bail_io_error_other!("sending too large TCP message");
}
let len = message.len() as u16;
let header = [b'V', b'L', len as u8, (len >> 8) as u8];
stream.write_all(&header).await.map_err(map_to_string)?;
stream.write_all(&message).await.map_err(map_to_string)
stream.write_all(&header).await?;
stream.write_all(&message).await
}
#[instrument(level="trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
let mut stream = self.stream.clone();
Self::send_internal(&mut stream, message).await
}
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> Result<Vec<u8>, String> {
pub async fn recv_internal(stream: &mut AsyncPeekStream) -> io::Result<Vec<u8>> {
let mut header = [0u8; 4];
stream
.read_exact(&mut header)
.await
.map_err(|e| format!("TCP recv error: {}", e))?;
stream.read_exact(&mut header).await?;
if header[0] != b'V' || header[1] != b'L' {
return Err("received invalid TCP frame header".to_owned());
bail_io_error_other!("received invalid TCP frame header");
}
let len = ((header[3] as usize) << 8) | (header[2] as usize);
if len > MAX_MESSAGE_SIZE {
return Err("received too large TCP frame".to_owned());
bail_io_error_other!("received too large TCP frame");
}
let mut out: Vec<u8> = vec![0u8; len];
stream.read_exact(&mut out).await.map_err(map_to_string)?;
stream.read_exact(&mut out).await?;
Ok(out)
}
#[instrument(level="trace", err, skip(self), fields(ret.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> {
pub async fn recv(&self) -> io::Result<Vec<u8>> {
let mut stream = self.stream.clone();
let out = Self::recv_internal(&mut stream).await?;
tracing::Span::current().record("ret.len", &out.len());
@@ -121,14 +119,10 @@ impl RawTcpProtocolHandler {
self,
stream: AsyncPeekStream,
socket_addr: SocketAddr,
) -> Result<Option<ProtocolNetworkConnection>, String> {
) -> io::Result<Option<ProtocolNetworkConnection>> {
log_net!("TCP: on_accept_async: enter");
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream
.peek(&mut peekbuf)
.await
.map_err(map_to_string)
.map_err(logthru_net!("could not peek tcp stream"))?;
let peeklen = stream.peek(&mut peekbuf).await?;
assert_eq!(peeklen, PEEK_DETECT_LEN);
let peer_addr = PeerAddress::new(
@@ -150,7 +144,7 @@ impl RawTcpProtocolHandler {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
) -> io::Result<ProtocolNetworkConnection> {
// Get remote socket address to connect to
let remote_socket_addr = dial_info.to_socket_addr();
@@ -163,15 +157,10 @@ impl RawTcpProtocolHandler {
};
// Non-blocking connect to remote address
let ts = nonblocking_connect(socket, remote_socket_addr).await
.map_err(map_to_string)
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
let ts = nonblocking_connect(socket, remote_socket_addr).await?;
// See what local address we ended up with and turn this into a stream
let actual_local_address = ts
.local_addr()
.map_err(map_to_string)
.map_err(logthru_net!("could not get local address from TCP stream"))?;
let actual_local_address = ts.local_addr()?;
#[cfg(feature = "rt-tokio")]
let ts = ts.compat();
let ps = AsyncPeekStream::new(ts);
@@ -189,12 +178,9 @@ impl RawTcpProtocolHandler {
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(
socket_addr: SocketAddr,
data: Vec<u8>,
) -> Result<(), String> {
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound TCP message".to_owned());
bail_io_error_other!("sending too large unbound TCP message");
}
trace!(
"sending unbound message of length {} to {}",
@@ -206,10 +192,7 @@ impl RawTcpProtocolHandler {
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
// Non-blocking connect to remote address
let ts = nonblocking_connect(socket, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
let ts = nonblocking_connect(socket, socket_addr).await?;
// See what local address we ended up with and turn this into a stream
// let actual_local_address = ts
@@ -231,9 +214,9 @@ impl RawTcpProtocolHandler {
socket_addr: SocketAddr,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> io::Result<Vec<u8>> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound TCP message".to_owned());
bail_io_error_other!("sending too large unbound TCP message");
}
trace!(
"sending unbound message of length {} to {}",
@@ -245,10 +228,7 @@ impl RawTcpProtocolHandler {
let socket = new_unbound_shared_tcp_socket(socket2::Domain::for_address(socket_addr))?;
// Non-blocking connect to remote address
let ts = nonblocking_connect(socket, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "remote_addr={}", socket_addr))?;
let ts = nonblocking_connect(socket, socket_addr).await?;
// See what local address we ended up with and turn this into a stream
// let actual_local_address = ts
@@ -265,7 +245,7 @@ impl RawTcpProtocolHandler {
let out = timeout(timeout_ms, RawTcpNetworkConnection::recv_internal(&mut ps))
.await
.map_err(map_to_string)??;
.map_err(|e| e.to_io())??;
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
@@ -277,7 +257,7 @@ impl ProtocolAcceptHandler for RawTcpProtocolHandler {
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<core::result::Result<Option<ProtocolNetworkConnection>, String>> {
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr))
}
}

View File

@@ -12,27 +12,30 @@ impl RawUdpProtocolHandler {
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn recv_message(
&self,
data: &mut [u8],
) -> Result<(usize, ConnectionDescriptor), String> {
let (size, remote_addr) = self.socket.recv_from(data).await.map_err(map_to_string)?;
if size > MAX_MESSAGE_SIZE {
return Err("received too large UDP message".to_owned());
}
trace!(
"receiving UDP message of length {} from {}",
size,
remote_addr
);
pub async fn recv_message(&self, data: &mut [u8]) -> io::Result<(usize, ConnectionDescriptor)> {
let (size, remote_addr) = loop {
match self.socket.recv_from(data).await {
Ok((size, remote_addr)) => {
if size > MAX_MESSAGE_SIZE {
bail_io_error_other!("received too large UDP message");
}
break (size, remote_addr);
}
Err(e) => {
if e.kind() == io::ErrorKind::ConnectionReset {
// Ignore icmp
} else {
return Err(e);
}
}
}
};
let peer_addr = PeerAddress::new(
SocketAddress::from_socket_addr(remote_addr),
ProtocolType::UDP,
);
let local_socket_addr = self.socket.local_addr().map_err(map_to_string)?;
let local_socket_addr = self.socket.local_addr()?;
let descriptor = ConnectionDescriptor::new(
peer_addr,
SocketAddress::from_socket_addr(local_socket_addr),
@@ -44,45 +47,24 @@ impl RawUdpProtocolHandler {
}
#[instrument(level = "trace", err, skip(self, data), fields(data.len = data.len(), ret.len, ret.from))]
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> Result<(), String> {
pub async fn send_message(&self, data: Vec<u8>, socket_addr: SocketAddr) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large UDP message".to_owned()).map_err(logthru_net!(error));
bail_io_error_other!("sending too large UDP message");
}
log_net!(
"sending UDP message of length {} to {}",
data.len(),
socket_addr
);
let len = self
.socket
.send_to(&data, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed udp send: addr={}", socket_addr))?;
let len = self.socket.send_to(&data, socket_addr).await?;
if len != data.len() {
Err("UDP partial send".to_owned()).map_err(logthru_net!(error))
} else {
Ok(())
bail_io_error_other!("UDP partial send")
}
Ok(())
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(
socket_addr: SocketAddr,
data: Vec<u8>,
) -> Result<(), String> {
pub async fn send_unbound_message(socket_addr: SocketAddr, data: Vec<u8>) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound UDP message".to_owned())
.map_err(logthru_net!(error));
bail_io_error_other!("sending too large unbound UDP message");
}
log_net!(
"sending unbound message of length {} to {}",
data.len(),
socket_addr
);
// get local wildcard address for bind
let local_socket_addr = match socket_addr {
@@ -91,20 +73,13 @@ impl RawUdpProtocolHandler {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
}
};
let socket = UdpSocket::bind(local_socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
let len = socket
.send_to(&data, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
let socket = UdpSocket::bind(local_socket_addr).await?;
let len = socket.send_to(&data, socket_addr).await?;
if len != data.len() {
Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error))
} else {
Ok(())
bail_io_error_other!("UDP partial unbound send")
}
Ok(())
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len(), ret.len))]
@@ -112,16 +87,10 @@ impl RawUdpProtocolHandler {
socket_addr: SocketAddr,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> io::Result<Vec<u8>> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound UDP message".to_owned())
.map_err(logthru_net!(error));
bail_io_error_other!("sending too large unbound UDP message");
}
log_net!(
"sending unbound message of length {} to {}",
data.len(),
socket_addr
);
// get local wildcard address for bind
let local_socket_addr = match socket_addr {
@@ -132,29 +101,21 @@ impl RawUdpProtocolHandler {
};
// get unspecified bound socket
let socket = UdpSocket::bind(local_socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to bind unbound udp socket"))?;
let len = socket
.send_to(&data, socket_addr)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed unbound udp send: addr={}", socket_addr))?;
let socket = UdpSocket::bind(local_socket_addr).await?;
let len = socket.send_to(&data, socket_addr).await?;
if len != data.len() {
return Err("UDP partial unbound send".to_owned()).map_err(logthru_net!(error));
bail_io_error_other!("UDP partial unbound send");
}
// receive single response
let mut out = vec![0u8; MAX_MESSAGE_SIZE];
let (len, from_addr) = timeout(timeout_ms, socket.recv_from(&mut out))
.await
.map_err(map_to_string)?
.map_err(map_to_string)?;
.map_err(|e| e.to_io())??;
// if the from address is not the same as the one we sent to, then drop this
if from_addr != socket_addr {
return Err(format!(
bail_io_error_other!(format!(
"Unbound response received from wrong address: addr={}",
from_addr,
));

View File

@@ -17,6 +17,25 @@ cfg_if! {
}
}
fn to_io(err: async_tungstenite::tungstenite::Error) -> io::Error {
let kind = match err {
async_tungstenite::tungstenite::Error::ConnectionClosed => io::ErrorKind::ConnectionReset,
async_tungstenite::tungstenite::Error::AlreadyClosed => io::ErrorKind::NotConnected,
async_tungstenite::tungstenite::Error::Io(x) => {
return x;
}
async_tungstenite::tungstenite::Error::Tls(_) => io::ErrorKind::InvalidData,
async_tungstenite::tungstenite::Error::Capacity(_) => io::ErrorKind::Other,
async_tungstenite::tungstenite::Error::Protocol(_) => io::ErrorKind::Other,
async_tungstenite::tungstenite::Error::SendQueueFull(_) => io::ErrorKind::Other,
async_tungstenite::tungstenite::Error::Utf8 => io::ErrorKind::Other,
async_tungstenite::tungstenite::Error::Url(_) => io::ErrorKind::Other,
async_tungstenite::tungstenite::Error::Http(_) => io::ErrorKind::Other,
async_tungstenite::tungstenite::Error::HttpFormat(_) => io::ErrorKind::Other,
};
io::Error::new(kind, err)
}
pub type WebSocketNetworkConnectionAccepted = WebsocketNetworkConnection<AsyncPeekStream>;
pub struct WebsocketNetworkConnection<T>
@@ -62,41 +81,49 @@ where
// }
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
if message.len() > MAX_MESSAGE_SIZE {
return Err("received too large WS message".to_owned());
bail_io_error_other!("received too large WS message");
}
self.stream
.clone()
.send(Message::binary(message))
.await
.map_err(map_to_string)
.map_err(logthru_net!(error "failed to send websocket message"))
.map_err(to_io)
}
#[instrument(level = "trace", err, skip(self), fields(ret.len))]
pub async fn recv(&self) -> Result<Vec<u8>, String> {
pub async fn recv(&self) -> io::Result<Vec<u8>> {
let out = match self.stream.clone().next().await {
Some(Ok(Message::Binary(v))) => v,
Some(Ok(Message::Close(e))) => {
return Err(format!("WS connection closed: {:?}", e));
Some(Ok(Message::Binary(v))) => {
if v.len() > MAX_MESSAGE_SIZE {
return Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"too large ws message",
));
}
v
}
Some(Ok(Message::Close(_))) => {
return Err(io::Error::new(io::ErrorKind::ConnectionReset, "closeframe"))
}
Some(Ok(x)) => {
return Err(format!("Unexpected WS message type: {:?}", x));
}
Some(Err(e)) => {
return Err(e.to_string()).map_err(logthru_net!(error));
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unexpected WS message type: {:?}", x),
));
}
Some(Err(e)) => return Err(to_io(e)),
None => {
return Err("WS stream closed".to_owned());
return Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection ended",
))
}
};
if out.len() > MAX_MESSAGE_SIZE {
Err("sending too large WS message".to_owned()).map_err(logthru_net!(error))
} else {
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
}
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
}
}
@@ -145,21 +172,18 @@ impl WebsocketProtocolHandler {
self,
ps: AsyncPeekStream,
socket_addr: SocketAddr,
) -> Result<Option<ProtocolNetworkConnection>, String> {
) -> io::Result<Option<ProtocolNetworkConnection>> {
log_net!("WS: on_accept_async: enter");
let request_path_len = self.arc.request_path.len() + 2;
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
match timeout(
if let Err(_) = timeout(
self.arc.connection_initial_timeout_ms,
ps.peek_exact(&mut peekbuf),
)
.await
{
Ok(_) => (),
Err(e) => {
return Err(e.to_string());
}
return Ok(None);
}
// Check for websocket path
@@ -169,15 +193,12 @@ impl WebsocketProtocolHandler {
&& peekbuf[request_path_len - 1] == b' '));
if !matches_path {
log_net!("WS: not websocket");
return Ok(None);
}
log_net!("WS: found websocket");
let ws_stream = accept_async(ps)
.await
.map_err(map_to_string)
.map_err(logthru_net!("failed websockets handshake"))?;
.map_err(|e| io_error_other!(format!("failed websockets handshake: {}", e)))?;
// Wrap the websocket in a NetworkConnection and register it
let protocol_type = if self.arc.tls {
@@ -205,7 +226,7 @@ impl WebsocketProtocolHandler {
async fn connect_internal(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
) -> io::Result<ProtocolNetworkConnection> {
// Split dial info up
let (tls, scheme) = match &dial_info {
DialInfo::WS(_) => (false, "ws"),
@@ -213,9 +234,9 @@ impl WebsocketProtocolHandler {
_ => panic!("invalid dialinfo for WS/WSS protocol"),
};
let request = dial_info.request().unwrap();
let split_url = SplitUrl::from_str(&request)?;
let split_url = SplitUrl::from_str(&request).map_err(to_io_error_other)?;
if split_url.scheme != scheme {
return Err("invalid websocket url scheme".to_string());
bail_io_error_other!("invalid websocket url scheme");
}
let domain = split_url.host.clone();
@@ -231,12 +252,10 @@ impl WebsocketProtocolHandler {
};
// Non-blocking connect to remote address
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await
.map_err(map_to_string)
.map_err(logthru_net!(error "local_address={:?} remote_addr={}", local_address, remote_socket_addr))?;
let tcp_stream = nonblocking_connect(socket, remote_socket_addr).await?;
// See what local address we ended up with
let actual_local_addr = tcp_stream.local_addr().map_err(map_to_string)?;
let actual_local_addr = tcp_stream.local_addr()?;
#[cfg(feature = "rt-tokio")]
let tcp_stream = tcp_stream.compat();
@@ -249,15 +268,10 @@ impl WebsocketProtocolHandler {
// Negotiate TLS if this is WSS
if tls {
let connector = TlsConnector::default();
let tls_stream = connector
.connect(domain.to_string(), tcp_stream)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
let tls_stream = connector.connect(domain.to_string(), tcp_stream).await?;
let (ws_stream, _response) = client_async(request, tls_stream)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
.map_err(to_io_error_other)?;
Ok(ProtocolNetworkConnection::Wss(
WebsocketNetworkConnection::new(descriptor, ws_stream),
@@ -265,8 +279,7 @@ impl WebsocketProtocolHandler {
} else {
let (ws_stream, _response) = client_async(request, tcp_stream)
.await
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
.map_err(to_io_error_other)?;
Ok(ProtocolNetworkConnection::Ws(
WebsocketNetworkConnection::new(descriptor, ws_stream),
))
@@ -277,19 +290,17 @@ impl WebsocketProtocolHandler {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
) -> io::Result<ProtocolNetworkConnection> {
Self::connect_internal(local_address, dial_info).await
}
#[instrument(level = "trace", err, skip(data), fields(data.len = data.len()))]
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> Result<(), String> {
pub async fn send_unbound_message(dial_info: DialInfo, data: Vec<u8>) -> io::Result<()> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());
bail_io_error_other!("sending too large unbound WS message");
}
let protconn = Self::connect_internal(None, dial_info.clone())
.await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
protconn.send(data).await
}
@@ -299,19 +310,17 @@ impl WebsocketProtocolHandler {
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> io::Result<Vec<u8>> {
if data.len() > MAX_MESSAGE_SIZE {
return Err("sending too large unbound WS message".to_owned());
bail_io_error_other!("sending too large unbound WS message");
}
let protconn = Self::connect_internal(None, dial_info.clone())
.await
.map_err(|e| format!("failed to connect websocket for unbound message: {}", e))?;
let protconn = Self::connect_internal(None, dial_info.clone()).await?;
protconn.send(data).await?;
let out = timeout(timeout_ms, protconn.recv())
.await
.map_err(map_to_string)??;
.map_err(|e| e.to_io())??;
tracing::Span::current().record("ret.len", &out.len());
Ok(out)
@@ -323,7 +332,7 @@ impl ProtocolAcceptHandler for WebsocketProtocolHandler {
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>> {
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr))
}
}

View File

@@ -164,7 +164,7 @@ impl Network {
/////////////////////////////////////////////////////
fn find_available_udp_port(&self) -> Result<u16, String> {
fn find_available_udp_port(&self) -> EyreResult<u16> {
// If the address is empty, iterate ports until we find one we can use.
let mut udp_port = 5150u16;
loop {
@@ -175,14 +175,14 @@ impl Network {
break;
}
if udp_port == 65535 {
return Err("Could not find free udp port to listen on".to_owned());
bail!("Could not find free udp port to listen on");
}
udp_port += 1;
}
Ok(udp_port)
}
fn find_available_tcp_port(&self) -> Result<u16, String> {
fn find_available_tcp_port(&self) -> EyreResult<u16> {
// If the address is empty, iterate ports until we find one we can use.
let mut tcp_port = 5150u16;
loop {
@@ -193,17 +193,14 @@ impl Network {
break;
}
if tcp_port == 65535 {
return Err("Could not find free tcp port to listen on".to_owned());
bail!("Could not find free tcp port to listen on");
}
tcp_port += 1;
}
Ok(tcp_port)
}
async fn allocate_udp_port(
&self,
listen_address: String,
) -> Result<(u16, Vec<IpAddr>), String> {
async fn allocate_udp_port(&self, listen_address: String) -> EyreResult<(u16, Vec<IpAddr>)> {
if listen_address.is_empty() {
// If listen address is empty, find us a port iteratively
let port = self.find_available_udp_port()?;
@@ -217,21 +214,17 @@ impl Network {
// If the address is specified, only use the specified port and fail otherwise
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
if sockaddrs.is_empty() {
return Err(format!("No valid listen address: {}", listen_address));
bail!("No valid listen address: {}", listen_address);
}
let port = sockaddrs[0].port();
if self.bind_first_udp_port(port) {
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
} else {
Err("Could not find free udp port to listen on".to_owned())
if !self.bind_first_udp_port(port) {
bail!("Could not find free udp port to listen on");
}
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
}
}
async fn allocate_tcp_port(
&self,
listen_address: String,
) -> Result<(u16, Vec<IpAddr>), String> {
async fn allocate_tcp_port(&self, listen_address: String) -> EyreResult<(u16, Vec<IpAddr>)> {
if listen_address.is_empty() {
// If listen address is empty, find us a port iteratively
let port = self.find_available_tcp_port()?;
@@ -245,20 +238,19 @@ impl Network {
// If the address is specified, only use the specified port and fail otherwise
let sockaddrs = listen_address_to_socket_addrs(&listen_address)?;
if sockaddrs.is_empty() {
return Err(format!("No valid listen address: {}", listen_address));
bail!("No valid listen address: {}", listen_address);
}
let port = sockaddrs[0].port();
if self.bind_first_tcp_port(port) {
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
} else {
Err("Could not find free tcp port to listen on".to_owned())
if !self.bind_first_tcp_port(port) {
bail!("Could not find free tcp port to listen on");
}
Ok((port, sockaddrs.iter().map(|s| s.ip()).collect()))
}
}
/////////////////////////////////////////////////////
pub(super) async fn start_udp_listeners(&self) -> Result<(), String> {
pub(super) async fn start_udp_listeners(&self) -> EyreResult<()> {
trace!("starting udp listeners");
let routing_table = self.routing_table();
let (listen_address, public_address, enable_local_peer_scope) = {
@@ -319,7 +311,7 @@ impl Network {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
.wrap_err(format!("Unable to resolve address: {}", public_address))?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {
@@ -364,7 +356,7 @@ impl Network {
self.create_udp_listener_tasks().await
}
pub(super) async fn start_ws_listeners(&self) -> Result<(), String> {
pub(super) async fn start_ws_listeners(&self) -> EyreResult<()> {
trace!("starting ws listeners");
let routing_table = self.routing_table();
let (listen_address, url, path, enable_local_peer_scope) = {
@@ -405,9 +397,9 @@ impl Network {
// Add static public dialinfo if it's configured
if let Some(url) = url.as_ref() {
let mut split_url = SplitUrl::from_str(url)?;
let mut split_url = SplitUrl::from_str(url).wrap_err("couldn't split url")?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
bail!("WS URL must use 'ws://' scheme");
}
split_url.scheme = "ws".to_owned();
@@ -415,13 +407,11 @@ impl Network {
let global_socket_addrs = split_url
.host_port(80)
.to_socket_addrs()
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
.wrap_err("failed to resolve ws url")?;
for gsa in global_socket_addrs {
let pdi = DialInfo::try_ws(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
.wrap_err("try_ws failed")?;
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
@@ -458,9 +448,7 @@ impl Network {
}
// Build dial info request url
let local_url = format!("ws://{}/{}", socket_address, path);
let local_di = DialInfo::try_ws(socket_address, local_url)
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
let local_di = DialInfo::try_ws(socket_address, local_url).wrap_err("try_ws failed")?;
if url.is_none() && (socket_address.address().is_global() || enable_local_peer_scope) {
// Register public dial info
@@ -490,7 +478,7 @@ impl Network {
Ok(())
}
pub(super) async fn start_wss_listeners(&self) -> Result<(), String> {
pub(super) async fn start_wss_listeners(&self) -> EyreResult<()> {
trace!("starting wss listeners");
let routing_table = self.routing_table();
@@ -538,7 +526,7 @@ impl Network {
// Add static public dialinfo if it's configured
let mut split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
bail!("WSS URL must use 'wss://' scheme");
}
split_url.scheme = "wss".to_owned();
@@ -546,13 +534,10 @@ impl Network {
let global_socket_addrs = split_url
.host_port(443)
.to_socket_addrs()
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
.wrap_err("failed to resolve wss url")?;
for gsa in global_socket_addrs {
let pdi = DialInfo::try_wss(SocketAddress::from_socket_addr(gsa), url.clone())
.map_err(map_to_string)
.map_err(logthru_net!(error))?;
.wrap_err("try_wss failed")?;
routing_table.register_dial_info(
RoutingDomain::PublicInternet,
@@ -581,7 +566,7 @@ impl Network {
registered_addresses.insert(gsa.ip());
}
} else {
return Err("WSS URL must be specified due to TLS requirements".to_owned());
bail!("WSS URL must be specified due to TLS requirements");
}
if static_public {
@@ -594,7 +579,7 @@ impl Network {
Ok(())
}
pub(super) async fn start_tcp_listeners(&self) -> Result<(), String> {
pub(super) async fn start_tcp_listeners(&self) -> EyreResult<()> {
trace!("starting tcp listeners");
let routing_table = self.routing_table();
@@ -659,7 +644,7 @@ impl Network {
// Resolve statically configured public dialinfo
let mut public_sockaddrs = public_address
.to_socket_addrs()
.map_err(|e| format!("Unable to resolve address: {}\n{}", public_address, e))?;
.wrap_err("failed to resolve tcp address")?;
// Add all resolved addresses as public dialinfo
for pdi_addr in &mut public_sockaddrs {

View File

@@ -1,5 +1,6 @@
use super::*;
use futures_util::{FutureExt, StreamExt};
use std::io;
use stop_token::prelude::*;
cfg_if::cfg_if! {
@@ -15,7 +16,7 @@ cfg_if::cfg_if! {
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<Option<ProtocolNetworkConnection>, String>>;
) -> SystemPinBoxFuture<io::Result<Option<ProtocolNetworkConnection>>>;
}
pub trait ProtocolAcceptHandlerClone {
@@ -52,13 +53,13 @@ impl DummyNetworkConnection {
pub fn descriptor(&self) -> ConnectionDescriptor {
self.descriptor.clone()
}
pub fn close(&self) -> Result<(), String> {
// pub fn close(&self) -> Result<(), String> {
// Ok(())
// }
pub fn send(&self, _message: Vec<u8>) -> io::Result<()> {
Ok(())
}
pub fn send(&self, _message: Vec<u8>) -> Result<(), String> {
Ok(())
}
pub fn recv(&self) -> Result<Vec<u8>, String> {
pub fn recv(&self) -> io::Result<Vec<u8>> {
Ok(Vec::new())
}
}
@@ -178,7 +179,7 @@ impl NetworkConnection {
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
message: Vec<u8>,
) -> Result<(), String> {
) -> io::Result<()> {
let ts = intf::get_timestamp();
let out = protocol_connection.send(message).await;
if out.is_ok() {
@@ -190,7 +191,7 @@ impl NetworkConnection {
async fn recv_internal(
protocol_connection: &ProtocolNetworkConnection,
stats: Arc<Mutex<NetworkConnectionStats>>,
) -> Result<Vec<u8>, String> {
) -> io::Result<Vec<u8>> {
let ts = intf::get_timestamp();
let out = protocol_connection.recv().await;
if out.is_ok() {
@@ -222,7 +223,7 @@ impl NetworkConnection {
) -> SystemPinBoxFuture<()> {
Box::pin(async move {
log_net!(
"Starting process_connection loop for {:?}",
"== Starting process_connection loop for {:?}",
descriptor.green()
);
@@ -236,7 +237,7 @@ impl NetworkConnection {
let new_timer = || {
intf::sleep(connection_inactivity_timeout_ms).then(|_| async {
// timeout
log_net!("connection timeout on {:?}", descriptor.green());
log_net!("== Connection timeout on {:?}", descriptor.green());
RecvLoopAction::Timeout
})
};
@@ -288,7 +289,7 @@ impl NetworkConnection {
.on_recv_envelope(message.as_slice(), descriptor)
.await
{
log_net!(error e);
log_net!(debug "failed to process received envelope: {}", e);
RecvLoopAction::Finish
} else {
RecvLoopAction::Recv
@@ -296,7 +297,7 @@ impl NetworkConnection {
}
Err(e) => {
// Connection unable to receive, closed
log_net!(warn e);
log_net!(debug e);
RecvLoopAction::Finish
}
}

View File

@@ -56,7 +56,7 @@ impl Network {
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
let data_len = data.len();
let res = match dial_info.protocol_type() {
@@ -90,7 +90,7 @@ impl Network {
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> EyreResult<Vec<u8>> {
let data_len = data.len();
let out = match dial_info.protocol_type() {
ProtocolType::UDP => {
@@ -124,7 +124,7 @@ impl Network {
&self,
descriptor: ConnectionDescriptor,
data: Vec<u8>,
) -> Result<Option<Vec<u8>>, String> {
) -> EyreResult<Option<Vec<u8>>> {
let data_len = data.len();
match descriptor.protocol_type() {
ProtocolType::UDP => {
@@ -161,7 +161,7 @@ impl Network {
&self,
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
) -> EyreResult<()> {
let data_len = data.len();
if dial_info.protocol_type() == ProtocolType::UDP {
return Err("no support for UDP protocol".to_owned()).map_err(logthru_net!(error))
@@ -187,7 +187,7 @@ impl Network {
/////////////////////////////////////////////////////////////////
pub async fn startup(&self) -> Result<(), String> {
pub async fn startup(&self) -> EyreResult<()> {
// get protocol config
self.inner.lock().protocol_config = Some({
let c = self.config.get();
@@ -269,7 +269,7 @@ impl Network {
}
//////////////////////////////////////////
pub async fn tick(&self) -> Result<(), String> {
pub async fn tick(&self) -> EyreResult<()> {
Ok(())
}
}

View File

@@ -15,7 +15,7 @@ impl ProtocolNetworkConnection {
pub async fn connect(
local_address: Option<SocketAddr>,
dial_info: DialInfo,
) -> Result<ProtocolNetworkConnection, String> {
) -> io::Result<ProtocolNetworkConnection> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not supported on WASM targets");
@@ -32,7 +32,7 @@ impl ProtocolNetworkConnection {
pub async fn send_unbound_message(
dial_info: DialInfo,
data: Vec<u8>,
) -> Result<(), String> {
) -> io::Result<()> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not supported on WASM targets");
@@ -50,7 +50,7 @@ impl ProtocolNetworkConnection {
dial_info: DialInfo,
data: Vec<u8>,
timeout_ms: u32,
) -> Result<Vec<u8>, String> {
) -> io::Result<Vec<u8>> {
match dial_info.protocol_type() {
ProtocolType::UDP => {
panic!("UDP dial info is not supported on WASM targets");
@@ -72,20 +72,20 @@ impl ProtocolNetworkConnection {
}
}
pub async fn close(&self) -> Result<(), String> {
match self {
Self::Dummy(d) => d.close(),
Self::Ws(w) => w.close().await,
}
}
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
// pub async fn close(&self) -> io::Result<()> {
// match self {
// Self::Dummy(d) => d.close(),
// Self::Ws(w) => w.close().await,
// }
// }
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
match self {
Self::Dummy(d) => d.send(message),
Self::Ws(w) => w.send(message).await,
}
}
pub async fn recv(&self) -> Result<Vec<u8>, String> {
pub async fn recv(&self) -> io::Result<Vec<u8>> {
match self {
Self::Dummy(d) => d.recv(),
Self::Ws(w) => w.recv().await,

View File

@@ -1,12 +1,19 @@
use super::*;
use ws_stream_wasm::*;
use futures_util::{StreamExt, SinkExt};
use std::io;
struct WebsocketNetworkConnectionInner {
ws_meta: WsMeta,
ws_stream: CloneStream<WsStream>,
}
fn to_io(err: WsErr) -> io::Error {
let kind = match err {
WsErr::InvalidWsState {supplied:_} => io::ErrorKind::
}
}
#[derive(Clone)]
pub struct WebsocketNetworkConnection {
descriptor: ConnectionDescriptor,
@@ -36,15 +43,15 @@ impl WebsocketNetworkConnection {
self.descriptor.clone()
}
#[instrument(level = "trace", err, skip(self))]
pub async fn close(&self) -> Result<(), String> {
self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
}
// #[instrument(level = "trace", err, skip(self))]
// pub async fn close(&self) -> Result<(), String> {
// self.inner.ws_meta.close().await.map_err(map_to_string).map(drop)
// }
#[instrument(level = "trace", err, skip(self, message), fields(message.len = message.len()))]
pub async fn send(&self, message: Vec<u8>) -> Result<(), String> {
pub async fn send(&self, message: Vec<u8>) -> io::Result<()> {
if message.len() > MAX_MESSAGE_SIZE {
return Err("sending too large WS message".to_owned()).map_err(logthru_net!(error));
bail_io_error_other!("sending too large WS message");
}
self.inner.ws_stream.clone()
.send(WsMessage::Binary(message)).await