Skip to content

Commit

Permalink
Translate pytests
Browse files Browse the repository at this point in the history
- Translate original filter pytests to work with an in-memory database.
- Move VCF-related tests to a better home in test_utils.
- Split filter tests into separate files:
    - file loading
    - filtering
    - TODO: subsampling
  • Loading branch information
victorlin committed Feb 17, 2022
1 parent 2e99ba3 commit 3fcd258
Show file tree
Hide file tree
Showing 5 changed files with 486 additions and 444 deletions.
280 changes: 32 additions & 248 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,267 +1,51 @@
import argparse
import numpy as np
import random
import shlex

import pytest

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import sqlite3

import augur.filter
from augur.utils import read_metadata
from augur.filter_support.db.sqlite import FilterSQLite

@pytest.fixture
def argparser():

def parse_args(args:str):
parser = argparse.ArgumentParser()
augur.filter.register_arguments(parser)
def parse(args):
return parser.parse_args(shlex.split(args))
return parse

@pytest.fixture
def sequences():
def random_seq(k):
return "".join(random.choices(("A","T","G","C"), k=k))
return {
"SEQ_1": SeqRecord(Seq(random_seq(10)), id="SEQ_1"),
"SEQ_2": SeqRecord(Seq(random_seq(10)), id="SEQ_2"),
"SEQ_3": SeqRecord(Seq(random_seq(10)), id="SEQ_3"),
}

@pytest.fixture
def fasta_fn(tmpdir, sequences):
fn = str(tmpdir / "sequences.fasta")
SeqIO.write(sequences.values(), fn, "fasta")
return fn

def write_metadata(tmpdir, metadata):
fn = str(tmpdir / "metadata.tsv")
with open(fn, "w") as fh:
fh.write("\n".join(("\t".join(md) for md in metadata)))
return fn

@pytest.fixture
def mock_priorities_file_valid(mocker):
mocker.patch(
"builtins.open", mocker.mock_open(read_data="strain1 5\nstrain2 6\nstrain3 8\n")
)


@pytest.fixture
def mock_priorities_file_malformed(mocker):
mocker.patch("builtins.open", mocker.mock_open(read_data="strain1 X\n"))


@pytest.fixture
def mock_run_shell_command(mocker):
mocker.patch("augur.filter.run_shell_command")


@pytest.fixture
def mock_priorities_file_valid_with_spaces_and_tabs(mocker):
mocker.patch(
"builtins.open", mocker.mock_open(read_data="strain 1\t5\nstrain 2\t6\nstrain 3\t8\n")
)

class TestFilter:
def test_read_vcf_compressed(self):
seq_keep, all_seq = augur.filter.read_vcf(
"tests/builds/tb/data/lee_2015.vcf.gz"
)

assert len(seq_keep) == 150
assert seq_keep[149] == "G22733"
assert seq_keep == all_seq
return parser.parse_args(shlex.split(args))

def test_read_vcf_uncompressed(self):
seq_keep, all_seq = augur.filter.read_vcf("tests/builds/tb/data/lee_2015.vcf")

assert len(seq_keep) == 150
assert seq_keep[149] == "G22733"
assert seq_keep == all_seq
def write_file(tmpdir, filename:str, content:str):
filepath = str(tmpdir / filename)
with open(filepath, "w") as handle:
handle.write(content)
return filepath

def test_read_priority_scores_valid(self, mock_priorities_file_valid):
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
priorities = augur.filter.read_priority_scores(
"tests/builds/tb/data/lee_2015.vcf"
)

assert priorities == {"strain1": 5, "strain2": 6, "strain3": 8}
assert priorities["strain1"] == 5
assert priorities["strain42"] == -np.inf, "Default priority is negative infinity for unlisted sequences"

def test_read_priority_scores_malformed(self, mock_priorities_file_malformed):
with pytest.raises(ValueError):
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
augur.filter.read_priority_scores("tests/builds/tb/data/lee_2015.vcf")

def test_read_priority_scores_valid_with_spaces_and_tabs(self, mock_priorities_file_valid_with_spaces_and_tabs):
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
priorities = augur.filter.read_priority_scores(
"tests/builds/tb/data/lee_2015.vcf"
)

assert priorities == {"strain 1": 5, "strain 2": 6, "strain 3": 8}

def test_read_priority_scores_does_not_exist(self):
with pytest.raises(FileNotFoundError):
augur.filter.read_priority_scores("/does/not/exist.txt")

def test_write_vcf_compressed_input(self, mock_run_shell_command):
augur.filter.write_vcf(
"tests/builds/tb/data/lee_2015.vcf.gz", "output_file.vcf.gz", []
)

