This commit is contained in:
John Smith
2021-11-27 19:56:56 -05:00
parent 028e02f942
commit 45489d0e9c
9 changed files with 109 additions and 94 deletions

View File

@@ -10,7 +10,7 @@ pub trait TcpProtocolHandler: TcpProtocolHandlerClone + Send + Sync {
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SendPinBoxFuture<Result<bool, ()>>;
) -> SendPinBoxFuture<Result<bool, String>>;
}
pub trait TcpProtocolHandlerClone {

View File

@@ -125,9 +125,12 @@ impl RawTcpProtocolHandler {
self,
stream: AsyncPeekStream,
socket_addr: SocketAddr,
) -> Result<bool, ()> {
) -> Result<bool, String> {
let mut peekbuf: [u8; PEEK_DETECT_LEN] = [0u8; PEEK_DETECT_LEN];
let peeklen = stream.peek(&mut peekbuf).await.map_err(drop)?;
let peeklen = stream
.peek(&mut peekbuf)
.await
.map_err(|e| format!("could not peek tcp stream: {}", e))?;
assert_eq!(peeklen, PEEK_DETECT_LEN);
let conn = NetworkConnection::RawTcp(RawTcpNetworkConnection::new(stream));
@@ -150,12 +153,13 @@ impl RawTcpProtocolHandler {
network_manager: NetworkManager,
preferred_local_address: Option<SocketAddr>,
remote_socket_addr: SocketAddr,
) -> Result<NetworkConnection, ()> {
) -> Result<NetworkConnection, String> {
// Make a low level socket that can connect to the remote socket address
// and attempt to reuse the local address that our listening socket uses
// for hole-punch compatibility
let domain = Domain::for_address(remote_socket_addr);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).map_err(drop)?;
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(|e| format!("could not create tcp socket: {}", e))?;
if let Err(e) = socket.set_linger(None) {
warn!("Couldn't set TCP linger: {}", e);
}
@@ -183,12 +187,16 @@ impl RawTcpProtocolHandler {
// Connect to the remote address
let remote_socket2_addr = socket2::SockAddr::from(remote_socket_addr);
socket.connect(&remote_socket2_addr).map_err(drop)?;
socket
.connect(&remote_socket2_addr)
.map_err(|e| format!("couln't connect tcp: {}", e))?;
let std_stream: std::net::TcpStream = socket.into();
let ts = TcpStream::from(std_stream);
// See what local address we ended up with and turn this into a stream
let local_address = ts.local_addr().map_err(drop)?;
let local_address = ts
.local_addr()
.map_err(|e| format!("couldn't get local address for tcp socket: {}", e))?;
let ps = AsyncPeekStream::new(ts);
let peer_addr = PeerAddress::new(
Address::from_socket_addr(remote_socket_addr),
@@ -227,7 +235,7 @@ impl TcpProtocolHandler for RawTcpProtocolHandler {
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SendPinBoxFuture<Result<bool, ()>> {
) -> SendPinBoxFuture<Result<bool, String>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr))
}
}

View File

@@ -168,7 +168,7 @@ impl WebsocketProtocolHandler {
self,
ps: AsyncPeekStream,
socket_addr: SocketAddr,
) -> Result<bool, ()> {
) -> Result<bool, String> {
let request_path_len = self.inner.request_path.len() + 2;
let mut peekbuf: Vec<u8> = vec![0u8; request_path_len];
match io::timeout(
@@ -179,8 +179,7 @@ impl WebsocketProtocolHandler {
{
Ok(_) => (),
Err(e) => {
trace!("failed to peek stream: {:?}", e);
return Err(());
return Err(format!("failed to peek stream: {:?}", e));
}
}
// Check for websocket path
@@ -198,8 +197,7 @@ impl WebsocketProtocolHandler {
let ws_stream = match accept_async(ps).await {
Ok(s) => s,
Err(e) => {
trace!("failed websockets handshake: {:?}", e);
return Err(());
return Err(format!("failed websockets handshake: {:?}", e));
}
};
@@ -234,7 +232,7 @@ impl WebsocketProtocolHandler {
pub async fn connect(
network_manager: NetworkManager,
dial_info: &DialInfo,
) -> Result<NetworkConnection, ()> {
) -> Result<NetworkConnection, String> {
let (tls, request, domain, port, protocol_type) = match &dial_info {
DialInfo::WS(di) => (
false,
@@ -255,9 +253,13 @@ impl WebsocketProtocolHandler {
let tcp_stream = TcpStream::connect(format!("{}:{}", &domain, &port))
.await
.map_err(drop)?;
let local_addr = tcp_stream.local_addr().map_err(drop)?;
let peer_socket_addr = tcp_stream.peer_addr().map_err(drop)?;
.map_err(|e| format!("failed to connect tcp stream: {}", e))?;
let local_addr = tcp_stream
.local_addr()
.map_err(|e| format!("can't get local address for tcp stream: {}", e))?;
let peer_socket_addr = tcp_stream
.peer_addr()
.map_err(|e| format!("can't get peer address for tcp stream: {}", e))?;
let peer_addr = PeerAddress::new(
Address::from_socket_addr(peer_socket_addr),
peer_socket_addr.port(),
@@ -266,8 +268,13 @@ impl WebsocketProtocolHandler {
if tls {
let connector = TlsConnector::default();
let tls_stream = connector.connect(domain, tcp_stream).await.map_err(drop)?;
let (ws_stream, _response) = client_async(request, tls_stream).await.map_err(drop)?;
let tls_stream = connector
.connect(domain, tcp_stream)
.await
.map_err(|e| format!("can't connect tls: {}", e))?;
let (ws_stream, _response) = client_async(request, tls_stream)
.await
.map_err(|e| format!("wss negotation failed: {}", e))?;
let conn = NetworkConnection::Wss(WebsocketNetworkConnection::new(tls, ws_stream));
network_manager
.on_new_connection(
@@ -277,7 +284,9 @@ impl WebsocketProtocolHandler {
.await?;
Ok(conn)
} else {
let (ws_stream, _response) = client_async(request, tcp_stream).await.map_err(drop)?;
let (ws_stream, _response) = client_async(request, tcp_stream)
.await
.map_err(|e| format!("ws negotiate failed: {}", e))?;
let conn = NetworkConnection::Ws(WebsocketNetworkConnection::new(tls, ws_stream));
network_manager
.on_new_connection(
@@ -295,7 +304,7 @@ impl TcpProtocolHandler for WebsocketProtocolHandler {
&self,
stream: AsyncPeekStream,
peer_addr: SocketAddr,
) -> SystemPinBoxFuture<Result<bool, ()>> {
) -> SystemPinBoxFuture<Result<bool, String>> {
Box::pin(self.clone().on_accept_async(stream, peer_addr))
}
}