diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index 6b5eeacd..a353c4c0 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -80,14 +80,14 @@ def test_sgkit_dataset_roundtrip(tmp_path): inf_ts = tsinfer.infer(samples) ds = sgkit.load_dataset(zarr_path) - assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"] + assert ts.num_individuals == inf_ts.num_individuals == ds.sizes["samples"] for ts_ind, sample_id in zip(inf_ts.individuals(), ds["sample_id"].values): assert ts_ind.metadata["variant_data_sample_id"] == sample_id assert ( - ts.num_samples == inf_ts.num_samples == ds.dims["samples"] * ds.dims["ploidy"] + ts.num_samples == inf_ts.num_samples == ds.sizes["samples"] * ds.sizes["ploidy"] ) - assert ts.num_sites == inf_ts.num_sites == ds.dims["variants"] + assert ts.num_sites == inf_ts.num_sites == ds.sizes["variants"] assert ts.sequence_length == inf_ts.sequence_length == ds.attrs["contig_lengths"][0] for ( v, @@ -122,7 +122,7 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path): inf_ts = tsinfer.infer(samples) ds = sgkit.load_dataset(zarr_path) - assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"] + assert ts.num_individuals == inf_ts.num_individuals == ds.sizes["samples"] for i, (ts_ind, sample_id) in enumerate( zip(inf_ts.individuals(), ds["sample_id"].values) ): @@ -694,23 +694,15 @@ def test_phased(self, tmp_path): ds["call_genotype"].dims, np.ones(ds["call_genotype"].shape, dtype=bool), ) - ds["variant_ancestral_allele"] = ( - ds["variant_position"].dims, - np.array(["A", "C", "G"], dtype="S1"), - ) sgkit.save_dataset(ds, path) - tsinfer.VariantData(path, "variant_ancestral_allele") + tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) def test_ploidy1_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" # Ploidy==1 is always ok ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) - ds["variant_ancestral_allele"] = ( - ds["variant_position"].dims, - np.array(["A", "C", "G"], dtype="S1"), - ) sgkit.save_dataset(ds, path) - tsinfer.VariantData(path, "variant_ancestral_allele") + tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) def test_ploidy1_unphased(self, tmp_path): path = tmp_path / "data.zarr" @@ -719,12 +711,8 @@ def test_ploidy1_unphased(self, tmp_path): ds["call_genotype"].dims, np.zeros(ds["call_genotype"].shape, dtype=bool), ) - ds["variant_ancestral_allele"] = ( - ds["variant_position"].dims, - np.array(["A", "C", "G"], dtype="S1"), - ) sgkit.save_dataset(ds, path) - tsinfer.VariantData(path, "variant_ancestral_allele") + tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) def test_duplicate_positions(self, tmp_path): path = tmp_path / "data.zarr" @@ -749,14 +737,10 @@ def test_empty_alleles_not_at_end(self, tmp_path): ds["variant_allele"].dims, np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"), ) - ds["variant_ancestral_allele"] = ( - ["variants"], - np.array(["C", "A", "A"], dtype="S1"), - ) sgkit.save_dataset(ds, path) - samples = tsinfer.VariantData(path, "variant_ancestral_allele") + vdata = tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) with pytest.raises(ValueError, match="Empty alleles must be at the end"): - tsinfer.infer(samples) + tsinfer.infer(vdata) def test_unimplemented_from_tree_sequence(self): # NB we should reimplement something like this functionality.