augur.filter.run_shell_command.assert_called_once_with(
"vcftools --gzvcf tests/builds/tb/data/lee_2015.vcf.gz --recode --stdout | gzip -c > output_file.vcf.gz",
raise_errors=True,
)

def test_write_vcf_uncompressed_input(self, mock_run_shell_command):
augur.filter.write_vcf(
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf.gz", []
)

augur.filter.run_shell_command.assert_called_once_with(
"vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout | gzip -c > output_file.vcf.gz",
raise_errors=True,
)

def test_write_vcf_compressed_output(self, mock_run_shell_command):
augur.filter.write_vcf(
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf.gz", []
)

augur.filter.run_shell_command.assert_called_once_with(
"vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout | gzip -c > output_file.vcf.gz",
raise_errors=True,
)

def test_write_vcf_uncompressed_output(self, mock_run_shell_command):
augur.filter.write_vcf(
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf", []
)

augur.filter.run_shell_command.assert_called_once_with(
"vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout > output_file.vcf",
raise_errors=True,
)

def test_write_vcf_dropped_samples(self, mock_run_shell_command):
augur.filter.write_vcf(
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf", ["x", "y", "z"]
)

augur.filter.run_shell_command.assert_called_once_with(
"vcftools --remove-indv x --remove-indv y --remove-indv z --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout > output_file.vcf",
raise_errors=True,
)
def write_metadata(tmpdir, metadata):
content = "\n".join(("\t".join(md) for md in metadata))
return write_file(tmpdir, "metadata.tsv", content)

