Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add params to CurveFits.combineCurveFits #41

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.

0.10.0
------

Added
+++++
- Added the following parameters to ``CurveFits.combineCurveFits``: ``sera``, ``viruses``, ``serum_virus_replicates_to_drop``.

0.9.0
-----

Expand Down
2 changes: 1 addition & 1 deletion neutcurve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

__author__ = "Jesse Bloom"
__email__ = "jbloom@fredhutch.org"
__version__ = "0.9.0"
__version__ = "0.10.0"
__url__ = "https://github.com/jbloomlab/neutcurve"

from neutcurve.curvefits import CurveFits # noqa: F401
Expand Down
123 changes: 79 additions & 44 deletions neutcurve/curvefits.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,27 @@ class CurveFits:
_WILDTYPE_NAMES = ("WT", "wt", "wildtype", "Wildtype", "wild type", "Wild type")

@staticmethod
def combineCurveFits(curvefits_list):
def combineCurveFits(
curvefits_list,
*,
sera=None,
viruses=None,
serum_virus_replicates_to_drop=None,
):
"""
Args:
`curvesfit_list` (list)
List of :class:`CurveFits` objects that are identical other than the
data they contain and have unique virus/serum/replicate combinations.
They can differ in `fixtop` and `fixbottom`, but then those
will be set to `None` in the returned object.
`sera` (None or list)
Only keep fits for sera in this list, or keep all sera if `None`.
`viruses` (None or list)
Only keep fits for viruses in this list, or keep all sera if `None`.
`serum_virus_replicates_to_drop` (None or list)
If a list, should specify `(serum, virus, replicates)` tuples, and those
particular fits are dropped.

Returns:
combined_fits (:class:`CurveFits`)
Expand Down Expand Up @@ -147,6 +160,31 @@ def combineCurveFits(curvefits_list):
combined_fits.df = combined_fits.df[
combined_fits.df[combined_fits.replicate_col] != "average"
]
for col, keeplist in [
(combined_fits.serum_col, sera),
(combined_fits.virus_col, viruses),
]:
if keeplist is not None:
combined_fits.df = combined_fits.df[
combined_fits.df[col].isin(keeplist)
]
if serum_virus_replicates_to_drop:
assert "tup" not in set(combined_fits.df.columns)
combined_fits.df = (
combined_fits.df.assign(
tup=lambda x: list(
x[
[
combined_fits.serum_col,
combined_fits.virus_col,
combined_fits.replicate_col,
]
].itertuples(index=False, name=None)
),
)
.query("tup not in @serum_virus_replicates_to_drop")
.drop(columns="tup")
)
combined_fits.df = combined_fits._get_avg_and_stderr_df(combined_fits.df)
if len(combined_fits.df) != len(
combined_fits.df.groupby(
Expand All @@ -161,56 +199,47 @@ def combineCurveFits(curvefits_list):
raise ValueError("duplicated sera/virus/replicate in `curvefits_list`")

# combine sera
combined_fits.sera = list(
dict.fromkeys([serum for f in curvefits_list for serum in f.sera])
)
assert set(combined_fits.sera) == set(combined_fits.df[combined_fits.serum_col])
combined_fits.sera = combined_fits.df[combined_fits.serum_col].unique().tolist()

# combine allviruses
combined_fits.allviruses = list(
dict.fromkeys([virus for f in curvefits_list for virus in f.allviruses])
)
assert set(combined_fits.allviruses) == set(
combined_fits.df[combined_fits.virus_col]
combined_fits.allviruses = (
combined_fits.df[combined_fits.virus_col].unique().tolist()
)

# combine viruses and replicates
combined_fits.viruses = {}
combined_fits.replicates = {}
for serum in combined_fits.sera:
combined_fits.viruses[serum] = []
for virus in combined_fits.allviruses:
for f in curvefits_list:
if (
(serum in f.viruses)
and (virus in f.viruses[serum])
and (virus not in combined_fits.viruses[serum])
):
combined_fits.viruses[serum].append(virus)
if (serum, virus) in f.replicates:
if (serum, virus) not in combined_fits.replicates:
combined_fits.replicates[(serum, virus)] = f.replicates[
(serum, virus)
]
else:
combined_fits.replicates[(serum, virus)] += f.replicates[
(serum, virus)
]
for key, val in combined_fits.replicates.items():
val = dict.fromkeys(val)
if "average" in val:
del val["average"]
if len(val) != len(set(val)):
raise ValueError(f"duplicate replicate for {key}")
combined_fits.replicates[key] = list(val)
combined_fits.replicates[key].append("average")
assert combined_fits.serum_col != "viruses"
combined_fits.viruses = (
combined_fits.df.groupby(combined_fits.serum_col, sort=False)
.aggregate(
viruses=pd.NamedAgg(
combined_fits.virus_col,
lambda s: s.unique().tolist(),
),
)["viruses"]
.to_dict()
)
assert combined_fits.serum_col != "replicate"
assert combined_fits.virus_col != "replicate"
combined_fits.replicates = (
combined_fits.df[combined_fits.df[combined_fits.replicate_col] != "average"]
.groupby([combined_fits.serum_col, combined_fits.virus_col], sort=False)
.aggregate(
replicates=pd.NamedAgg(
combined_fits.replicate_col,
lambda s: s.unique().tolist(),
),
)["replicates"]
.to_dict()
)
for serum, virus in combined_fits.replicates:
combined_fits.replicates[(serum, virus)].append("average")
serum_virus_rep_tups = [
(serum, virus, rep)
for (serum, virus), reps in combined_fits.replicates.items()
for rep in reps
]
assert len(serum_virus_rep_tups) == len(set(serum_virus_rep_tups))
assert set(serum_virus_rep_tups) == set(
combined_fits_tups = set(
combined_fits.df[
[
combined_fits.serum_col,
Expand All @@ -219,7 +248,10 @@ def combineCurveFits(curvefits_list):
]
].itertuples(index=False, name=None)
)
assert set(combined_fits.allviruses) == set(
assert (
set(serum_virus_rep_tups) == combined_fits_tups
), f"{combined_fits_tups=}\n\n{serum_virus_rep_tups=}"
assert set(combined_fits.allviruses).issubset(
v for f in curvefits_list for s in f.viruses.values() for v in s
)

Expand All @@ -228,9 +260,12 @@ def combineCurveFits(curvefits_list):
combined_fits._hillcurves = {}
for c in curvefits_list:
for (serum, virus, replicate), curve in c._hillcurves.items():
assert serum in combined_fits.sera
assert virus in combined_fits.allviruses
if replicate != "average":
if (
(serum in combined_fits.sera)
and (virus in combined_fits.allviruses)
and (replicate in combined_fits.replicates[(serum, virus)])
and (replicate != "average")
):
combined_fits._hillcurves[(serum, virus, replicate)] = curve

combined_fits._fitparams = {} # clear this cache
Expand Down
Loading