refactor websocket veilid_config and update scripts

This commit is contained in:
John Smith
2021-12-09 16:00:47 -05:00
parent de36b0d6d6
commit ea8ffea1c9
19 changed files with 797 additions and 145 deletions
+32 -32
View File
@@ -860,11 +860,11 @@ impl Network {
pub async fn start_ws_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address, path) = {
let (listen_address, url, path) = {
let c = self.config.get();
(
c.network.protocol.ws.listen_address.clone(),
c.network.protocol.ws.public_address.clone(),
c.network.protocol.ws.url.clone(),
c.network.protocol.ws.path.clone(),
)
};
@@ -888,21 +888,20 @@ impl Network {
);
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
let (public_fqdn, public_port) = split_port(public_address).map_err(|_| {
"invalid WS public address, port not specified correctly".to_owned()
})?;
let public_port = public_port
.ok_or_else(|| "port must be specified for public WS address".to_owned())?;
if let Some(url) = url.as_ref() {
let split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "ws" {
return Err("WS URL must use 'ws://' scheme".to_owned());
}
routing_table.register_global_dial_info(
DialInfo::ws(fqdn, public_port, public_fqdn),
Some(NetworkClass::Server),
DialInfoOrigin::Static,
);
} else {
routing_table.register_global_dial_info(
DialInfo::ws(fqdn, port, path.clone()),
DialInfo::ws(
split_url.host,
split_url.port.unwrap_or(80),
split_url
.path
.map(|p| p.to_string())
.unwrap_or_else(|| "/".to_string()),
),
Some(NetworkClass::Server),
DialInfoOrigin::Static,
);
@@ -914,11 +913,11 @@ impl Network {
pub async fn start_wss_listeners(&self) -> Result<(), String> {
let routing_table = self.routing_table();
let (listen_address, public_address, path) = {
let (listen_address, url, path) = {
let c = self.config.get();
(
c.network.protocol.wss.listen_address.clone(),
c.network.protocol.wss.public_address.clone(),
c.network.protocol.wss.url.clone(),
c.network.protocol.wss.path.clone(),
)
};
@@ -943,24 +942,25 @@ impl Network {
);
// Add static public dialinfo if it's configured
if let Some(public_address) = public_address.as_ref() {
let (public_fqdn, public_port) = split_port(public_address).map_err(|_| {
"invalid WSS public address, port not specified correctly".to_owned()
})?;
let public_port = public_port
.ok_or_else(|| "port must be specified for public WSS address".to_owned())?;
if let Some(url) = url.as_ref() {
let split_url = SplitUrl::from_str(url)?;
if split_url.scheme.to_ascii_lowercase() != "wss" {
return Err("WSS URL must use 'wss://' scheme".to_owned());
}
routing_table.register_global_dial_info(
DialInfo::wss(fqdn, public_port, public_fqdn),
None,
DialInfo::wss(
split_url.host,
split_url.port.unwrap_or(443),
split_url
.path
.map(|p| p.to_string())
.unwrap_or_else(|| "/".to_string()),
),
Some(NetworkClass::Server),
DialInfoOrigin::Static,
);
} else {
routing_table.register_global_dial_info(
DialInfo::wss(fqdn, port, path.clone()),
None,
DialInfoOrigin::Static,
);
return Err("WSS URL must be specified due to TLS requirements".to_owned());
}
self.inner.lock().wss_listen = true;
+5 -1
View File
@@ -227,10 +227,11 @@ impl NetworkManager {
}
pub async fn tick(&self) -> Result<(), String> {
let (net, lease_manager, receipt_manager) = {
let (routing_table, net, lease_manager, receipt_manager) = {
let inner = self.inner.lock();
let components = inner.components.as_ref().unwrap();
(
inner.routing_table.as_ref().unwrap().clone(),
components.net.clone(),
components.lease_manager.clone(),
components.receipt_manager.clone(),
@@ -244,6 +245,9 @@ impl NetworkManager {
net.startup().await?;
}
// Run the routing table tick
routing_table.tick().await?;
// Run the low level network tick
net.tick().await?;
+7 -5
View File
@@ -205,14 +205,14 @@ impl RoutingTable {
});
info!(
"Local Dial Info: {} ({:?})",
"Local Dial Info: {}",
NodeDialInfoSingle {
node_id: NodeId::new(inner.node_id),
dial_info
}
.to_string(),
origin,
);
debug!(" Origin: {:?}", origin);
}
pub fn clear_local_dial_info(&self) {
@@ -281,15 +281,15 @@ impl RoutingTable {
});
info!(
"Public Dial Info: {} ({:?}#{:?})",
"Public Dial Info: {}",
NodeDialInfoSingle {
node_id: NodeId::new(inner.node_id),
dial_info
}
.to_string(),
origin,
network_class,
);
debug!(" Origin: {:?}", origin);
debug!(" Network Class: {:?}", network_class);
}
pub fn clear_global_dial_info(&self) {
@@ -613,6 +613,8 @@ impl RoutingTable {
c.network.bootstrap.clone()
};
trace!("Bootstrap task with: {:?}", bootstrap);
// Map all bootstrap entries to a single key with multiple dialinfo
let mut bsmap: BTreeMap<DHTKey, Vec<DialInfo>> = BTreeMap::new();
for b in bootstrap {
@@ -311,6 +311,131 @@ pub async fn test_sleep() {
}
}
macro_rules! assert_split_url {
($url:expr, $scheme:expr, $host:expr) => {
assert_eq!(
SplitUrl::from_str($url),
Ok(SplitUrl::new($scheme, None, $host, None, None))
);
};
($url:expr, $scheme:expr, $host:expr, $port:expr) => {
assert_eq!(
SplitUrl::from_str($url),
Ok(SplitUrl::new($scheme, None, $host, $port, None))
);
};
($url:expr, $scheme:expr, $host:expr, $port:expr, $path:expr) => {
assert_eq!(
SplitUrl::from_str($url),
Ok(SplitUrl::new(
$scheme,
None,
$host,
$port,
Some(SplitUrlPath::new(
$path,
Option::<String>::None,
Option::<String>::None
))
))
);
};
($url:expr, $scheme:expr, $host:expr, $port:expr, $path:expr, $frag:expr, $query:expr) => {
assert_eq!(
SplitUrl::from_str($url),
Ok(SplitUrl::new(
$scheme,
None,
$host,
$port,
Some(SplitUrlPath::new($path, $frag, $query))
))
);
};
}
macro_rules! assert_split_url_parse {
($url:expr) => {
let url = $url;
let su1 = SplitUrl::from_str(url).expect("should parse");
assert_eq!(su1.to_string(), url);
};
}
macro_rules! assert_err {
($ex:expr) => {
if let Ok(v) = $ex {
panic!("assertion failed, expected Err(..), got {:?}", v);
}
};
}
pub async fn test_split_url() {
info!("testing split_url");
assert_split_url!("http://foo", "http", "foo");
assert_split_url!("http://foo:1234", "http", "foo", Some(1234));
assert_split_url!("http://foo:1234/", "http", "foo", Some(1234), "");
assert_split_url!(
"http://foo:1234/asdf/qwer",
"http",
"foo",
Some(1234),
"asdf/qwer"
);
assert_split_url!("http://foo/", "http", "foo", None, "");
assert_split_url!("http://foo/asdf/qwer", "http", "foo", None, "asdf/qwer");
assert_split_url!(
"http://foo/asdf/qwer#3",
"http",
"foo",
None,
"asdf/qwer",
Some("3"),
Option::<String>::None
);
assert_split_url!(
"http://foo/asdf/qwer?xxx",
"http",
"foo",
None,
"asdf/qwer",
Option::<String>::None,
Some("xxx")
);
assert_split_url!(
"http://foo/asdf/qwer#yyy?xxx",
"http",
"foo",
None,
"asdf/qwer",
Some("yyy"),
Some("xxx")
);
assert_err!(SplitUrl::from_str("://asdf"));
assert_err!(SplitUrl::from_str(""));
assert_err!(SplitUrl::from_str("::"));
assert_err!(SplitUrl::from_str("://:"));
assert_err!(SplitUrl::from_str("a://:"));
assert_err!(SplitUrl::from_str("a://:1243"));
assert_err!(SplitUrl::from_str("a://:65536"));
assert_err!(SplitUrl::from_str("a://:-16"));
assert_err!(SplitUrl::from_str("a:///"));
assert_err!(SplitUrl::from_str("a:///qwer:"));
assert_err!(SplitUrl::from_str("a:///qwer://"));
assert_err!(SplitUrl::from_str("a://qwer://"));
assert_split_url_parse!("sch://foo:bar@baz.com:1234/fnord#qux?zuz");
assert_split_url_parse!("sch://foo:bar@baz.com:1234/fnord#qux");
assert_split_url_parse!("sch://foo:bar@baz.com:1234/fnord?zuz");
assert_split_url_parse!("sch://foo:bar@baz.com:1234/fnord/");
assert_split_url_parse!("sch://foo:bar@baz.com:1234//");
assert_split_url_parse!("sch://foo:bar@baz.com:1234");
assert_split_url_parse!("sch://@baz.com:1234");
assert_split_url_parse!("sch://baz.com/asdf/asdf");
assert_split_url_parse!("sch://baz.com/");
assert_split_url_parse!("s://s");
}
pub async fn test_protected_store() {
info!("testing protected store");
@@ -518,6 +643,7 @@ pub async fn test_all() {
test_log().await;
test_get_timestamp().await;
test_tools().await;
test_split_url().await;
test_get_random_u64().await;
test_get_random_u32().await;
test_sleep().await;
@@ -189,13 +189,16 @@ pub fn config_callback(key: String) -> Result<Box<dyn core::any::Any>, String> {
"network.tls.certificate_path" => Ok(Box::new(get_certfile_path())),
"network.tls.private_key_path" => Ok(Box::new(get_keyfile_path())),
"network.tls.connection_initial_timeout" => Ok(Box::new(2_000_000u64)),
"network.application.path" => Ok(Box::new(String::from("/app"))),
"network.application.https.enabled" => Ok(Box::new(true)),
"network.application.https.enabled" => Ok(Box::new(false)),
"network.application.https.listen_address" => Ok(Box::new(String::from("[::1]:5150"))),
"network.application.http.enabled" => Ok(Box::new(true)),
"network.application.https.path" => Ok(Box::new(String::from("app"))),
"network.application.https.url" => Ok(Box::new(Option::<String>::None)),
"network.application.http.enabled" => Ok(Box::new(false)),
"network.application.http.listen_address" => Ok(Box::new(String::from("[::1]:5150"))),
"network.application.http.path" => Ok(Box::new(String::from("app"))),
"network.application.http.url" => Ok(Box::new(Option::<String>::None)),
"network.protocol.udp.enabled" => Ok(Box::new(true)),
"network.protocol.udp.socket_pool_size" => Ok(Box::new(0u32)),
"network.protocol.udp.socket_pool_size" => Ok(Box::new(16u32)),
"network.protocol.udp.listen_address" => Ok(Box::new(String::from("[::1]:5150"))),
"network.protocol.udp.public_address" => Ok(Box::new(Option::<String>::None)),
"network.protocol.tcp.connect" => Ok(Box::new(true)),
@@ -203,23 +206,27 @@ pub fn config_callback(key: String) -> Result<Box<dyn core::any::Any>, String> {
"network.protocol.tcp.max_connections" => Ok(Box::new(32u32)),
"network.protocol.tcp.listen_address" => Ok(Box::new(String::from("[::1]:5150"))),
"network.protocol.tcp.public_address" => Ok(Box::new(Option::<String>::None)),
"network.protocol.ws.connect" => Ok(Box::new(true)),
"network.protocol.ws.listen" => Ok(Box::new(true)),
"network.protocol.ws.connect" => Ok(Box::new(false)),
"network.protocol.ws.listen" => Ok(Box::new(false)),
"network.protocol.ws.max_connections" => Ok(Box::new(16u32)),
"network.protocol.ws.listen_address" => Ok(Box::new(String::from("[::1]:5150"))),
"network.protocol.ws.path" => Ok(Box::new(String::from("/ws"))),
"network.protocol.ws.public_address" => Ok(Box::new(Option::<String>::None)),
"network.protocol.wss.connect" => Ok(Box::new(true)),
"network.protocol.wss.listen" => Ok(Box::new(true)),
"network.protocol.ws.path" => Ok(Box::new(String::from("ws"))),
"network.protocol.ws.url" => Ok(Box::new(Option::<String>::None)),
"network.protocol.wss.connect" => Ok(Box::new(false)),
"network.protocol.wss.listen" => Ok(Box::new(false)),
"network.protocol.wss.max_connections" => Ok(Box::new(16u32)),
"network.protocol.wss.listen_address" => Ok(Box::new(String::from("[::1]:5150"))),
"network.protocol.wss.path" => Ok(Box::new(String::from("/ws"))),
"network.protocol.wss.public_address" => Ok(Box::new(Option::<String>::None)),
"network.protocol.wss.path" => Ok(Box::new(String::from("ws"))),
"network.protocol.wss.url" => Ok(Box::new(Option::<String>::None)),
"network.leases.max_server_signal_leases" => Ok(Box::new(256u32)),
"network.leases.max_server_relay_leases" => Ok(Box::new(8u32)),
"network.leases.max_client_signal_leases" => Ok(Box::new(2u32)),
"network.leases.max_client_relay_leases" => Ok(Box::new(2u32)),
_ => Err(format!("config key '{}' doesn't exist", key)),
_ => {
let err = format!("config key '{}' doesn't exist", key);
debug!("{}", err);
Err(err)
}
}
}
@@ -278,13 +285,16 @@ pub async fn test_config() {
assert_eq!(inner.network.tls.private_key_path, get_keyfile_path());
assert_eq!(inner.network.tls.connection_initial_timeout, 2_000_000u64);
assert_eq!(inner.network.application.path, "/app");
assert_eq!(inner.network.application.https.enabled, true);
assert_eq!(inner.network.application.https.enabled, false);
assert_eq!(inner.network.application.https.listen_address, "[::1]:5150");
assert_eq!(inner.network.application.http.enabled, true);
assert_eq!(inner.network.application.https.path, "app");
assert_eq!(inner.network.application.https.url, None);
assert_eq!(inner.network.application.http.enabled, false);
assert_eq!(inner.network.application.http.listen_address, "[::1]:5150");
assert_eq!(inner.network.application.http.path, "app");
assert_eq!(inner.network.application.http.url, None);
assert_eq!(inner.network.protocol.udp.enabled, true);
assert_eq!(inner.network.protocol.udp.socket_pool_size, 0u32);
assert_eq!(inner.network.protocol.udp.socket_pool_size, 16u32);
assert_eq!(inner.network.protocol.udp.listen_address, "[::1]:5150");
assert_eq!(inner.network.protocol.udp.public_address, None);
assert_eq!(inner.network.protocol.tcp.connect, true);
@@ -292,18 +302,18 @@ pub async fn test_config() {
assert_eq!(inner.network.protocol.tcp.max_connections, 32u32);
assert_eq!(inner.network.protocol.tcp.listen_address, "[::1]:5150");
assert_eq!(inner.network.protocol.tcp.public_address, None);
assert_eq!(inner.network.protocol.ws.connect, true);
assert_eq!(inner.network.protocol.ws.listen, true);
assert_eq!(inner.network.protocol.ws.connect, false);
assert_eq!(inner.network.protocol.ws.listen, false);
assert_eq!(inner.network.protocol.ws.max_connections, 16u32);
assert_eq!(inner.network.protocol.ws.listen_address, "[::1]:5150");
assert_eq!(inner.network.protocol.ws.path, "/ws");
assert_eq!(inner.network.protocol.ws.public_address, None);
assert_eq!(inner.network.protocol.wss.connect, true);
assert_eq!(inner.network.protocol.wss.listen, true);
assert_eq!(inner.network.protocol.ws.path, "ws");
assert_eq!(inner.network.protocol.ws.url, None);
assert_eq!(inner.network.protocol.wss.connect, false);
assert_eq!(inner.network.protocol.wss.listen, false);
assert_eq!(inner.network.protocol.wss.max_connections, 16u32);
assert_eq!(inner.network.protocol.wss.listen_address, "[::1]:5150");
assert_eq!(inner.network.protocol.wss.path, "/ws");
assert_eq!(inner.network.protocol.wss.public_address, None);
assert_eq!(inner.network.protocol.wss.path, "ws");
assert_eq!(inner.network.protocol.wss.url, None);
}
pub async fn test_all() {
+3 -3
View File
@@ -385,7 +385,7 @@ impl DialInfo {
let addr: IpAddr = di
.fqdn
.parse()
.map_err(|e| format!("Failed to parse WS fqdn: {}", e))?;
.map_err(|e| format!("Failed to parse WSS fqdn: {}", e))?;
Ok(addr)
}
}
@@ -896,7 +896,7 @@ impl fmt::Debug for VeilidAPIInner {
impl Drop for VeilidAPIInner {
fn drop(&mut self) {
if let Some(core) = self.core.take() {
intf::spawn_local(core.internal_shutdown()).detach();
intf::spawn_local(core.shutdown()).detach();
}
}
}
@@ -953,7 +953,7 @@ impl VeilidAPI {
pub async fn shutdown(self) {
let core = { self.inner.lock().core.take() };
if let Some(core) = core {
core.internal_shutdown().await;
core.shutdown().await;
}
}
+102 -10
View File
@@ -14,17 +14,20 @@ cfg_if! {
pub struct VeilidConfigHTTPS {
pub enabled: bool,
pub listen_address: String,
pub path: String,
pub url: Option<String>, // Fixed URL is not optional for TLS-based protocols and is dynamically validated
}
#[derive(Default, Clone)]
pub struct VeilidConfigHTTP {
pub enabled: bool,
pub listen_address: String,
pub path: String,
pub url: Option<String>,
}
#[derive(Default, Clone)]
pub struct VeilidConfigApplication {
pub path: String,
pub https: VeilidConfigHTTPS,
pub http: VeilidConfigHTTP,
}
@@ -53,7 +56,7 @@ pub struct VeilidConfigWS {
pub max_connections: u32,
pub listen_address: String,
pub path: String,
pub public_address: Option<String>,
pub url: Option<String>,
}
#[derive(Default, Clone)]
@@ -63,7 +66,7 @@ pub struct VeilidConfigWSS {
pub max_connections: u32,
pub listen_address: String,
pub path: String,
pub public_address: Option<String>,
pub url: Option<String>, // Fixed URL is not optional for TLS-based protocols and is dynamically validated
}
#[derive(Default, Clone)]
@@ -184,9 +187,11 @@ impl VeilidConfig {
macro_rules! get_config {
($key:expr) => {
let keyname = &stringify!($key)[6..];
$key = *cb(keyname.to_owned())?
.downcast()
.map_err(|_| format!("incorrect type for key: {}", keyname))?;
$key = *cb(keyname.to_owned())?.downcast().map_err(|_| {
let err = format!("incorrect type for key: {}", keyname);
debug!("{}", err);
err
})?;
};
}
@@ -232,11 +237,14 @@ impl VeilidConfig {
get_config!(inner.network.tls.certificate_path);
get_config!(inner.network.tls.private_key_path);
get_config!(inner.network.tls.connection_initial_timeout);
get_config!(inner.network.application.path);
get_config!(inner.network.application.https.enabled);
get_config!(inner.network.application.https.listen_address);
get_config!(inner.network.application.https.path);
get_config!(inner.network.application.https.url);
get_config!(inner.network.application.http.enabled);
get_config!(inner.network.application.http.listen_address);
get_config!(inner.network.application.http.path);
get_config!(inner.network.application.http.url);
get_config!(inner.network.protocol.udp.enabled);
get_config!(inner.network.protocol.udp.socket_pool_size);
get_config!(inner.network.protocol.udp.listen_address);
@@ -251,13 +259,13 @@ impl VeilidConfig {
get_config!(inner.network.protocol.ws.max_connections);
get_config!(inner.network.protocol.ws.listen_address);
get_config!(inner.network.protocol.ws.path);
get_config!(inner.network.protocol.ws.public_address);
get_config!(inner.network.protocol.ws.url);
get_config!(inner.network.protocol.wss.connect);
get_config!(inner.network.protocol.wss.listen);
get_config!(inner.network.protocol.wss.max_connections);
get_config!(inner.network.protocol.wss.listen_address);
get_config!(inner.network.protocol.wss.path);
get_config!(inner.network.protocol.wss.public_address);
get_config!(inner.network.protocol.wss.url);
get_config!(inner.network.leases.max_server_signal_leases);
get_config!(inner.network.leases.max_server_relay_leases);
get_config!(inner.network.leases.max_client_signal_leases);
@@ -266,7 +274,12 @@ impl VeilidConfig {
// Initialize node id as early as possible because it is used
// for encryption purposes all over the program
self.init_node_id().await
self.init_node_id().await?;
// Validate settings
self.validate().await?;
Ok(())
}
pub async fn terminate(&self) {
@@ -277,6 +290,85 @@ impl VeilidConfig {
self.inner.read()
}
async fn validate(&self) -> Result<(), String> {
let inner = self.inner.read();
if inner.network.protocol.udp.enabled {
// Validate UDP settings
if inner.network.protocol.udp.socket_pool_size == 0 {
return Err("UDP socket pool size must be > 0 in config key 'network.protocol.udp.socket_pool_size'".to_owned());
}
}
if inner.network.protocol.tcp.listen {
// Validate TCP settings
if inner.network.protocol.tcp.max_connections == 0 {
return Err("TCP max connections must be > 0 in config key 'network.protocol.tcp.max_connections'".to_owned());
}
}
if inner.network.protocol.ws.listen {
// Validate WS settings
if inner.network.protocol.ws.max_connections == 0 {
return Err("WS max connections must be > 0 in config key 'network.protocol.ws.max_connections'".to_owned());
}
if inner.network.application.https.enabled
&& inner.network.application.https.path == inner.network.protocol.ws.path
{
return Err("WS path conflicts with HTTPS application path in config key 'network.protocol.ws.path'".to_owned());
}
if inner.network.application.http.enabled
&& inner.network.application.http.path == inner.network.protocol.ws.path
{
return Err("WS path conflicts with HTTP application path in config key 'network.protocol.ws.path'".to_owned());
}
}
if inner.network.protocol.wss.listen {
// Validate WSS settings
if inner.network.protocol.wss.max_connections == 0 {
return Err("WSS max connections must be > 0 in config key 'network.protocol.wss.max_connections'".to_owned());
}
if inner
.network
.protocol
.wss
.url
.as_ref()
.map(|u| u.is_empty())
.unwrap_or_default()
{
return Err(
"WSS URL must be specified in config key 'network.protocol.wss.url'".to_owned(),
);
}
if inner.network.application.https.enabled
&& inner.network.application.https.path == inner.network.protocol.wss.path
{
return Err("WSS path conflicts with HTTPS application path in config key 'network.protocol.ws.path'".to_owned());
}
if inner.network.application.http.enabled
&& inner.network.application.http.path == inner.network.protocol.wss.path
{
return Err("WSS path conflicts with HTTP application path in config key 'network.protocol.ws.path'".to_owned());
}
}
if inner.network.application.https.enabled {
// Validate HTTPS settings
if inner
.network
.application
.https
.url
.as_ref()
.map(|u| u.is_empty())
.unwrap_or_default()
{
return Err(
"HTTPS URL must be specified in config key 'network.application.https.url'"
.to_owned(),
);
}
}
Ok(())
}
// Get the node id from config if one is specified
async fn init_node_id(&self) -> Result<(), String> {
let mut inner = self.inner.write();
+8 -4
View File
@@ -163,15 +163,13 @@ impl VeilidCore {
match self.internal_startup(&mut *inner, setup).await {
Ok(v) => Ok(v),
Err(e) => {
self.clone().internal_shutdown().await;
Self::internal_shutdown(&mut *inner).await;
Err(e)
}
}
}
// stop the node gracefully because the veilid api was dropped
pub(crate) async fn internal_shutdown(self) {
let mut inner = self.inner.lock();
async fn internal_shutdown(inner: &mut VeilidCoreInner) {
trace!("VeilidCore::internal_shutdown starting");
// Detach the API object
@@ -204,5 +202,11 @@ impl VeilidCore {
trace!("VeilidCore::shutdown complete");
}
// stop the node gracefully because the veilid api was dropped
pub(crate) async fn shutdown(self) {
let mut inner = self.inner.lock();
Self::internal_shutdown(&mut *inner);
}
//
}
+2
View File
@@ -7,12 +7,14 @@ mod ip_addr_port;
mod ip_extra;
mod single_future;
mod single_shot_eventual;
mod split_url;
mod tick_task;
mod tools;
pub use cfg_if::*;
pub use log::*;
pub use parking_lot::*;
pub use split_url::*;
pub use static_assertions::*;
pub type PinBox<T> = Pin<Box<T>>;
+325
View File
@@ -0,0 +1,325 @@
// Loose subset interpretation of the URL standard
// Not using full Url crate here for no_std compatibility
//
// Caveats:
// No support for query string parsing
// No support for paths with ';' parameters
// URLs must convert to UTF8
// Only IP address and DNS hostname host fields are supported
use super::IpAddr;
use core::fmt;
use core::str::FromStr;
fn is_alphanum(c: u8) -> bool {
matches!(c,
b'A'..=b'Z'
| b'a'..=b'z'
| b'0'..=b'9'
)
}
fn is_mark(c: u8) -> bool {
matches!(
c,
b'-' | b'_' | b'.' | b'!' | b'~' | b'*' | b'\'' | b'(' | b')'
)
}
fn is_unreserved(c: u8) -> bool {
is_alphanum(c) || is_mark(c)
}
fn must_encode_userinfo(c: u8) -> bool {
!(is_unreserved(c) || matches!(c, b'%' | b':' | b';' | b'&' | b'=' | b'+' | b'$' | b','))
}
fn must_encode_path(c: u8) -> bool {
!(is_unreserved(c)
|| matches!(
c,
b'%' | b'/' | b':' | b'@' | b'&' | b'=' | b'+' | b'$' | b','
))
}
fn is_valid_host<H: AsRef<str>>(host: H) -> bool {
if host.as_ref().is_empty() {
return false;
}
if IpAddr::from_str(host.as_ref()).is_err() {
for ch in host.as_ref().chars() {
if !matches!(ch,
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '.' )
{
return false;
}
}
}
true
}
fn is_valid_scheme<H: AsRef<str>>(host: H) -> bool {
let mut chars = host.as_ref().chars();
if let Some(ch) = chars.next() {
if !matches!(ch, 'A'..='Z' | 'a'..='z') {
return false;
}
} else {
return false;
}
for ch in chars {
if !matches!(ch,
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '+' | '.' )
{
return false;
}
}
true
}
fn hex_decode(h: u8) -> Result<u8, String> {
match h {
b'0'..=b'9' => Ok(h - b'0'),
b'A'..=b'F' => Ok(h - b'A' + 10),
b'a'..=b'f' => Ok(h - b'a' + 10),
_ => Err("Unexpected character in percent encoding".to_owned()),
}
}
fn hex_encode(c: u8) -> (char, char) {
let c0 = c >> 4;
let c1 = c & 15;
(
if c0 < 10 {
char::from_u32((b'0' + c0) as u32).unwrap()
} else {
char::from_u32((b'A' + c0 - 10) as u32).unwrap()
},
if c1 < 10 {
char::from_u32((b'0' + c1) as u32).unwrap()
} else {
char::from_u32((b'A' + c1 - 10) as u32).unwrap()
},
)
}
fn url_decode<S: AsRef<str>>(s: S) -> Result<String, String> {
let url = s.as_ref().to_owned();
if !url.is_ascii() {
return Err("URL is not in ASCII encoding".to_owned());
}
let url_bytes = url.as_bytes();
let mut dec_bytes: Vec<u8> = Vec::with_capacity(url_bytes.len());
let mut i = 0;
let end = url_bytes.len();
while i < end {
let mut b = url_bytes[i];
i += 1;
if b == b'%' {
if (i + 1) >= end {
return Err("Invalid URL encoding".to_owned());
}
b = hex_decode(url_bytes[i])? << 4 | hex_decode(url_bytes[i + 1])?;
i += 2;
}
dec_bytes.push(b);
}
String::from_utf8(dec_bytes).map_err(|e| format!("Decoded URL is not valid UTF-8: {}", e))
}
fn url_encode<S: AsRef<str>>(s: S, must_encode: impl Fn(u8) -> bool) -> String {
let bytes = s.as_ref().as_bytes();
let mut out = String::new();
for b in bytes {
if must_encode(*b) {
let (c0, c1) = hex_encode(*b);
out.push('%');
out.push(c0);
out.push(c1);
} else {
out.push(char::from_u32(*b as u32).unwrap())
}
}
out
}
fn convert_port<N>(port_str: N) -> Result<u16, String>
where
N: AsRef<str>,
{
port_str
.as_ref()
.parse::<u16>()
.map_err(|e| format!("Invalid port: {}", e))
}
///////////////////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SplitUrlPath {
pub path: String,
pub fragment: Option<String>,
pub query: Option<String>,
}
impl SplitUrlPath {
pub fn new<P, F, Q>(path: P, fragment: Option<F>, query: Option<Q>) -> Self
where
P: AsRef<str>,
F: AsRef<str>,
Q: AsRef<str>,
{
Self {
path: path.as_ref().to_owned(),
fragment: fragment.map(|f| f.as_ref().to_owned()),
query: query.map(|f| f.as_ref().to_owned()),
}
}
}
impl FromStr for SplitUrlPath {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if let Some((p, q)) = s.split_once('?') {
if let Some((p, f)) = p.split_once('#') {
SplitUrlPath::new(url_decode(p)?, Some(url_decode(f)?), Some(q))
} else {
SplitUrlPath::new(url_decode(p)?, Option::<String>::None, Some(q))
}
} else if let Some((p, f)) = s.split_once('#') {
SplitUrlPath::new(url_decode(p)?, Some(url_decode(f)?), Option::<String>::None)
} else {
SplitUrlPath::new(
url_decode(s)?,
Option::<String>::None,
Option::<String>::None,
)
})
}
}
impl fmt::Display for SplitUrlPath {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(fragment) = &self.fragment {
if let Some(query) = &self.query {
write!(
f,
"{}#{}?{}",
url_encode(&self.path, must_encode_path),
url_encode(fragment, must_encode_path),
query
)
} else {
write!(f, "{}#{}", self.path, fragment)
}
} else if let Some(query) = &self.query {
write!(f, "{}?{}", url_encode(&self.path, must_encode_path), query)
} else {
write!(f, "{}", url_encode(&self.path, must_encode_path))
}
}
}
///////////////////////////////////////////////////////////////////////////////
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SplitUrl {
pub scheme: String,
pub userinfo: Option<String>,
pub host: String,
pub port: Option<u16>,
pub path: Option<SplitUrlPath>,
}
impl SplitUrl {
pub fn new<S, H>(
scheme: S,
userinfo: Option<String>,
host: H,
port: Option<u16>,
path: Option<SplitUrlPath>,
) -> Self
where
S: AsRef<str>,
H: AsRef<str>,
{
Self {
scheme: scheme.as_ref().to_owned(),
userinfo,
host: host.as_ref().to_owned(),
port,
path,
}
}
}
impl FromStr for SplitUrl {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some((scheme, mut rest)) = s.split_once("://") {
if !is_valid_scheme(scheme) {
return Err("Invalid scheme specified".to_owned());
}
let userinfo = {
if let Some((userinfo_str, after)) = rest.split_once("@") {
rest = after;
Some(url_decode(userinfo_str)?)
} else {
None
}
};
if let Some((host, rest)) = rest.rsplit_once(':') {
if !is_valid_host(host) {
return Err("Invalid host specified".to_owned());
}
if let Some((portstr, path)) = rest.split_once('/') {
let port = convert_port(portstr)?;
let path = SplitUrlPath::from_str(path)?;
Ok(SplitUrl::new(
scheme,
userinfo,
host,
Some(port),
Some(path),
))
} else {
let port = convert_port(rest)?;
Ok(SplitUrl::new(scheme, userinfo, host, Some(port), None))
}
} else if let Some((host, path)) = rest.split_once('/') {
if !is_valid_host(host) {
return Err("Invalid host specified".to_owned());
}
let path = SplitUrlPath::from_str(path)?;
Ok(SplitUrl::new(scheme, userinfo, host, None, Some(path)))
} else {
if !is_valid_host(rest) {
return Err("Invalid host specified".to_owned());
}
Ok(SplitUrl::new(scheme, userinfo, rest, None, None))
}
} else {
Err("No scheme specified".to_owned())
}
}
}
impl fmt::Display for SplitUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let hostname = {
if let Some(userinfo) = &self.userinfo {
let userinfo = url_encode(userinfo, must_encode_userinfo);
if let Some(port) = self.port {
format!("{}@{}:{}", userinfo, self.host, port)
} else {
format!("{}@{}", userinfo, self.host)
}
} else {
self.host.clone()
}
};
if let Some(path) = &self.path {
write!(f, "{}://{}/{}", self.scheme, hostname, path)
} else {
write!(f, "{}://{}", self.scheme, hostname)
}
}
}