def test_filter_on_query_good(self, tmpdir, sequences):
"""Basic filter_on_query test"""
meta_fn = write_metadata(tmpdir, (("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
metadata, columns = read_metadata(meta_fn, as_data_frame=True)
filtered = augur.filter.filter_by_query(metadata, 'quality=="good"')
assert sorted(filtered) == ["SEQ_1", "SEQ_3"]

def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser):
"""Test that filter --query works as expected"""
out_fn = str(tmpdir / "out.fasta")
meta_fn = write_metadata(tmpdir, (("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
args = argparser('-s %s --metadata %s -o %s --query "location==\'colorado\'"'
% (fasta_fn, meta_fn, out_fn))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]
def get_filter_obj_run(args:argparse.Namespace):
"""Returns a filter object connected to an in-memory database with run() invoked."""
obj = FilterSQLite(':memory:')
obj.set_args(args)
# keep intermediate tables to validate contents
obj.run(cleanup=False)
return obj

def test_filter_run_with_query_and_include(self, tmpdir, fasta_fn, argparser):
"""Test that --include still works with filtering on query"""
out_fn = str(tmpdir / "out.fasta")
meta_fn = write_metadata(tmpdir, (("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
include_fn = str(tmpdir / "include")
open(include_fn, "w").write("SEQ_3")
args = argparser('-s %s --metadata %s -o %s --query "quality==\'good\' & location==\'colorado\'" --include %s'
% (fasta_fn, meta_fn, out_fn, include_fn))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_3"]

def test_filter_run_with_query_and_include_where(self, tmpdir, fasta_fn, argparser):
"""Test that --include_where still works with filtering on query"""
out_fn = str(tmpdir / "out.fasta")
meta_fn = write_metadata(tmpdir, (("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
args = argparser('-s %s --metadata %s -o %s --query "quality==\'good\' & location==\'colorado\'" --include-where "location=nevada"'
% (fasta_fn, meta_fn, out_fn))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_3"]
def get_valid_args(data, tmpdir):
"""Returns an argparse.Namespace with metadata and output_strains"""
meta_fn = write_metadata(tmpdir, data)
return parse_args(f'--metadata {meta_fn} --output-strains {tmpdir / "strains.txt"}')

def test_filter_run_min_date(self, tmpdir, fasta_fn, argparser):
"""Test that filter --min-date is inclusive"""
out_fn = str(tmpdir / "out.fasta")
min_date = "2020-02-26"
meta_fn = write_metadata(tmpdir, (("strain","date"),
("SEQ_1","2020-02-XX"),
("SEQ_2","2020-02-26"),
("SEQ_3","2020-02-25")))
args = argparser('-s %s --metadata %s -o %s --min-date %s'
% (fasta_fn, meta_fn, out_fn, min_date))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]

def test_filter_run_max_date(self, tmpdir, fasta_fn, argparser):
"""Test that filter --max-date is inclusive"""
out_fn = str(tmpdir / "out.fasta")
max_date = "2020-03-01"
meta_fn = write_metadata(tmpdir, (("strain","date"),
("SEQ_1","2020-03-XX"),
("SEQ_2","2020-03-01"),
("SEQ_3","2020-03-02")))
args = argparser('-s %s --metadata %s -o %s --max-date %s'
% (fasta_fn, meta_fn, out_fn, max_date))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]
def query_fetchall(filter_obj:FilterSQLite, query:str):
filter_obj.cur.execute(query)
return filter_obj.cur.fetchall()

def test_filter_incomplete_year(self, tmpdir, fasta_fn, argparser):
"""Test that 2020 is evaluated as 2020-XX-XX"""
out_fn = str(tmpdir / "out.fasta")
min_date = "2020-02-01"
meta_fn = write_metadata(tmpdir, (("strain","date"),
("SEQ_1","2020.0"),
("SEQ_2","2020"),
("SEQ_3","2020-XX-XX")))
args = argparser('-s %s --metadata %s -o %s --min-date %s'
% (fasta_fn, meta_fn, out_fn, min_date))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_2", "SEQ_3"]

def test_filter_date_formats(self, tmpdir, fasta_fn, argparser):
"""Test that 2020.0, 2020, and 2020-XX-XX all pass --min-date 2019"""
out_fn = str(tmpdir / "out.fasta")
min_date = "2019"
meta_fn = write_metadata(tmpdir, (("strain","date"),
("SEQ_1","2020.0"),
("SEQ_2","2020"),
("SEQ_3","2020-XX-XX")))
args = argparser('-s %s --metadata %s -o %s --min-date %s'
% (fasta_fn, meta_fn, out_fn, min_date))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_2", "SEQ_3"]
def query_fetchall_dict(filter_obj:FilterSQLite, query:str):
filter_obj.connection.row_factory = sqlite3.Row
cur = filter_obj.connection.cursor()
cur.execute(query)
return [dict(row) for row in cur.fetchall()]
69 changes: 69 additions & 0 deletions tests/test_filter_data_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from augur.filter_support.db.sqlite import (
METADATA_TABLE_NAME,
PRIORITIES_TABLE_NAME,
)

from test_filter import write_file
from tests.test_filter import get_filter_obj_run, get_valid_args, query_fetchall


def get_filter_obj_with_priority_loaded(tmpdir, content:str):
priorities_fn = write_file(tmpdir, "priorities.txt", content)
# metadata is a required arg but we don't need it
data = [("strain","location","quality"),
("SEQ_1","colorado","good")]
args = get_valid_args(data, tmpdir)
args.priority = priorities_fn
return get_filter_obj_run(args)


class TestDataLoading:
def test_load_metadata(self, tmpdir):
"""Load a metadata file."""
data = [("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")]
args = get_valid_args(data, tmpdir)
filter_obj = get_filter_obj_run(args)
results = query_fetchall(filter_obj, f"SELECT * FROM {METADATA_TABLE_NAME}")
assert [row[1:] for row in results] == data[1:]

def test_load_priority_scores_valid(self, tmpdir):
"""Load a priority score file."""
content = "strain1\t5\nstrain2\t6\nstrain3\t8\n"
filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content)
filter_obj.db_load_priorities_table()
results = query_fetchall(filter_obj, f"SELECT * FROM {PRIORITIES_TABLE_NAME}")
assert results == [(0, "strain1", 5.0), (1, "strain2", 6.0), (2, "strain3", 8.0)]

@pytest.mark.skip(reason="this isn't trivial with SQLite's flexible typing rules")
def test_load_priority_scores_malformed(self, tmpdir):
"""Attempt to load a priority score file with non-float in priority column raises a ValueError."""
content = "strain1 X\n"
filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content)
with pytest.raises(ValueError) as e_info:
filter_obj.db_load_priorities_table()
assert str(e_info.value) == f"Failed to parse priority file {filter_obj.args.priority}."

def test_load_priority_scores_valid_with_spaces_and_tabs(self, tmpdir):
"""Load a priority score file with spaces in strain names."""
content = "strain 1\t5\nstrain 2\t6\nstrain 3\t8\n"
filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content)
filter_obj.db_load_priorities_table()
results = query_fetchall(filter_obj, f"SELECT * FROM {PRIORITIES_TABLE_NAME}")
assert results == [(0, "strain 1", 5.0), (1, "strain 2", 6.0), (2, "strain 3", 8.0)]

def test_load_priority_scores_does_not_exist(self, tmpdir):
"""Attempt to load a non-existant priority score file raises a FileNotFoundError."""
invalid_priorities_fn = str(tmpdir / "does/not/exist.txt")
# metadata is a required arg but we don't need it
data = [("strain","location","quality"),
("SEQ_1","colorado","good")]
args = get_valid_args(data, tmpdir)
args.priority = invalid_priorities_fn
filter_obj = get_filter_obj_run(args)
with pytest.raises(FileNotFoundError):
filter_obj.db_load_priorities_table()
Loading

0 comments on commit 3fcd258

Please sign in to comment.