Skip to content

Commit

Permalink
implemented weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
mlb2251 committed Nov 29, 2023
1 parent f96ba0e commit 323da8c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/bin/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() {

let input = args.fmt.load_programs_and_tasks(&args.file).unwrap();

let (step_results, json_res) = multistep_compression(&input.train_programs, input.tasks, input.name_mapping, None, &args.multistep);
let (step_results, json_res) = multistep_compression(&input.train_programs, input.tasks, None, input.name_mapping, None, &args.multistep);

let out_path = &args.out;
if let Some(out_path_dir) = out_path.parent() {
Expand Down
55 changes: 39 additions & 16 deletions src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,10 @@ pub struct SharedData {
pub root_idxs_of_task: Vec<Vec<usize>>,
pub cost_of_node_all: Vec<i32>,
pub init_cost: i32,
pub init_cost_weighted: i32,
pub init_cost_by_root_idx: Vec<i32>,
pub init_cost_by_root_idx_weighted: Vec<f32>,
pub weight_by_root_idx: Vec<f32>,
pub first_train_cost: i32,
pub stats: Mutex<Stats>,
pub cfg: CompressionStepConfig,
Expand Down Expand Up @@ -1404,13 +1407,13 @@ impl CompressionStepResult {
let inv = done.to_invention(inv_name, shared);
let rewritten = rewrite_fast(&done, shared, &Node::Prim(inv.name.clone().into()), &shared.cost_fn);

let expected_cost = shared.init_cost - done.compressive_utility;
let expected_cost = shared.init_cost_weighted - done.compressive_utility;
// let final_cost = rewritten.cost();
let final_cost = shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| rewritten[*idx].cost(&shared.cost_fn)).min().unwrap()
root_idxs.iter().map(|idx| (rewritten[*idx].cost(&shared.cost_fn) as f32 * shared.weight_by_root_idx[*idx]).round() as i32).min().unwrap()
).sum::<i32>();
if expected_cost != final_cost && !shared.cfg.quiet { println!("*** expected cost {expected_cost} != final cost {final_cost}") }
let multiplier = shared.init_cost as f64 / final_cost as f64;
let multiplier = shared.init_cost_weighted as f64 / final_cost as f64;
let multiplier_wrt_orig = very_first_cost as f64 / final_cost as f64;
let uses = done.usages;
let use_exprs: Vec<Idx> = done.pattern.match_locations.clone();
Expand Down Expand Up @@ -1558,8 +1561,8 @@ fn compressive_utility(pattern: &Pattern, shared: &SharedData) -> UtilityCalcula

let (cumulative_utility_of_node, corrected_utils) = bottom_up_utility_correction(pattern,shared,&utility_of_loc_once);

let compressive_utility: i32 = shared.init_cost - shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| shared.init_cost_by_root_idx[*idx] - cumulative_utility_of_node[shared.roots[*idx]]).min().unwrap()
let compressive_utility: i32 = shared.init_cost_weighted - shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| (shared.init_cost_by_root_idx_weighted[*idx] - (cumulative_utility_of_node[shared.roots[*idx]] as f32 * shared.weight_by_root_idx[*idx])).round() as i32).min().unwrap()
).sum::<i32>();

