diff --git a/src/compression.rs b/src/compression.rs index 2b44e788..42e1757c 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -223,7 +223,27 @@ pub struct CompressionStepConfig { /// Silence all printing within a compression step. See `silent` to silence all outputs between compression steps as well. #[clap(long)] pub quiet: bool, - + + // Fused lambda tags + #[clap(long, value_parser = clap::value_parser!(FusedLambdaTags), default_value="")] + pub fused_lambda_tags: FusedLambdaTags, +} + +#[derive(Debug, Clone, Serialize)] +pub struct FusedLambdaTags { + tags: Option>, +} + +// parse from a string like "1,2,3" +impl std::str::FromStr for FusedLambdaTags { + type Err = String; + fn from_str(s: &str) -> Result { + if s.is_empty() { + return Ok(FusedLambdaTags { tags: None }) + } + let tags = s.split(",").map(|s| s.parse::().unwrap()).collect(); + Ok(FusedLambdaTags { tags: Some(tags) }) + } } impl CompressionStepConfig { @@ -373,6 +393,12 @@ impl Pattern { } } } + + match_locations.retain(|node| + !invalid_match_location(set, + &cfg.fused_lambda_tags.tags, + *node) + ); if cfg.eta_long { @@ -659,6 +685,26 @@ pub struct SharedData { pub cfg: CompressionStepConfig, pub multistep_cfg: MultistepCompressionConfig, pub tracking: Option, + pub fused_lambda_tags: Option>, +} + +fn invalid_metavar_location(shared : &Arc, node: Idx) -> bool { + fused_lambda_location(&shared.set, &shared.fused_lambda_tags, node) +} + +fn invalid_match_location(set : &ExprSet, fused_lambda_tags: &Option>, node: Idx) -> bool { + fused_lambda_location(set, fused_lambda_tags, node) +} + +fn fused_lambda_location(set : &ExprSet, fused_lambda_tags: &Option>, node: Idx) -> bool { + if let Some(fused_lambda_tags) = fused_lambda_tags { + if let Node::Lam(_, tag) = &set[node] { + if fused_lambda_tags.contains(tag) { + return true + } + } + } + false } /// Used for debugging tracking information @@ -1166,16 +1212,19 @@ fn get_ivars_expansions(original_pattern: &Pattern, arg_of_loc: &FxHashMap = original_pattern.match_locations.iter() - .filter(|loc| + .filter(|loc:&&Idx| arg_of_loc[loc].shifted_id == - arg_of_loc_ivar[loc].shifted_id).cloned().collect(); + arg_of_loc_ivar[loc].shifted_id + && !invalid_metavar_location(shared, arg_of_loc[loc].shifted_id) + ).cloned().collect(); if locs.is_empty() { continue; } ivars_expansions.push((ExpandsTo::IVar(ivar as i32), locs)); } // also consider one ivar greater, if this is within the arity limit. This will match at all the same locations as the original. if original_pattern.first_zid_of_ivar.len() < shared.cfg.max_arity { let ivar = original_pattern.first_zid_of_ivar.len(); - let locs = original_pattern.match_locations.clone(); + let mut locs = original_pattern.match_locations.clone(); + locs.retain(|loc| !invalid_metavar_location(shared, arg_of_loc[loc].shifted_id)); ivars_expansions.push((ExpandsTo::IVar(ivar as i32), locs)); } ivars_expansions @@ -2085,6 +2134,15 @@ pub fn compression_step( worklist.push(HeapItem::new(single_hole)); let crit = CriticalMultithreadData::new(donelist, worklist, cfg); + + let fused_copy: Option> = if let Some(fused_tags) = &cfg.fused_lambda_tags.tags { + let mut fused_copy = FxHashSet::default(); + fused_copy.extend(fused_tags.iter().cloned()); + Some(fused_copy) + } else { + None + }; + let shared = Arc::new(SharedData { crit: Mutex::new(crit), programs: programs.to_vec(), @@ -2117,6 +2175,7 @@ pub fn compression_step( cfg: cfg.clone(), multistep_cfg: multistep_cfg.clone(), tracking, + fused_lambda_tags: fused_copy, }); if !shared.cfg.quiet { println!("built SharedData: {:?}ms", tstart.elapsed().as_millis()) }