From 9c6722db00351434ed1b02378f793ba97b068868 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Mon, 7 Feb 2022 18:53:55 -0800 Subject: [PATCH] Add pytests - all samples dropped - metadata id cols - priority column mismatch - date parsing --- tests/test_filter_data_loading.py | 89 +++++++++++++++ tests/test_filter_date_parsing.py | 177 ++++++++++++++++++++++++++++++ tests/test_filter_filtering.py | 77 +++++++++++++ tests/test_filter_groupby.py | 16 +++ 4 files changed, 359 insertions(+) create mode 100644 tests/test_filter_date_parsing.py diff --git a/tests/test_filter_data_loading.py b/tests/test_filter_data_loading.py index 67f6fad1b..a38c03796 100644 --- a/tests/test_filter_data_loading.py +++ b/tests/test_filter_data_loading.py @@ -2,7 +2,9 @@ from augur.filter_support.db.sqlite import ( METADATA_TABLE_NAME, + OUTPUT_METADATA_TABLE_NAME, PRIORITIES_TABLE_NAME, + SEQUENCE_INDEX_TABLE_NAME, ) from test_filter import write_file @@ -67,3 +69,90 @@ def test_load_priority_scores_does_not_exist(self, tmpdir): filter_obj = get_filter_obj_run(args) with pytest.raises(FileNotFoundError): filter_obj.db_load_priorities_table() + + def test_load_invalid_id_column(self, tmpdir): + data = [ + ("invalid_name","date","country"), + ("SEQ_1","2020-01-XX","A"), + ] + args = get_valid_args(data, tmpdir) + with pytest.raises(ValueError) as e_info: + get_filter_obj_run(args) + assert str(e_info.value) == "None of the possible id columns (['strain', 'name']) were found in the metadata's columns ('invalid_name', 'date', 'country')" + + def test_load_custom_id_column(self, tmpdir): + data = [ + ("custom_id_col","date","country"), + ("SEQ_1","2020-01-XX","A"), + ] + args = get_valid_args(data, tmpdir) + args.metadata_id_columns = ["custom_id_col"] + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f""" + SELECT custom_id_col FROM {METADATA_TABLE_NAME} + """) + assert results == [("SEQ_1",)] + + def test_load_custom_id_column_with_spaces(self, tmpdir): + data = [ + ("strain name with spaces","date","country"), + ("SEQ_1","2020-01-XX","A"), + ] + args = get_valid_args(data, tmpdir) + args.metadata_id_columns = ["strain name with spaces"] + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f""" + SELECT "strain name with spaces" FROM {METADATA_TABLE_NAME} + """) + assert results == [("SEQ_1",)] + + def test_load_priority_scores_extra_column(self, tmpdir): + """Attempt to load a priority score file with an extra column raises a ValueError.""" + content = "strain1\t5\tbad_col\n" + filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content) + with pytest.raises(ValueError) as e_info: + filter_obj.db_load_priorities_table() + assert str(e_info.value) == f"Failed to parse priority file {filter_obj.args.priority}." + + def test_load_priority_scores_missing_column(self, tmpdir): + """Attempt to load a priority score file with a missing column raises a ValueError.""" + content = "strain1\n" + filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content) + with pytest.raises(ValueError) as e_info: + filter_obj.db_load_priorities_table() + assert str(e_info.value) == f"Failed to parse priority file {filter_obj.args.priority}." + + def test_load_sequences_subset_strains(self, tmpdir): + """Loading sequences filters output to the intersection of strains from metadata and sequences.""" + data = [("strain",), + ("SEQ_1",), + ("SEQ_2",), + ("SEQ_3",)] + args = get_valid_args(data, tmpdir) + fasta_lines = [ + ">SEQ_1", "aaaa", + ">SEQ_3", "aaaa", + ">SEQ_4", "nnnn", + ] + args.sequences = write_file(tmpdir, "sequences.fasta", "\n".join(fasta_lines)) + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f"SELECT strain FROM {OUTPUT_METADATA_TABLE_NAME}") + assert results == [("SEQ_1",), ("SEQ_3",)] + + def test_generate_sequence_index(self, tmpdir): + """Loading sequences filters output to the intersection of strains from metadata and sequences.""" + data = [("strain",), + ("SEQ_1",), + ("SEQ_2",), + ("SEQ_3",)] + args = get_valid_args(data, tmpdir) + fasta_lines = [ + ">SEQ_1", "aaaa", + ">SEQ_3", "aaaa", + ">SEQ_4", "nnnn", + ] + args.sequences = write_file(tmpdir, "sequences.fasta", "\n".join(fasta_lines)) + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f"SELECT strain, A, N FROM {SEQUENCE_INDEX_TABLE_NAME}") + print(results) + assert results == [("SEQ_1", 4, 0), ("SEQ_3", 4, 0), ("SEQ_4", 0, 4)] diff --git a/tests/test_filter_date_parsing.py b/tests/test_filter_date_parsing.py new file mode 100644 index 000000000..111a2bdfd --- /dev/null +++ b/tests/test_filter_date_parsing.py @@ -0,0 +1,177 @@ +import pytest +from treetime.utils import numeric_date +from datetime import date +from textwrap import dedent +from augur.filter_support.date_parsing import InvalidDateFormat + +from augur.filter_support.db.sqlite import ( + DATE_MIN_COL, + DATE_MAX_COL, + DATE_TABLE_NAME, +) + +from tests.test_filter import get_filter_obj_run, get_valid_args, query_fetchall + + +def get_parsed_date_min_max(date:str, tmpdir): + data = [ + ("strain","date"), + ("SEQ_1",date), + ] + args = get_valid_args(data, tmpdir) + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f""" + SELECT {DATE_MIN_COL}, {DATE_MAX_COL} FROM {DATE_TABLE_NAME} + """) + return results[0] + + +class TestDateParsing: + def test_ambiguous_day(self, tmpdir): + """Ambiguous day yields a certain min/max range.""" + date_min, date_max = get_parsed_date_min_max( + "2018-01-XX", tmpdir) + assert date_min == pytest.approx(2018.001, abs=1e-3) + assert date_max == pytest.approx(2018.083, abs=1e-3) + + def test_missing_day(self, tmpdir): + """Date without day yields a range equivalent to ambiguous day.""" + date_min, date_max = get_parsed_date_min_max( + "2018-01", tmpdir) + assert date_min == pytest.approx(2018.001, abs=1e-3) + assert date_max == pytest.approx(2018.083, abs=1e-3) + + def test_ambiguous_month(self, tmpdir): + """Ambiguous month yields a certain min/max range.""" + date_min, date_max = get_parsed_date_min_max( + "2018-XX-XX", tmpdir) + assert date_min == pytest.approx(2018.001, abs=1e-3) + assert date_max == pytest.approx(2018.999, abs=1e-3) + + def test_missing_month(self, tmpdir): + """Date without month/day yields a range equivalent to ambiguous month/day.""" + date_min, date_max = get_parsed_date_min_max( + "2018", tmpdir) + assert date_min == pytest.approx(2018.001, abs=1e-3) + assert date_max == pytest.approx(2018.999, abs=1e-3) + + def test_numerical_exact_year(self, tmpdir): + """Numerical year ending in .0 should be interpreted as exact.""" + date_min, date_max = get_parsed_date_min_max( + "2018.0", tmpdir) + assert date_min == pytest.approx(2018.001, abs=1e-3) + assert date_max == pytest.approx(2018.001, abs=1e-3) + + def test_ambiguous_year(self, tmpdir): + """Ambiguous year replaces X with 0 (min) and 9 (max).""" + date_min, date_max = get_parsed_date_min_max( + "201X-XX-XX", tmpdir) + assert date_min == pytest.approx(2010.001, abs=1e-3) + assert date_max == pytest.approx(2019.999, abs=1e-3) + + def test_ambiguous_year_incomplete_date(self, tmpdir): + """Ambiguous year without month/day yields a range equivalent to ambiguous month/day counterpart.""" + date_min, date_max = get_parsed_date_min_max( + "201X", tmpdir) + assert date_min == pytest.approx(2010.001, abs=1e-3) + assert date_max == pytest.approx(2019.999, abs=1e-3) + + def test_ambiguous_year_decade(self, tmpdir): + """Parse year-only ambiguous date with ambiguous decade.""" + date_min, date_max = get_parsed_date_min_max( + "10X1", tmpdir) + assert date_min == pytest.approx(1001.001, abs=1e-3) + assert date_max == pytest.approx(1091.999, abs=1e-3) + + def test_ambiguous_year_incomplete_date(self, tmpdir): + """Ambiguous year without explicit X fails parsing.""" + date_min, date_max = get_parsed_date_min_max("201x", tmpdir) + assert date_min == None + assert date_max == None + + def test_future_year(self, tmpdir): + """Date from the future should be converted to today.""" + date_min, date_max = get_parsed_date_min_max( + "3000", tmpdir) + assert date_min == pytest.approx(numeric_date(date.today()), abs=1e-3) + assert date_max == pytest.approx(numeric_date(date.today()), abs=1e-3) + + def test_ambiguous_month_exact_date_error(self, tmpdir): + """Date that has ambiguous month but exact date raises an error.""" + with pytest.raises(InvalidDateFormat) as e_info: + get_parsed_date_min_max("2018-XX-01", tmpdir) + assert str(e_info.value) == dedent(f"""\ + Some dates have an invalid format (showing at most 3): '2018-XX-01'. + If year contains ambiguity, month and day must also be ambiguous. + If month contains ambiguity, day must also be ambiguous.""") + + def test_ambiguous_month_exact_date_error(self, tmpdir): + """Date that has ambiguous year but exact month and date raises an error.""" + with pytest.raises(InvalidDateFormat) as e_info: + get_parsed_date_min_max("20X8-01-01", tmpdir) + assert str(e_info.value) == dedent(f"""\ + Some dates have an invalid format (showing at most 3): '20X8-01-01'. + If year contains ambiguity, month and day must also be ambiguous. + If month contains ambiguity, day must also be ambiguous.""") + + def test_out_of_bounds_month(self, tmpdir): + """Out-of-bounds month cannot be parsed.""" + date_min, date_max = get_parsed_date_min_max("2018-00-01", tmpdir) + assert date_min == None + assert date_max == None + date_min, date_max = get_parsed_date_min_max("2018-13-01", tmpdir) + assert date_min == None + assert date_max == None + + def test_out_of_bounds_day(self, tmpdir): + """Out-of-bounds day cannot be parsed.""" + date_min, date_max = get_parsed_date_min_max("2018-01-00", tmpdir) + assert date_min == None + assert date_max == None + date_min, date_max = get_parsed_date_min_max("2018-02-30", tmpdir) + assert date_min == None + assert date_max == None + + def test_negative_iso_date_error(self, tmpdir): + """Negative ISO dates are unsupported.""" + date_min, date_max = get_parsed_date_min_max("-2018-01-01", tmpdir) + assert date_min == None + assert date_max == None + + def test_negative_ambiguous_iso_date_error(self, tmpdir): + """Negative ambiguous ISO dates are unsupported.""" + date_min, date_max = get_parsed_date_min_max("-2018-XX-XX", tmpdir) + assert date_min == None + assert date_max == None + + def test_negative_iso_date_missing_day_error(self, tmpdir): + """Negative incomplete ISO dates are unsupported.""" + date_min, date_max = get_parsed_date_min_max("-2018-01", tmpdir) + assert date_min == None + assert date_max == None + + def test_negative_iso_date_missing_month_day_error(self, tmpdir): + """Negative incomplete ISO dates are unsupported.""" + date_min, date_max = get_parsed_date_min_max("-2018", tmpdir) + assert date_min == None + assert date_max == None + + def test_negative_numeric_date(self, tmpdir): + """Parse negative numeric date.""" + date_min, date_max = get_parsed_date_min_max( + "-2018.0", tmpdir) + assert date_min == pytest.approx(-2018.0, abs=1e-3) + assert date_max == pytest.approx(-2018.0, abs=1e-3) + + def test_zero_year_error(self, tmpdir): + """Zero year-only date is unsupported.""" + date_min, date_max = get_parsed_date_min_max("0", tmpdir) + assert date_min == None + assert date_max == None + + def test_zero_year(self, tmpdir): + """Parse the date 0.0.""" + date_min, date_max = get_parsed_date_min_max( + "0.0", tmpdir) + assert date_min == pytest.approx(0.0, abs=1e-3) + assert date_max == pytest.approx(0.0, abs=1e-3) diff --git a/tests/test_filter_filtering.py b/tests/test_filter_filtering.py index 91a4769cb..5c6a38853 100644 --- a/tests/test_filter_filtering.py +++ b/tests/test_filter_filtering.py @@ -1,14 +1,17 @@ +import pytest from augur.filter_support.db.sqlite import ( EXCLUDE_COL, FILTER_REASON_COL, INCLUDE_COL, METADATA_FILTER_REASON_TABLE_NAME, ) +from augur.filter_support.exceptions import FilterException from test_filter import ( get_filter_obj_run, get_valid_args, query_fetchall, + write_file, ) @@ -110,3 +113,77 @@ def test_filter_by_max_date(self, tmpdir): WHERE {FILTER_REASON_COL} = 'filter_by_max_date' """) assert results == [("SEQ_3",)] + + def test_filter_by_exclude_where(self, tmpdir): + """Filter by max date, inclusive.""" + data = [("strain","location","quality"), + ("SEQ_1","colorado","good"), + ("SEQ_2","colorado","bad"), + ("SEQ_3","nevada","good")] + args = get_valid_args(data, tmpdir) + args.exclude_where = ["location=colorado"] + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f""" + SELECT strain + FROM {METADATA_FILTER_REASON_TABLE_NAME} + WHERE {FILTER_REASON_COL} = 'filter_by_exclude_where' + """) + assert results == [("SEQ_1",), ("SEQ_2",)] + + def test_filter_by_exclude_where_missing_column_error(self, tmpdir): + """Filter by max date, inclusive.""" + data = [("strain","location","quality"), + ("SEQ_1","colorado","good"), + ("SEQ_2","colorado","bad"), + ("SEQ_3","nevada","good")] + args = get_valid_args(data, tmpdir) + args.exclude_where = ["invalid=colorado"] + with pytest.raises(FilterException) as e_info: + get_filter_obj_run(args) + assert str(e_info.value) == 'no such column: metadata.invalid' + + def test_filter_by_min_length(self, tmpdir): + """Filter by minimum sequence length of 3.""" + data = [("strain",), + ("SEQ_1",), + ("SEQ_2",), + ("SEQ_3",)] + args = get_valid_args(data, tmpdir) + fasta_lines = [ + ">SEQ_1", "aa", + ">SEQ_2", "aaa", + ">SEQ_3", "nnnn", + ] + args.sequences = write_file(tmpdir, "sequences.fasta", "\n".join(fasta_lines)) + args.min_length = 3 + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f""" + SELECT strain + FROM {METADATA_FILTER_REASON_TABLE_NAME} + WHERE {FILTER_REASON_COL} = 'filter_by_sequence_length' + """) + assert results == [("SEQ_1",), ("SEQ_3",)] + + def test_filter_by_non_nucleotide(self, tmpdir): + """Filter out sequences with at least 1 invalid nucleotide character.""" + data = [("strain",), + ("SEQ_1",), + ("SEQ_2",), + ("SEQ_3",), + ("SEQ_4",)] + args = get_valid_args(data, tmpdir) + fasta_lines = [ + ">SEQ_1", "aaaa", + ">SEQ_2", "nnnn", + ">SEQ_3", "xxxx", + ">SEQ_4", "aaax", + ] + args.sequences = write_file(tmpdir, "sequences.fasta", "\n".join(fasta_lines)) + args.non_nucleotide = True + filter_obj = get_filter_obj_run(args) + results = query_fetchall(filter_obj, f""" + SELECT strain + FROM {METADATA_FILTER_REASON_TABLE_NAME} + WHERE {FILTER_REASON_COL} = 'filter_by_non_nucleotide' + """) + assert results == [("SEQ_3",), ("SEQ_4",)] diff --git a/tests/test_filter_groupby.py b/tests/test_filter_groupby.py index 48b25ec79..0ce190d45 100644 --- a/tests/test_filter_groupby.py +++ b/tests/test_filter_groupby.py @@ -221,3 +221,19 @@ def test_filter_groupby_only_year_month_provided(self, tmpdir): WHERE {FILTER_REASON_COL} IS NULL """) assert results == [("SEQ_1",), ("SEQ_2",), ("SEQ_3",), ("SEQ_4",), ("SEQ_5",)] + + def test_all_samples_dropped(self, tmpdir): + data = [ + ("strain","date","country"), + ("SEQ_1","2020","A"), + ("SEQ_2","2020","B"), + ("SEQ_3","2020","C"), + ("SEQ_4","2020","D"), + ("SEQ_5","2020","E") + ] + args = get_valid_args(data, tmpdir) + args.group_by = ["country", "year", "month"] + args.sequences_per_group = 1 + with pytest.raises(FilterException) as e_info: + get_filter_obj_run(args) + assert str(e_info.value) == "All samples have been dropped! Check filter rules and metadata file format."