diff --git a/veilid-core/src/network_manager/native/network_class_discovery.rs b/veilid-core/src/network_manager/native/network_class_discovery.rs index bd06082b..b4ad5f0b 100644 --- a/veilid-core/src/network_manager/native/network_class_discovery.rs +++ b/veilid-core/src/network_manager/native/network_class_discovery.rs @@ -491,6 +491,15 @@ impl Network { assert_eq!(old_network_class, None); let protocol_config = self.inner.lock().protocol_config.unwrap_or_default(); + let tcp_same_port = if protocol_config.inbound.contains(ProtocolType::TCP) + && protocol_config.inbound.contains(ProtocolType::WS) + { + let inner = self.inner.lock(); + inner.tcp_port == inner.ws_port + } else { + false + }; + let mut unord = FuturesUnordered::new(); // Do UDPv4+v6 at the same time as everything else @@ -536,11 +545,10 @@ impl Network { // Do TCPv4 + WSv4 in series because they may use the same connection 5-tuple if protocol_config.family_global.contains(AddressType::IPV4) { - unord.push( - async { - // TCPv4 - let mut out = Vec::::new(); - if protocol_config.inbound.contains(ProtocolType::TCP) { + if protocol_config.inbound.contains(ProtocolType::TCP) { + unord.push( + async { + // TCPv4 let tcpv4_context = DiscoveryContext::new(self.routing_table(), self.clone()); if let Err(e) = self @@ -550,11 +558,16 @@ impl Network { log_net!(debug "Failed TCPv4 dialinfo discovery: {}", e); return None; } - out.push(tcpv4_context); + Some(vec![tcpv4_context]) } + .boxed(), + ); + } - // WSv4 - if protocol_config.inbound.contains(ProtocolType::WS) { + if protocol_config.inbound.contains(ProtocolType::WS) && !tcp_same_port { + unord.push( + async { + // WSv4 let wsv4_context = DiscoveryContext::new(self.routing_table(), self.clone()); if let Err(e) = self @@ -564,21 +577,19 @@ impl Network { log_net!(debug "Failed WSv4 dialinfo discovery: {}", e); return None; } - out.push(wsv4_context); + Some(vec![wsv4_context]) } - Some(out) - } - .boxed(), - ); + .boxed(), + ); + } } // Do TCPv6 + WSv6 in series because they may use the same connection 5-tuple if protocol_config.family_global.contains(AddressType::IPV6) { - unord.push( - async { - // TCPv6 - let mut out = Vec::::new(); - if protocol_config.inbound.contains(ProtocolType::TCP) { + if protocol_config.inbound.contains(ProtocolType::TCP) { + unord.push( + async { + // TCPv6 let tcpv6_context = DiscoveryContext::new(self.routing_table(), self.clone()); if let Err(e) = self @@ -588,11 +599,16 @@ impl Network { log_net!(debug "Failed TCPv6 dialinfo discovery: {}", e); return None; } - out.push(tcpv6_context); + Some(vec![tcpv6_context]) } + .boxed(), + ); + } - // WSv6 - if protocol_config.inbound.contains(ProtocolType::WS) { + // WSv6 + if protocol_config.inbound.contains(ProtocolType::WS) && !tcp_same_port { + unord.push( + async { let wsv6_context = DiscoveryContext::new(self.routing_table(), self.clone()); if let Err(e) = self @@ -602,12 +618,11 @@ impl Network { log_net!(debug "Failed WSv6 dialinfo discovery: {}", e); return None; } - out.push(wsv6_context); + Some(vec![wsv6_context]) } - Some(out) - } - .boxed(), - ); + .boxed(), + ); + } } // Wait for all discovery futures to complete and collect contexts @@ -659,6 +674,19 @@ impl Network { ) { log_net!(warn "Failed to register detected public dial info: {}", e); } + + // duplicate for same port + if tcp_same_port && pdi.dial_info.protocol_type() == ProtocolType::TCP { + let ws_dial_info = + ctx.make_dial_info(pdi.dial_info.socket_address(), ProtocolType::WS); + if let Err(e) = routing_table.register_dial_info( + RoutingDomain::PublicInternet, + ws_dial_info, + pdi.class, + ) { + log_net!(warn "Failed to register detected public dial info: {}", e); + } + } } } // Update network class