// pattern.match_locations.
Expand Down Expand Up @@ -1769,6 +1772,7 @@ fn use_counts(pattern: &Pattern, zip_of_zid: &[Vec<ZNode>], arg_of_zid_node: &[F
pub fn multistep_compression_internal(
train_programs: &[ExprOwned],
tasks: Option<Vec<String>>,
weights: Option<Vec<f32>>,
name_mapping: Option<Vec<(String, String)>>,
follow: Option<Vec<Invention>>,
cfg: &MultistepCompressionConfig
Expand All @@ -1792,14 +1796,16 @@ pub fn multistep_compression_internal(
cfg.step.no_opt();
}

let very_first_cost = min_cost(train_programs, &tasks, cost_fn);
let very_first_cost = min_cost(train_programs, &weights, &tasks, cost_fn);

let tasks: Vec<String> = tasks.unwrap_or_else(|| {
(0..train_programs.len())
.map(|i| i.to_string())
.collect()
});

let weights: Vec<f32> = weights.unwrap_or_else(|| vec![1.0; train_programs.len()]);

let mut name_mapping = name_mapping.unwrap_or_default();


Expand All @@ -1820,6 +1826,7 @@ pub fn multistep_compression_internal(
&inv_name,
&cfg,
&tasks,
&weights,
very_first_cost,
&name_mapping,
);
Expand All @@ -1846,10 +1853,10 @@ pub fn multistep_compression_internal(

if !cfg.step.quiet { println!("{}","\n=======Compression Summary=======".blue().bold()) }
if !cfg.step.quiet { println!("Found {} inventions", step_results.len()) }
let rewritten_cost = min_cost(&rewritten, &Some(tasks.clone()), cost_fn);
let rewritten_cost = min_cost(&rewritten, &Some(weights.clone()), &Some(tasks.clone()), cost_fn);
if !cfg.step.quiet { println!("Cost Improvement: ({:.2}x better) {} -> {}", compression_factor(very_first_cost, rewritten_cost), very_first_cost, rewritten_cost) }
for res in step_results.iter() {
let rewritten_cost = min_cost(&res.rewritten, &Some(tasks.clone()), cost_fn);
let rewritten_cost = min_cost(&res.rewritten, &Some(weights.clone()), &Some(tasks.clone()), cost_fn);
if !cfg.step.quiet { println!("{} ({:.2}x wrt orig): {}" , res.inv.name.clone().blue(), compression_factor(very_first_cost, rewritten_cost), res) }
}
if !cfg.step.quiet { println!("Time: {}ms", tstart.elapsed().as_millis()) }
Expand All @@ -1868,6 +1875,7 @@ pub fn compression_step(
new_inv_name: &str, // name of the new invention, like "inv4"
multistep_cfg: &MultistepCompressionConfig,
tasks: &[String],
weights: &[f32],
very_first_cost: i32,
name_mapping: &[(String, String)],
) -> Vec<CompressionStepResult> {
Expand Down Expand Up @@ -1911,9 +1919,13 @@ pub fn compression_step(
let tasks_of_node: Vec<FxHashSet<usize>> = associate_tasks(&roots, &set, &corpus_span, &task_of_root_idx);

let init_cost_by_root_idx: Vec<i32> = roots.iter().map(|idx| analyzed_cost[*idx]).collect();
let init_cost_by_root_idx_weighted: Vec<f32> = init_cost_by_root_idx.iter().zip(weights.iter()).map(|(cost,weight)| (*cost as f32 * weight)).collect();
let init_cost: i32 = root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| init_cost_by_root_idx[*idx]).min().unwrap()
).sum();
let init_cost_weighted: i32 = root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| init_cost_by_root_idx_weighted[*idx].round() as i32).min().unwrap()
).sum();
let first_train_cost = roots.iter().map(|idx| analyzed_cost[*idx]).sum(); // This is used for --verbose-print

if !cfg.quiet { println!("associate_tasks() and other task stuff: {:?}ms", tstart.elapsed().as_millis()) }
Expand Down Expand Up @@ -2009,8 +2021,8 @@ pub fn compression_step(
let body_utility = analyzed_cost[node];

// compressive_utility for arity-0 is cost_of_node_all[node] minus the penalty of using the new prim
let compressive_utility: i32 = init_cost - root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| init_cost_by_root_idx[*idx] - num_paths_to_node_by_root_idx[*idx][node] * (analyzed_cost[node] - cost_fn.cost_prim_default))
let compressive_utility: i32 = init_cost_weighted - root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| (init_cost_by_root_idx_weighted[*idx] - weights[*idx] * (num_paths_to_node_by_root_idx[*idx][node] * (analyzed_cost[node] - cost_fn.cost_prim_default)) as f32).round() as i32)
.min().unwrap()
).sum::<i32>();

