Skip to content

Commit 693ecc4

Browse files
committed
Remove axis swapping
1 parent f2c20a4 commit 693ecc4

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/ott/neural/flow_models/genot.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -145,29 +145,29 @@ def prepare_data(
145145
rng, rng_resample, rng_noise, rng_time, rng_latent, rng_step_fn = rng
146146

147147
batch = jtu.tree_map(jnp.asarray, batch)
148-
(src, src_cond, tgt), data = prepare_data(batch)
148+
(src, src_cond, tgt), matching_data = prepare_data(batch)
149149

150-
time = self.time_sampler(rng_time, len(src) * self.k_samples_per_x)
151-
latent = self.latent_noise_fn(rng_noise, (self.k_samples_per_x, len(src)))
150+
n = src.shape[0]
151+
time = self.time_sampler(rng_time, n * self.k_samples_per_x)
152+
latent = self.latent_noise_fn(rng_noise, (n, self.k_samples_per_x))
152153

153-
tmat = self.data_match_fn(*data) # (n, m)
154+
tmat = self.data_match_fn(*matching_data) # (n, m)
154155
src_ixs, tgt_ixs = flow_utils.sample_conditional( # (n, k), (m, k)
155156
rng_resample,
156157
tmat,
157158
k=self.k_samples_per_x,
158159
uniform_marginals=True, # TODO(michalk8): expose
159160
)
160161

161-
src = src[src_ixs].swapaxes(0, 1) # (k, n, ...)
162-
tgt = tgt[tgt_ixs].swapaxes(0, 1) # (k, m, ...)
162+
src, tgt = src[src_ixs], tgt[tgt_ixs] # (n, k, ...), # (m, k, ...)
163163
if src_cond is not None:
164-
src_cond = src_cond[src_ixs].swapaxes(0, 1) # (k, n, ...)
164+
src_cond = src_cond[src_ixs]
165165

166166
if self.latent_match_fn is not None:
167167
src, src_cond, tgt = self._match_latent(rng, src, src_cond, latent, tgt)
168168

169-
src = src.reshape(-1, *src.shape[2:]) # (k * bs, ...)
170-
tgt = tgt.reshape(-1, *tgt.shape[2:])
169+
src = src.reshape(-1, *src.shape[2:]) # (n * k, ...)
170+
tgt = tgt.reshape(-1, *tgt.shape[2:]) # (m * k, ...)
171171
latent = latent.reshape(-1, *latent.shape[2:])
172172
if src_cond is not None:
173173
src_cond = src_cond.reshape(-1, *src_cond.shape[2:])
@@ -197,8 +197,8 @@ def resample(
197197

198198
return src, src_cond, tgt
199199

200-
cond_axis = None if src_cond is None else 0
201-
in_axes, out_axes = (0, 0, cond_axis, 0, 0), (0, None, 0)
200+
cond_axis = None if src_cond is None else 1
201+
in_axes, out_axes = (0, 1, cond_axis, 1, 1), (1, cond_axis, 1)
202202
resample_fn = jax.jit(jax.vmap(resample, in_axes, out_axes))
203203

204204
rngs = jax.random.split(rng, self.k_samples_per_x)

0 commit comments

Comments
 (0)