Skip to content

Commit

Permalink
feat(stageleft): allow developers to add their own re-export rewrites (
Browse files Browse the repository at this point in the history
  • Loading branch information
shadaj authored Jan 9, 2025
1 parent 487e8fa commit b968f5b
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 147 deletions.
2 changes: 1 addition & 1 deletion hydro_lang/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ build = [ "dep:dfir_lang" ]

[dependencies]
bincode = "1.3.1"
ctor = "0.2.8"
hydro_deploy = { path = "../hydro_deploy/core", version = "^0.11.0", optional = true }
dfir_rs = { path = "../dfir_rs", version = "^0.11.0", default-features = false, features = ["deploy_integration"] }
dfir_lang = { path = "../dfir_lang", version = "^0.11.0", optional = true }
Expand All @@ -46,7 +47,6 @@ stageleft_tool = { path = "../stageleft_tool", version = "^0.5.0" }

[dev-dependencies]
async-ssh2-lite = { version = "0.5.0", features = ["vendored-openssl"] }
ctor = "0.2.8"
hydro_deploy = { path = "../hydro_deploy/core", version = "^0.11.0" }
insta = "1.39"
tokio-test = "0.4.4"
Expand Down
6 changes: 6 additions & 0 deletions hydro_lang/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ mod staging_util;
#[cfg(feature = "deploy")]
pub mod test_util;

#[ctor::ctor]
fn add_private_reexports() {
stageleft::add_private_reexport(vec!["tokio", "time", "instant"], vec!["tokio", "time"]);
stageleft::add_private_reexport(vec!["bytes", "bytes"], vec!["bytes"]);
}

#[stageleft::runtime]
#[cfg(test)]
mod tests {
Expand Down
2 changes: 1 addition & 1 deletion stageleft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use runtime_support::FreeVariableWithContext;
use crate::runtime_support::get_final_crate_name;

mod type_name;
pub use type_name::quote_type;
pub use type_name::{add_private_reexport, quote_type};

#[cfg(windows)]
#[macro_export]
Expand Down
214 changes: 70 additions & 144 deletions stageleft/src/type_name.rs
Original file line number Diff line number Diff line change
@@ -1,167 +1,93 @@
use std::sync::{LazyLock, RwLock};

use proc_macro2::Span;
use syn::visit_mut::VisitMut;
use syn::{parse_quote, TypeInfer};

use crate::runtime_support::get_final_crate_name;

/// Rewrites use of alloc::string::* to use std::string::*
struct RewriteAlloc {
type ReexportsSet = LazyLock<RwLock<Vec<(Vec<&'static str>, Vec<&'static str>)>>>;
static PRIVATE_REEXPORTS: ReexportsSet = LazyLock::new(|| {
RwLock::new(vec![
(vec!["alloc"], vec!["std"]),
(vec!["core", "ops", "range"], vec!["std", "ops"]),
(vec!["core", "slice", "iter"], vec!["std", "slice"]),
(vec!["core", "iter", "adapters", "*"], vec!["std", "iter"]),
(
vec!["std", "collections", "hash", "map"],
vec!["std", "collections", "hash_map"],
),
(vec!["std", "vec", "into_iter"], vec!["std", "vec"]),
])
});

/// Adds a private module re-export transformation to the type quoting system.
///
/// Sometimes, the [`quote_type`] function may produce an uncompilable reference to a
/// type inside a private module if the type is re-exported from a public module
/// (because Rust's `type_name` only gives the path to the original definition).
///
/// This function adds a rewrite rule for such cases, where the `from` path is
/// replaced with the `to` path. The paths are given as a list of strings, where
/// each string is a segment of the path. The `from` path is matched against the
/// beginning of the type path, and if it matches, the `to` path is substituted
/// in its place. The `from` path may contain a wildcard `*` to glob a segment.
///
/// # Example
/// ```rust
/// stageleft::add_private_reexport(
/// vec!["std", "collections", "hash", "map"],
/// vec!["std", "collections", "hash_map"],
/// );
/// ```
pub fn add_private_reexport(from: Vec<&'static str>, to: Vec<&'static str>) {
let mut transformations = PRIVATE_REEXPORTS.write().unwrap();
transformations.push((from, to));
}

struct RewritePrivateReexports {
mapping: Option<(String, String)>,
}

impl VisitMut for RewriteAlloc {
impl VisitMut for RewritePrivateReexports {
fn visit_path_mut(&mut self, i: &mut syn::Path) {
if i.segments.iter().take(1).collect::<Vec<_>>()
== vec![&syn::PathSegment::from(syn::Ident::new(
"alloc",
Span::call_site(),
))]
{
*i.segments.first_mut().unwrap() =
syn::PathSegment::from(syn::Ident::new("std", Span::call_site()));
} else if i.segments.iter().take(3).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("core", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("ops", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("range", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![
syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("ops", Span::call_site())),
]
.into_iter()
.chain(i.segments.iter().skip(3).cloned()),
),
};
} else if i.segments.iter().take(3).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("core", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("slice", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("iter", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![
syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("slice", Span::call_site())),
]
.into_iter()
.chain(i.segments.iter().skip(3).cloned()),
),
};
} else if i.segments.iter().take(3).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("core", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("iter", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("adapters", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![
syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("iter", Span::call_site())),
]
.into_iter()
.chain(i.segments.iter().skip(4).cloned()),
),
};
} else if i.segments.iter().take(4).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("collections", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("hash", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("map", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![
syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("collections", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("hash_map", Span::call_site())),
]
.into_iter()
.chain(i.segments.iter().skip(4).cloned()),
),
};
} else if i.segments.iter().take(3).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("vec", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("into_iter", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![
syn::PathSegment::from(syn::Ident::new("std", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("vec", Span::call_site())),
]
.into_iter()
.chain(i.segments.iter().skip(3).cloned()),
),
};
} else if i.segments.iter().take(3).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("tokio", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("time", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("instant", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![
syn::PathSegment::from(syn::Ident::new("tokio", Span::call_site())),
syn::PathSegment::from(syn::Ident::new("time", Span::call_site())),
]
.into_iter()
.chain(i.segments.iter().skip(3).cloned()),
),
};
} else if i.segments.iter().take(2).collect::<Vec<_>>()
== vec![
&syn::PathSegment::from(syn::Ident::new("bytes", Span::call_site())),
&syn::PathSegment::from(syn::Ident::new("bytes", Span::call_site())),
]
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
vec![syn::PathSegment::from(syn::Ident::new(
"bytes",
Span::call_site(),
))]
.into_iter()
.chain(i.segments.iter().skip(2).cloned()),
),
};
} else if let Some((macro_name, final_name)) = &self.mapping {
let transformations = PRIVATE_REEXPORTS.read().unwrap();
for (from, to) in transformations.iter() {
#[expect(clippy::cmp_owned, reason = "buggy lint for syn::Ident::to_string")]
if i.segments.len() >= from.len()
&& from
.iter()
.zip(i.segments.iter())
.all(|(f, s)| *f == "*" || *f == s.ident.to_string())
{
*i = syn::Path {
leading_colon: i.leading_colon,
segments: syn::punctuated::Punctuated::from_iter(
to.iter()
.map(|s| syn::PathSegment::from(syn::Ident::new(s, Span::call_site())))
.chain(i.segments.iter().skip(from.len()).cloned()),
),
};

drop(transformations);
self.visit_path_mut(i);
return;
}
}
drop(transformations);

if let Some((macro_name, final_name)) = &self.mapping {
if i.segments.first().unwrap().ident == macro_name {
*i.segments.first_mut().unwrap() =
syn::parse2(get_final_crate_name(final_name)).unwrap();

i.segments.insert(1, parse_quote!(__staged));
} else {
syn::visit_mut::visit_path_mut(self, i);
return;
}
} else {
syn::visit_mut::visit_path_mut(self, i);
return;
}

self.visit_path_mut(i);
}
}

Expand Down Expand Up @@ -202,7 +128,7 @@ pub fn quote_type<T>() -> syn::Type {
});
let mapping = super::runtime_support::MACRO_TO_CRATE.with(|m| m.borrow().clone());
ElimClosureToInfer.visit_type_mut(&mut t_type);
RewriteAlloc { mapping }.visit_type_mut(&mut t_type);
RewritePrivateReexports { mapping }.visit_type_mut(&mut t_type);

t_type
}
9 changes: 8 additions & 1 deletion stageleft_tool/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,14 @@ impl VisitMut for GenFinalPubVistor {
}
}

i.vis = parse_quote!(pub);
let is_ctor = i
.attrs
.iter()
.any(|a| a.path().to_token_stream().to_string() == "ctor :: ctor");

if !is_ctor {
i.vis = parse_quote!(pub);
}

syn::visit_mut::visit_item_fn_mut(self, i);
}
Expand Down

0 comments on commit b968f5b

Please sign in to comment.