Expand Down Expand Up @@ -2096,7 +2108,10 @@ pub fn compression_step(
root_idxs_of_task,
cost_of_node_all,
init_cost,
init_cost_weighted,
init_cost_by_root_idx,
init_cost_by_root_idx_weighted,
weight_by_root_idx: weights.to_vec(),
first_train_cost,
stats: Mutex::new(stats),
cfg: cfg.clone(),
Expand Down Expand Up @@ -2197,7 +2212,14 @@ pub fn compression_step(
}

/// toplevel entrypoint to compression used by most apis
pub fn multistep_compression(programs: &[String], tasks: Option<Vec<String>>, name_mapping: Option<Vec<(String,String)>>, follow: Option<Vec<Invention>>, cfg: &MultistepCompressionConfig) -> (Vec<CompressionStepResult>, serde_json::Value) {
pub fn multistep_compression(
programs: &[String],
tasks: Option<Vec<String>>,
weights: Option<Vec<f32>>,
name_mapping: Option<Vec<(String,String)>>,
follow: Option<Vec<Invention>>,
cfg: &MultistepCompressionConfig
)-> (Vec<CompressionStepResult>, serde_json::Value) {
let mut programs = programs.to_vec();
let mut cfg = cfg.clone();

Expand Down Expand Up @@ -2238,22 +2260,23 @@ pub fn multistep_compression(programs: &[String], tasks: Option<Vec<String>>, na

let step_results = multistep_compression_internal(
&train_programs,
tasks.clone(),
tasks.clone(),
weights.clone(),
name_mapping,
follow,
&cfg,
);

// write everything to json
let json_res = json_of_step_results(&step_results, &train_programs, tasks, &cost_fn, &cfg);
let json_res = json_of_step_results(&step_results, &train_programs, weights, tasks, &cost_fn, &cfg);

(step_results, json_res)
}

pub fn json_of_step_results(step_results: &[CompressionStepResult], train_programs: &Vec<ExprOwned>, tasks: Option<Vec<String>>, cost_fn: &ExprCost, cfg: &MultistepCompressionConfig) -> serde_json::Value {
pub fn json_of_step_results(step_results: &[CompressionStepResult], train_programs: &Vec<ExprOwned>, weights: Option<Vec<f32>>, tasks: Option<Vec<String>>, cost_fn: &ExprCost, cfg: &MultistepCompressionConfig) -> serde_json::Value {
let rewritten: &Vec<ExprOwned> = step_results.iter().last().map(|res| &res.rewritten).unwrap_or(train_programs);
let original_cost = min_cost(train_programs, &tasks, cost_fn);
let final_cost = min_cost(rewritten, &tasks, cost_fn);
let original_cost = min_cost(train_programs, &weights, &tasks, cost_fn);
let final_cost = min_cost(rewritten, &weights, &tasks, cost_fn);
let rewritten = step_results.iter().last().map(|res| &res.rewritten).unwrap_or(train_programs).iter().map(|p| p.to_string()).collect::<Vec<String>>();
let rewritten_dreamcoder = if !cfg.step.rewritten_dreamcoder { None } else {
let rewritten_dreamcoder = step_results.iter().last().map(|res| res.rewritten_dreamcoder.clone().unwrap()).unwrap_or_else(||train_programs.iter().map(
Expand Down
6 changes: 3 additions & 3 deletions src/rewriting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ pub fn rewrite_fast(
if !shared.cfg.no_mismatch_check && !shared.cfg.utility_by_rewrite {
assert_eq!(
shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| rewritten_exprs[*idx].cost(cost_fn)).min().unwrap()
root_idxs.iter().map(|idx| (rewritten_exprs[*idx].cost(cost_fn) as f32 * shared.weight_by_root_idx[*idx]).round() as i32).min().unwrap()
).sum::<i32>(),
shared.init_cost - pattern.util_calc.util,
shared.init_cost_weighted - pattern.util_calc.util,
"\n{}\n", pattern.info(shared)
);
}
Expand Down Expand Up @@ -185,7 +185,7 @@ pub fn rewrite_with_inventions(
// cfg.step.rewritten_dreamcoder = true;
// cfg.step.rewritten_intermediates = true;

let (step_results, json_res) = multistep_compression(programs, None, None, follow, &cfg);
let (step_results, json_res) = multistep_compression(programs, None, None, None, follow, &cfg);

// return the last one - note that if an abstraction wasn't used anywhere it will not be included in the step_results so this
// may be shorter than invs.len(), however we do ensure that we continue searching for the rest of the abstractions if this happens
Expand Down
8 changes: 5 additions & 3 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use crate::*;
use lambdas::*;


pub fn min_cost(programs: &[ExprOwned], tasks: &Option<Vec<String>>, cost_fn: &ExprCost) -> i32 {
pub fn min_cost(programs: &[ExprOwned], weights: &Option<Vec<f32>>, tasks: &Option<Vec<String>>, cost_fn: &ExprCost) -> i32 {
let weights = weights.clone().unwrap_or(vec![1.0; programs.len()]);
if let Some(tasks) = tasks {
let mut unique_tasks = tasks.to_vec();
unique_tasks.sort();
unique_tasks.dedup();
unique_tasks.iter().map(|task|
tasks.iter().zip(programs.iter()).filter_map(|(t,p)| if task == t { Some(p.cost(cost_fn)) } else { None }).min().unwrap()
tasks.iter().zip(programs.iter().zip(weights.iter())).filter_map(|(t,(p,w))| if task == t { Some((p.cost(cost_fn) as f32 * w).round() as i32) } else { None })
.min().unwrap()
).sum::<i32>()
} else {
programs.iter().map(|e| e.cost(cost_fn)).sum::<i32>()
programs.iter().zip(weights.iter()).map(|(e,w)| (e.cost(cost_fn) as f32 * w).round() as i32).sum::<i32>()
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ fn run_compression(inputs: &Input, cfg: &MultistepCompressionConfig) -> Value {
multistep_compression(
&inputs.train_programs,
inputs.tasks.clone(),
None,
inputs.name_mapping.clone(),
None,
cfg,
Expand Down

0 comments on commit 323da8c

Please sign in to comment.