use super::*; use crate::crypto::*; use alloc::fmt; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// pub struct NodeRefBaseCommon { routing_table: RoutingTable, node_id: DHTKey, entry: Arc, filter: Option, sequencing: Sequencing, #[cfg(feature = "tracking")] track_id: usize, } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// pub trait NodeRefBase: Sized { // Common field access fn common(&self) -> &NodeRefBaseCommon; fn common_mut(&mut self) -> &mut NodeRefBaseCommon; // Implementation-specific operators fn operate(&self, f: F) -> T where F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T; fn operate_mut(&self, f: F) -> T where F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T; // Filtering fn filter_ref(&self) -> Option<&NodeRefFilter> { self.common().filter.as_ref() } fn take_filter(&mut self) -> Option { self.common_mut().filter.take() } fn set_filter(&mut self, filter: Option) { self.common_mut().filter = filter } fn set_sequencing(&mut self, sequencing: Sequencing) { self.common_mut().sequencing = sequencing; } fn sequencing(&self) -> Sequencing { self.common().sequencing } fn merge_filter(&mut self, filter: NodeRefFilter) { let common_mut = self.common_mut(); if let Some(self_filter) = common_mut.filter.take() { common_mut.filter = Some(self_filter.filtered(&filter)); } else { common_mut.filter = Some(filter); } } fn is_filter_dead(&self) -> bool { if let Some(filter) = &self.common().filter { filter.is_dead() } else { false } } fn routing_domain_set(&self) -> RoutingDomainSet { self.common() .filter .as_ref() .map(|f| f.routing_domain_set) .unwrap_or(RoutingDomainSet::all()) } fn dial_info_filter(&self) -> DialInfoFilter { self.common() .filter .as_ref() .map(|f| f.dial_info_filter.clone()) .unwrap_or(DialInfoFilter::all()) } fn best_routing_domain(&self) -> Option { self.operate(|rti, e| { e.best_routing_domain( rti, self.common() .filter .as_ref() .map(|f| f.routing_domain_set) .unwrap_or(RoutingDomainSet::all()), ) }) } // Accessors fn routing_table(&self) -> RoutingTable { self.common().routing_table.clone() } fn node_id(&self) -> DHTKey { self.common().node_id } fn has_updated_since_last_network_change(&self) -> bool { self.operate(|_rti, e| e.has_updated_since_last_network_change()) } fn set_updated_since_last_network_change(&self) { self.operate_mut(|_rti, e| e.set_updated_since_last_network_change(true)); } fn update_node_status(&self, node_status: NodeStatus) { self.operate_mut(|_rti, e| { e.update_node_status(node_status); }); } fn min_max_version(&self) -> Option { self.operate(|_rti, e| e.min_max_version()) } fn set_min_max_version(&self, min_max_version: VersionRange) { self.operate_mut(|_rti, e| e.set_min_max_version(min_max_version)) } fn state(&self, cur_ts: u64) -> BucketEntryState { self.operate(|_rti, e| e.state(cur_ts)) } fn peer_stats(&self) -> PeerStats { self.operate(|_rti, e| e.peer_stats().clone()) } // Per-RoutingDomain accessors fn make_peer_info(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rti, e| e.make_peer_info(self.node_id(), routing_domain)) } fn node_info(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rti, e| e.node_info(routing_domain).cloned()) } fn signed_node_info_has_valid_signature(&self, routing_domain: RoutingDomain) -> bool { self.operate(|_rti, e| { e.signed_node_info(routing_domain) .map(|sni| sni.has_valid_signature()) .unwrap_or(false) }) } fn node_info_ts(&self, routing_domain: RoutingDomain) -> u64 { self.operate(|_rti, e| { e.signed_node_info(routing_domain) .map(|sni| sni.timestamp()) .unwrap_or(0u64) }) } fn has_seen_our_node_info_ts( &self, routing_domain: RoutingDomain, our_node_info_ts: u64, ) -> bool { self.operate(|_rti, e| e.has_seen_our_node_info_ts(routing_domain, our_node_info_ts)) } fn set_our_node_info_ts(&self, routing_domain: RoutingDomain, seen_ts: u64) { self.operate_mut(|_rti, e| e.set_our_node_info_ts(routing_domain, seen_ts)); } fn network_class(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rt, e| e.node_info(routing_domain).map(|n| n.network_class)) } fn outbound_protocols(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rt, e| e.node_info(routing_domain).map(|n| n.outbound_protocols)) } fn address_types(&self, routing_domain: RoutingDomain) -> Option { self.operate(|_rt, e| e.node_info(routing_domain).map(|n| n.address_types)) } fn node_info_outbound_filter(&self, routing_domain: RoutingDomain) -> DialInfoFilter { let mut dif = DialInfoFilter::all(); if let Some(outbound_protocols) = self.outbound_protocols(routing_domain) { dif = dif.with_protocol_type_set(outbound_protocols); } if let Some(address_types) = self.address_types(routing_domain) { dif = dif.with_address_type_set(address_types); } dif } fn relay(&self, routing_domain: RoutingDomain) -> Option { self.operate_mut(|rti, e| { e.signed_node_info(routing_domain) .and_then(|n| n.relay_peer_info()) .and_then(|t| { // If relay is ourselves, then return None, because we can't relay through ourselves // and to contact this node we should have had an existing inbound connection if t.node_id.key == rti.unlocked_inner.node_id { return None; } // Register relay node and return noderef rti.register_node_with_signed_node_info( self.routing_table(), routing_domain, t.node_id.key, t.signed_node_info, false, ) }) }) } // Filtered accessors fn first_filtered_dial_info_detail(&self) -> Option { let routing_domain_set = self.routing_domain_set(); let dial_info_filter = self.dial_info_filter(); let (sort, dial_info_filter) = match self.common().sequencing { Sequencing::NoPreference => (None, dial_info_filter), Sequencing::PreferOrdered => ( Some(DialInfoDetail::ordered_sequencing_sort), dial_info_filter, ), Sequencing::EnsureOrdered => ( Some(DialInfoDetail::ordered_sequencing_sort), dial_info_filter.filtered( &DialInfoFilter::all().with_protocol_type_set(ProtocolType::all_ordered_set()), ), ), }; self.operate(|_rt, e| { for routing_domain in routing_domain_set { if let Some(ni) = e.node_info(routing_domain) { let filter = |did: &DialInfoDetail| did.matches_filter(&dial_info_filter); if let Some(did) = ni.first_filtered_dial_info_detail(sort, filter) { return Some(did); } } } None }) } fn all_filtered_dial_info_details(&self) -> Vec { let routing_domain_set = self.routing_domain_set(); let dial_info_filter = self.dial_info_filter(); let (sort, dial_info_filter) = match self.common().sequencing { Sequencing::NoPreference => (None, dial_info_filter), Sequencing::PreferOrdered => ( Some(DialInfoDetail::ordered_sequencing_sort), dial_info_filter, ), Sequencing::EnsureOrdered => ( Some(DialInfoDetail::ordered_sequencing_sort), dial_info_filter.filtered( &DialInfoFilter::all().with_protocol_type_set(ProtocolType::all_ordered_set()), ), ), }; let mut out = Vec::new(); self.operate(|_rt, e| { for routing_domain in routing_domain_set { if let Some(ni) = e.node_info(routing_domain) { let filter = |did: &DialInfoDetail| did.matches_filter(&dial_info_filter); if let Some(did) = ni.first_filtered_dial_info_detail(sort, filter) { out.push(did); } } } }); out.remove_duplicates(); out } fn last_connection(&self) -> Option { // Get the last connections and the last time we saw anything with this connection // Filtered first and then sorted by most recent self.operate(|rti, e| { let last_connections = e.last_connections(rti, true, self.common().filter.clone()); last_connections.first().map(|x| x.0) }) } fn clear_last_connections(&self) { self.operate_mut(|_rti, e| e.clear_last_connections()) } fn set_last_connection(&self, connection_descriptor: ConnectionDescriptor, ts: u64) { self.operate_mut(|rti, e| { e.set_last_connection(connection_descriptor, ts); rti.touch_recent_peer(self.common().node_id, connection_descriptor); }) } fn has_any_dial_info(&self) -> bool { self.operate(|_rti, e| { for rtd in RoutingDomain::all() { if let Some(sni) = e.signed_node_info(rtd) { if sni.has_any_dial_info() { return true; } } } false }) } fn stats_question_sent(&self, ts: u64, bytes: u64, expects_answer: bool) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_up(bytes); e.question_sent(ts, bytes, expects_answer); }) } fn stats_question_rcvd(&self, ts: u64, bytes: u64) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_down(bytes); e.question_rcvd(ts, bytes); }) } fn stats_answer_sent(&self, bytes: u64) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_up(bytes); e.answer_sent(bytes); }) } fn stats_answer_rcvd(&self, send_ts: u64, recv_ts: u64, bytes: u64) { self.operate_mut(|rti, e| { rti.transfer_stats_accounting().add_down(bytes); rti.latency_stats_accounting() .record_latency(recv_ts - send_ts); e.answer_rcvd(send_ts, recv_ts, bytes); }) } fn stats_question_lost(&self) { self.operate_mut(|_rti, e| { e.question_lost(); }) } fn stats_failed_to_send(&self, ts: u64, expects_answer: bool) { self.operate_mut(|_rti, e| { e.failed_to_send(ts, expects_answer); }) } } //////////////////////////////////////////////////////////////////////////////////// /// Reference to a routing table entry /// Keeps entry in the routing table until all references are gone pub struct NodeRef { common: NodeRefBaseCommon, } impl NodeRef { pub fn new( routing_table: RoutingTable, node_id: DHTKey, entry: Arc, filter: Option, ) -> Self { entry.ref_count.fetch_add(1u32, Ordering::Relaxed); Self { common: NodeRefBaseCommon { routing_table, node_id, entry, filter, sequencing: Sequencing::NoPreference, #[cfg(feature = "tracking")] track_id: entry.track(), }, } } pub fn filtered_clone(&self, filter: NodeRefFilter) -> Self { let mut out = self.clone(); out.merge_filter(filter); out } pub fn locked<'a>(&self, rti: &'a RoutingTableInner) -> NodeRefLocked<'a> { NodeRefLocked::new(rti, self.clone()) } pub fn locked_mut<'a>(&self, rti: &'a mut RoutingTableInner) -> NodeRefLockedMut<'a> { NodeRefLockedMut::new(rti, self.clone()) } } impl NodeRefBase for NodeRef { fn common(&self) -> &NodeRefBaseCommon { &self.common } fn common_mut(&mut self) -> &mut NodeRefBaseCommon { &mut self.common } fn operate(&self, f: F) -> T where F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, { let inner = &*self.common.routing_table.inner.read(); self.common.entry.with(inner, f) } fn operate_mut(&self, f: F) -> T where F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, { let inner = &mut *self.common.routing_table.inner.write(); self.common.entry.with_mut(inner, f) } } impl Clone for NodeRef { fn clone(&self) -> Self { self.common .entry .ref_count .fetch_add(1u32, Ordering::Relaxed); Self { common: NodeRefBaseCommon { routing_table: self.common.routing_table.clone(), node_id: self.common.node_id, entry: self.common.entry.clone(), filter: self.common.filter.clone(), sequencing: self.common.sequencing, #[cfg(feature = "tracking")] track_id: self.common.entry.write().track(), }, } } } impl fmt::Display for NodeRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.common.node_id.encode()) } } impl fmt::Debug for NodeRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("NodeRef") .field("node_id", &self.common.node_id) .field("filter", &self.common.filter) .field("sequencing", &self.common.sequencing) .finish() } } impl Drop for NodeRef { fn drop(&mut self) { #[cfg(feature = "tracking")] self.common.entry.write().untrack(self.track_id); // drop the noderef and queue a bucket kick if it was the last one let new_ref_count = self .common .entry .ref_count .fetch_sub(1u32, Ordering::Relaxed) - 1; if new_ref_count == 0 { self.common .routing_table .queue_bucket_kick(self.common.node_id); } } } //////////////////////////////////////////////////////////////////////////////////// /// Locked reference to a routing table entry /// For internal use inside the RoutingTable module where you have /// already locked a RoutingTableInner /// Keeps entry in the routing table until all references are gone pub struct NodeRefLocked<'a> { inner: Mutex<&'a RoutingTableInner>, nr: NodeRef, } impl<'a> NodeRefLocked<'a> { pub fn new(inner: &'a RoutingTableInner, nr: NodeRef) -> Self { Self { inner: Mutex::new(inner), nr, } } } impl<'a> NodeRefBase for NodeRefLocked<'a> { fn common(&self) -> &NodeRefBaseCommon { &self.nr.common } fn common_mut(&mut self) -> &mut NodeRefBaseCommon { &mut self.nr.common } fn operate(&self, f: F) -> T where F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, { let inner = &*self.inner.lock(); self.nr.common.entry.with(inner, f) } fn operate_mut(&self, _f: F) -> T where F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, { panic!("need to locked_mut() for this operation") } } impl<'a> fmt::Display for NodeRefLocked<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.nr) } } impl<'a> fmt::Debug for NodeRefLocked<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("NodeRefLocked") .field("nr", &self.nr) .finish() } } //////////////////////////////////////////////////////////////////////////////////// /// Mutable locked reference to a routing table entry /// For internal use inside the RoutingTable module where you have /// already locked a RoutingTableInner /// Keeps entry in the routing table until all references are gone pub struct NodeRefLockedMut<'a> { inner: Mutex<&'a mut RoutingTableInner>, nr: NodeRef, } impl<'a> NodeRefLockedMut<'a> { pub fn new(inner: &'a mut RoutingTableInner, nr: NodeRef) -> Self { Self { inner: Mutex::new(inner), nr, } } } impl<'a> NodeRefBase for NodeRefLockedMut<'a> { fn common(&self) -> &NodeRefBaseCommon { &self.nr.common } fn common_mut(&mut self) -> &mut NodeRefBaseCommon { &mut self.nr.common } fn operate(&self, f: F) -> T where F: FnOnce(&RoutingTableInner, &BucketEntryInner) -> T, { let inner = &*self.inner.lock(); self.nr.common.entry.with(inner, f) } fn operate_mut(&self, f: F) -> T where F: FnOnce(&mut RoutingTableInner, &mut BucketEntryInner) -> T, { let inner = &mut *self.inner.lock(); self.nr.common.entry.with_mut(inner, f) } } impl<'a> fmt::Display for NodeRefLockedMut<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.nr) } } impl<'a> fmt::Debug for NodeRefLockedMut<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("NodeRefLockedMut") .field("nr", &self.nr) .finish() } }