Skip to content

Commit

Permalink
Always read alleles as strings (not bytes)
Browse files Browse the repository at this point in the history
Fixes #810
  • Loading branch information
hyanwong authored and mergify[bot] committed Jul 17, 2024
1 parent e309fa4 commit 141f0c7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
52 changes: 52 additions & 0 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
import tempfile

import msprime
import numcodecs
import numpy as np
import pytest
Expand All @@ -37,6 +38,43 @@
from tsinfer import formats


def ts_to_dataset(ts, chunks=None, samples=None):
"""
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
Convert the specified tskit tree sequence into an sgkit dataset.
Note this just generates haploids for now - see the note above
in simulate_ts.
"""
if samples is None:
samples = ts.samples()
tables = ts.dump_tables()
alleles = []
genotypes = []
max_alleles = 0
for var in ts.variants(samples=samples):
alleles.append(var.alleles)
max_alleles = max(max_alleles, len(var.alleles))
genotypes.append(var.genotypes.astype(np.int8))
padded_alleles = [
list(site_alleles) + [""] * (max_alleles - len(site_alleles))
for site_alleles in alleles
]
alleles = np.array(padded_alleles).astype("S")
genotypes = np.expand_dims(genotypes, axis=2)

ds = sgkit.create_genotype_call_dataset(
variant_contig_names=["1"],
variant_contig=np.zeros(len(tables.sites), dtype=int),
variant_position=tables.sites.position.astype(int),
variant_allele=alleles,
sample_id=np.array([f"tsk_{u}" for u in samples]).astype("U"),
call_genotype=genotypes,
)
if chunks is not None:
ds = ds.chunk(dict(zip(["variants", "samples"], chunks)))
return ds


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_sgkit_dataset_roundtrip(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
Expand Down Expand Up @@ -209,6 +247,20 @@ def test_sgkit_accessors_defaults(tmp_path):
)


def test_simulate_genotype_call_dataset(tmp_path):
# Test that byte alleles are correctly converted to string
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
ds = ts_to_dataset(ts)
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
ds.to_zarr(tmp_path, mode="w")
sd = tsinfer.SgkitSampleData(tmp_path)
ts = tsinfer.infer(sd)
for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes):
assert np.all(v.genotypes == ds_v.values.flatten())
assert np.all(v.genotypes == sd_v)


@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
class TestSgkitMask:
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
Expand Down
6 changes: 4 additions & 2 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,13 +2440,15 @@ def sites_position(self):

@functools.cached_property
def sites_alleles(self):
return self.data["variant_allele"][:][self.sites_select]
return self.data["variant_allele"][:][self.sites_select].astype(str)

@functools.cached_property
def sites_ancestral_allele(self):
unknown_alleles = collections.Counter()
try:
string_allele = self.data["variant_ancestral_allele"][:][self.sites_select]
string_allele = (
self.data["variant_ancestral_allele"][:][self.sites_select]
).astype(str)
except KeyError:
raise ValueError(
"variant_ancestral_allele was not found in the dataset."
Expand Down

0 comments on commit 141f0c7

Please sign in to comment.