@@ -50,25 +50,10 @@ def create_intervention_module(intervention_config):
50
50
if intervention_config .intervention_type == "set_uniform" :
51
51
intervention_fn = grn .PiecewiseSetConstantIntervention (
52
52
time_to_interval_fn = grn .TimeToInterval (intervals = intervention_config .controlled_intervals ))
53
- intervention_params_tree = DictTree ()
54
- for y_idx in intervention_config .controlled_node_ids :
55
- intervention_params_tree .y [y_idx ] = "placeholder"
56
- intervention_params_treedef = jtu .tree_structure (intervention_params_tree )
57
- intervention_params_shape = jtu .tree_map (lambda _ : (len (intervention_config .controlled_intervals ),), intervention_params_tree )
58
- intervention_params_dtype = jtu .tree_map (lambda _ : jnp .float32 , intervention_params_tree )
59
-
60
- intervention_low = DictTree (intervention_config .low )
61
- intervention_low = jtu .tree_map (lambda val , shape , dtype : val * jnp .ones (shape = shape , dtype = dtype ),
62
- intervention_low , intervention_params_shape ,
63
- intervention_params_dtype )
64
- intervention_high = DictTree (intervention_config .high )
65
- intervention_high = jtu .tree_map (lambda val , shape , dtype : val * jnp .ones (shape = shape , dtype = dtype ),
66
- intervention_high , intervention_params_shape ,
67
- intervention_params_dtype )
68
- random_intervention_generator = imgep .UniformRandomGenerator (intervention_params_treedef ,
69
- intervention_params_shape ,
70
- intervention_params_dtype ,
71
- intervention_low , intervention_high )
53
+ random_intervention_generator = imgep .UniformRandomGenerator (intervention_config .out_treedef ,
54
+ intervention_config .out_shape ,
55
+ intervention_config .out_dtype ,
56
+ intervention_config .low , intervention_config .high )
72
57
else :
73
58
raise ValueError
74
59
return random_intervention_generator , intervention_fn
@@ -217,8 +202,8 @@ def create_gc_intervention_optimizer_module(gc_intervention_optimizer_config):
217
202
gc_intervention_optimizer_config .high ,
218
203
gc_intervention_optimizer_config .n_optim_steps ,
219
204
gc_intervention_optimizer_config .n_workers ,
220
- init_noise_std = gc_intervention_optimizer_config .init_noise_std ,
221
- lr = gc_intervention_optimizer_config .lr ,
205
+ gc_intervention_optimizer_config .init_noise_std ,
206
+ gc_intervention_optimizer_config .lr ,
222
207
)
223
208
224
209
@@ -230,7 +215,7 @@ def create_gc_intervention_optimizer_module(gc_intervention_optimizer_config):
230
215
gc_intervention_optimizer_config .high ,
231
216
gc_intervention_optimizer_config .n_optim_steps ,
232
217
gc_intervention_optimizer_config .n_workers ,
233
- init_noise_std = gc_intervention_optimizer_config .init_noise_std
218
+ gc_intervention_optimizer_config .init_noise_std
234
219
)
235
220
236
221
else :
0 commit comments