Skip to content

Commit

Permalink
🚧 Use Dask
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Mar 11, 2024
1 parent 850d908 commit 946d93f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
10 changes: 5 additions & 5 deletions augur/dates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,22 @@ def get_numerical_date_from_value(value, fmt=None, min_max_year=None):
except:
return None

def get_numerical_dates(metadata:pd.DataFrame, name_col = None, date_col='date', fmt=None, min_max_year=None):
if not isinstance(metadata, pd.DataFrame):
raise AugurError("Metadata should be a pandas.DataFrame.")
def get_numerical_dates(metadata, name_col = None, date_col='date', fmt=None, min_max_year=None):
# if not isinstance(metadata, pd.DataFrame):
# raise AugurError("Metadata should be a pandas.DataFrame.")
if fmt:
strains = metadata.index.values
dates = metadata[date_col].apply(
lambda date: get_numerical_date_from_value(
date,
fmt,
min_max_year
)
), meta=(date_col, 'str')
).values
else:
strains = metadata.index.values
dates = metadata[date_col].astype(float)
return dict(zip(strains, dates))
return dict(zip(strains.compute(), dates.compute()))

def get_iso_year_week(year, month, day):
return datetime.date(year, month, day).isocalendar()[:2]
31 changes: 15 additions & 16 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import pandas as pd
import dask.dataframe as dd
from tempfile import NamedTemporaryFile

from augur.errors import AugurError
Expand All @@ -14,7 +15,7 @@
DELIMITER as SEQUENCE_INDEX_DELIMITER,
)
from augur.io.file import PANDAS_READ_CSV_OPTIONS, open_file
from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata
from augur.io.metadata import InvalidDelimiter, Metadata
from augur.io.sequences import read_sequences, write_sequences
from augur.io.print import print_err
from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf
Expand Down Expand Up @@ -97,20 +98,17 @@ def run(args):
)
useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns)

metadata = read_metadata(
metadata = dd.read_csv(
args.metadata,
delimiters=[metadata_object.delimiter],
columns=useful_metadata_columns,
id_columns=[metadata_object.id_column],
dtype={col: 'category' for col in useful_metadata_columns},
)

duplicate_strains = metadata.index[metadata.index.duplicated()]
if len(duplicate_strains) > 0:
raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains)))
delimiter=metadata_object.delimiter,
usecols=useful_metadata_columns,
dtype='str',
).set_index(metadata_object.id_column)

# FIXME: remove redundant variable from chunking logic
metadata_strains = set(metadata.index.values)
# FIXME: detect duplicates
# duplicate_strains = metadata.index[metadata.index.duplicated()]
# if len(duplicate_strains) > 0:
# raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains)))

# Setup filters.
exclude_by, include_by = construct_filters(
Expand Down Expand Up @@ -261,16 +259,17 @@ def run(args):
args.metadata_id_columns, args.output_metadata,
args.output_strains, valid_strains)

# FIXME: inspect metadata/sequence mismatch
# Calculate the number of strains that don't exist in either metadata or
# sequences.
num_excluded_by_lack_of_metadata = 0
if sequence_strains:
num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains)
# if sequence_strains:
# num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains)


# Calculate the number of strains passed and filtered.
total_strains_passed = len(valid_strains)
total_strains_filtered = len(metadata_strains) + num_excluded_by_lack_of_metadata - total_strains_passed
total_strains_filtered = len(metadata.index) + num_excluded_by_lack_of_metadata - total_strains_passed

print(f"{total_strains_filtered} {'strain was' if total_strains_filtered == 1 else 'strains were'} dropped during filtering")

Expand Down
4 changes: 2 additions & 2 deletions augur/filter/include_exclude_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def filter_by_min_date(metadata, date_column, min_date) -> FilterFunctionReturn:
['strain1', 'strain2']
"""
strains = set(metadata.index.values)
strains = set(metadata.index.values.compute())

# Skip this filter if the date column does not exist.
if date_column not in metadata.columns:
Expand Down Expand Up @@ -766,7 +766,7 @@ def apply_filters(metadata, exclude_by: List[FilterOption], include_by: List[Fil
[{'strain': 'strain2', 'filter': 'force_include_where', 'kwargs': '[["include_where", "region=Europe"]]'}]
"""
strains_to_keep = set(metadata.index.values)
strains_to_keep = set(metadata.index.values.compute())
strains_to_filter = []
strains_to_force_include = []
distinct_strains_to_force_include: Set = set()
Expand Down
9 changes: 6 additions & 3 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ def subsample(metadata, args, group_by):
metadata,
group_by,
)

def apply_priorities(row):
return priorities[row.name]

# Enrich with priorities.
grouping_metadata['priority'] = [priorities[strain] for strain in grouping_metadata.index]
grouping_metadata['priority'] = grouping_metadata.apply(apply_priorities, axis=1, meta=('priority', 'f8'))

pandas_groupby = grouping_metadata.groupby(list(group_by), group_keys=False)

n_groups = len(pandas_groupby.groups)
n_groups = len(pandas_groupby.size())

# Determine sequences per group.
if args.sequences_per_group:
sequences_per_group = args.sequences_per_group
elif args.subsample_max_sequences:
group_sizes = [len(strains) for strains in pandas_groupby.groups.values()]
group_sizes = pandas_groupby.size().compute().tolist()

try:
# Calculate sequences per group. If there are more groups than maximum
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
# TODO: Remove biopython >= 1.80 pin if it is added to bcbio-gff: https://github.com/chapmanb/bcbb/issues/142
"biopython >=1.80, ==1.*",
"cvxopt >=1.1.9, ==1.*",
"dask[dataframe]",
"pyarrow",
"importlib_resources >=5.3.0; python_version < '3.11'",
"isodate ==0.6.*",
"jsonschema >=3.0.0, ==3.*",
Expand Down

0 comments on commit 946d93f

Please sign in to comment.