@@ -145,29 +145,29 @@ def prepare_data(
145
145
rng , rng_resample , rng_noise , rng_time , rng_latent , rng_step_fn = rng
146
146
147
147
batch = jtu .tree_map (jnp .asarray , batch )
148
- (src , src_cond , tgt ), data = prepare_data (batch )
148
+ (src , src_cond , tgt ), matching_data = prepare_data (batch )
149
149
150
- time = self .time_sampler (rng_time , len (src ) * self .k_samples_per_x )
151
- latent = self .latent_noise_fn (rng_noise , (self .k_samples_per_x , len (src )))
150
+ n = src .shape [0 ]
151
+ time = self .time_sampler (rng_time , n * self .k_samples_per_x )
152
+ latent = self .latent_noise_fn (rng_noise , (n , self .k_samples_per_x ))
152
153
153
- tmat = self .data_match_fn (* data ) # (n, m)
154
+ tmat = self .data_match_fn (* matching_data ) # (n, m)
154
155
src_ixs , tgt_ixs = flow_utils .sample_conditional ( # (n, k), (m, k)
155
156
rng_resample ,
156
157
tmat ,
157
158
k = self .k_samples_per_x ,
158
159
uniform_marginals = True , # TODO(michalk8): expose
159
160
)
160
161
161
- src = src [src_ixs ].swapaxes (0 , 1 ) # (k, n, ...)
162
- tgt = tgt [tgt_ixs ].swapaxes (0 , 1 ) # (k, m, ...)
162
+ src , tgt = src [src_ixs ], tgt [tgt_ixs ] # (n, k, ...), # (m, k, ...)
163
163
if src_cond is not None :
164
- src_cond = src_cond [src_ixs ]. swapaxes ( 0 , 1 ) # (k, n, ...)
164
+ src_cond = src_cond [src_ixs ]
165
165
166
166
if self .latent_match_fn is not None :
167
167
src , src_cond , tgt = self ._match_latent (rng , src , src_cond , latent , tgt )
168
168
169
- src = src .reshape (- 1 , * src .shape [2 :]) # (k * bs , ...)
170
- tgt = tgt .reshape (- 1 , * tgt .shape [2 :])
169
+ src = src .reshape (- 1 , * src .shape [2 :]) # (n * k , ...)
170
+ tgt = tgt .reshape (- 1 , * tgt .shape [2 :]) # (m * k, ...)
171
171
latent = latent .reshape (- 1 , * latent .shape [2 :])
172
172
if src_cond is not None :
173
173
src_cond = src_cond .reshape (- 1 , * src_cond .shape [2 :])
@@ -197,8 +197,8 @@ def resample(
197
197
198
198
return src , src_cond , tgt
199
199
200
- cond_axis = None if src_cond is None else 0
201
- in_axes , out_axes = (0 , 0 , cond_axis , 0 , 0 ), (0 , None , 0 )
200
+ cond_axis = None if src_cond is None else 1
201
+ in_axes , out_axes = (0 , 1 , cond_axis , 1 , 1 ), (1 , cond_axis , 1 )
202
202
resample_fn = jax .jit (jax .vmap (resample , in_axes , out_axes ))
203
203
204
204
rngs = jax .random .split (rng , self .k_samples_per_x )
0 commit comments