Skip to content

Commit 0c83d06

Browse files
committed
update example folder
1 parent 26978a3 commit 0c83d06

8 files changed

+3231
-3275
lines changed

autodiscjax/utils/create_modules.py

+7-22
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,10 @@ def create_intervention_module(intervention_config):
5050
if intervention_config.intervention_type == "set_uniform":
5151
intervention_fn = grn.PiecewiseSetConstantIntervention(
5252
time_to_interval_fn=grn.TimeToInterval(intervals=intervention_config.controlled_intervals))
53-
intervention_params_tree = DictTree()
54-
for y_idx in intervention_config.controlled_node_ids:
55-
intervention_params_tree.y[y_idx] = "placeholder"
56-
intervention_params_treedef = jtu.tree_structure(intervention_params_tree)
57-
intervention_params_shape = jtu.tree_map(lambda _: (len(intervention_config.controlled_intervals),), intervention_params_tree)
58-
intervention_params_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)
59-
60-
intervention_low = DictTree(intervention_config.low)
61-
intervention_low = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype),
62-
intervention_low, intervention_params_shape,
63-
intervention_params_dtype)
64-
intervention_high = DictTree(intervention_config.high)
65-
intervention_high = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype),
66-
intervention_high, intervention_params_shape,
67-
intervention_params_dtype)
68-
random_intervention_generator = imgep.UniformRandomGenerator(intervention_params_treedef,
69-
intervention_params_shape,
70-
intervention_params_dtype,
71-
intervention_low, intervention_high)
53+
random_intervention_generator = imgep.UniformRandomGenerator(intervention_config.out_treedef,
54+
intervention_config.out_shape,
55+
intervention_config.out_dtype,
56+
intervention_config.low, intervention_config.high)
7257
else:
7358
raise ValueError
7459
return random_intervention_generator, intervention_fn
@@ -217,8 +202,8 @@ def create_gc_intervention_optimizer_module(gc_intervention_optimizer_config):
217202
gc_intervention_optimizer_config.high,
218203
gc_intervention_optimizer_config.n_optim_steps,
219204
gc_intervention_optimizer_config.n_workers,
220-
init_noise_std=gc_intervention_optimizer_config.init_noise_std,
221-
lr=gc_intervention_optimizer_config.lr,
205+
gc_intervention_optimizer_config.init_noise_std,
206+
gc_intervention_optimizer_config.lr,
222207
)
223208

224209

@@ -230,7 +215,7 @@ def create_gc_intervention_optimizer_module(gc_intervention_optimizer_config):
230215
gc_intervention_optimizer_config.high,
231216
gc_intervention_optimizer_config.n_optim_steps,
232217
gc_intervention_optimizer_config.n_workers,
233-
init_noise_std=gc_intervention_optimizer_config.init_noise_std
218+
gc_intervention_optimizer_config.init_noise_std
234219
)
235220

236221
else:

examples/analyze_imgep_evaluation.ipynb

+109-15
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)