- Allow f32 as a tpye for get_pc_dists to reduce memory of very large datasets (quantify discrepancy)
- Auto-fallback to JAX?
- Handle missingness so we can use more jax? Scikit jax?
- tfp has a covariance module, can it handle missing data as well?
- Allow jax functions to place on different devices (CPU, GPU, TPU)