diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c475aa9..d7e14ef 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,13 @@ All notable changes to this project will be documented in this file. The format is based on `Keep a Changelog `_. +0.10.0 +------ + +Added ++++++ +- Added the following parameters to ``CurveFits.combineCurveFits``: ``sera``, ``viruses``, ``serum_virus_replicates_to_drop``. + 0.9.0 ----- diff --git a/neutcurve/__init__.py b/neutcurve/__init__.py index e8986e1..e0fc2c5 100644 --- a/neutcurve/__init__.py +++ b/neutcurve/__init__.py @@ -16,7 +16,7 @@ __author__ = "Jesse Bloom" __email__ = "jbloom@fredhutch.org" -__version__ = "0.9.0" +__version__ = "0.10.0" __url__ = "https://github.com/jbloomlab/neutcurve" from neutcurve.curvefits import CurveFits # noqa: F401 diff --git a/neutcurve/curvefits.py b/neutcurve/curvefits.py index 4f254a2..4c68fe6 100644 --- a/neutcurve/curvefits.py +++ b/neutcurve/curvefits.py @@ -69,7 +69,13 @@ class CurveFits: _WILDTYPE_NAMES = ("WT", "wt", "wildtype", "Wildtype", "wild type", "Wild type") @staticmethod - def combineCurveFits(curvefits_list): + def combineCurveFits( + curvefits_list, + *, + sera=None, + viruses=None, + serum_virus_replicates_to_drop=None, + ): """ Args: `curvesfit_list` (list) @@ -77,6 +83,13 @@ def combineCurveFits(curvefits_list): data they contain and have unique virus/serum/replicate combinations. They can differ in `fixtop` and `fixbottom`, but then those will be set to `None` in the returned object. + `sera` (None or list) + Only keep fits for sera in this list, or keep all sera if `None`. + `viruses` (None or list) + Only keep fits for viruses in this list, or keep all sera if `None`. + `serum_virus_replicates_to_drop` (None or list) + If a list, should specify `(serum, virus, replicates)` tuples, and those + particular fits are dropped. Returns: combined_fits (:class:`CurveFits`) @@ -147,6 +160,31 @@ def combineCurveFits(curvefits_list): combined_fits.df = combined_fits.df[ combined_fits.df[combined_fits.replicate_col] != "average" ] + for col, keeplist in [ + (combined_fits.serum_col, sera), + (combined_fits.virus_col, viruses), + ]: + if keeplist is not None: + combined_fits.df = combined_fits.df[ + combined_fits.df[col].isin(keeplist) + ] + if serum_virus_replicates_to_drop: + assert "tup" not in set(combined_fits.df.columns) + combined_fits.df = ( + combined_fits.df.assign( + tup=lambda x: list( + x[ + [ + combined_fits.serum_col, + combined_fits.virus_col, + combined_fits.replicate_col, + ] + ].itertuples(index=False, name=None) + ), + ) + .query("tup not in @serum_virus_replicates_to_drop") + .drop(columns="tup") + ) combined_fits.df = combined_fits._get_avg_and_stderr_df(combined_fits.df) if len(combined_fits.df) != len( combined_fits.df.groupby( @@ -161,56 +199,47 @@ def combineCurveFits(curvefits_list): raise ValueError("duplicated sera/virus/replicate in `curvefits_list`") # combine sera - combined_fits.sera = list( - dict.fromkeys([serum for f in curvefits_list for serum in f.sera]) - ) - assert set(combined_fits.sera) == set(combined_fits.df[combined_fits.serum_col]) + combined_fits.sera = combined_fits.df[combined_fits.serum_col].unique().tolist() # combine allviruses - combined_fits.allviruses = list( - dict.fromkeys([virus for f in curvefits_list for virus in f.allviruses]) - ) - assert set(combined_fits.allviruses) == set( - combined_fits.df[combined_fits.virus_col] + combined_fits.allviruses = ( + combined_fits.df[combined_fits.virus_col].unique().tolist() ) # combine viruses and replicates - combined_fits.viruses = {} - combined_fits.replicates = {} - for serum in combined_fits.sera: - combined_fits.viruses[serum] = [] - for virus in combined_fits.allviruses: - for f in curvefits_list: - if ( - (serum in f.viruses) - and (virus in f.viruses[serum]) - and (virus not in combined_fits.viruses[serum]) - ): - combined_fits.viruses[serum].append(virus) - if (serum, virus) in f.replicates: - if (serum, virus) not in combined_fits.replicates: - combined_fits.replicates[(serum, virus)] = f.replicates[ - (serum, virus) - ] - else: - combined_fits.replicates[(serum, virus)] += f.replicates[ - (serum, virus) - ] - for key, val in combined_fits.replicates.items(): - val = dict.fromkeys(val) - if "average" in val: - del val["average"] - if len(val) != len(set(val)): - raise ValueError(f"duplicate replicate for {key}") - combined_fits.replicates[key] = list(val) - combined_fits.replicates[key].append("average") + assert combined_fits.serum_col != "viruses" + combined_fits.viruses = ( + combined_fits.df.groupby(combined_fits.serum_col, sort=False) + .aggregate( + viruses=pd.NamedAgg( + combined_fits.virus_col, + lambda s: s.unique().tolist(), + ), + )["viruses"] + .to_dict() + ) + assert combined_fits.serum_col != "replicate" + assert combined_fits.virus_col != "replicate" + combined_fits.replicates = ( + combined_fits.df[combined_fits.df[combined_fits.replicate_col] != "average"] + .groupby([combined_fits.serum_col, combined_fits.virus_col], sort=False) + .aggregate( + replicates=pd.NamedAgg( + combined_fits.replicate_col, + lambda s: s.unique().tolist(), + ), + )["replicates"] + .to_dict() + ) + for serum, virus in combined_fits.replicates: + combined_fits.replicates[(serum, virus)].append("average") serum_virus_rep_tups = [ (serum, virus, rep) for (serum, virus), reps in combined_fits.replicates.items() for rep in reps ] assert len(serum_virus_rep_tups) == len(set(serum_virus_rep_tups)) - assert set(serum_virus_rep_tups) == set( + combined_fits_tups = set( combined_fits.df[ [ combined_fits.serum_col, @@ -219,7 +248,10 @@ def combineCurveFits(curvefits_list): ] ].itertuples(index=False, name=None) ) - assert set(combined_fits.allviruses) == set( + assert ( + set(serum_virus_rep_tups) == combined_fits_tups + ), f"{combined_fits_tups=}\n\n{serum_virus_rep_tups=}" + assert set(combined_fits.allviruses).issubset( v for f in curvefits_list for s in f.viruses.values() for v in s ) @@ -228,9 +260,12 @@ def combineCurveFits(curvefits_list): combined_fits._hillcurves = {} for c in curvefits_list: for (serum, virus, replicate), curve in c._hillcurves.items(): - assert serum in combined_fits.sera - assert virus in combined_fits.allviruses - if replicate != "average": + if ( + (serum in combined_fits.sera) + and (virus in combined_fits.allviruses) + and (replicate in combined_fits.replicates[(serum, virus)]) + and (replicate != "average") + ): combined_fits._hillcurves[(serum, virus, replicate)] = curve combined_fits._fitparams = {} # clear this cache diff --git a/notebooks/combine_curvefits.ipynb b/notebooks/combine_curvefits.ipynb index ebe5f99..cad05bf 100644 --- a/notebooks/combine_curvefits.ipynb +++ b/notebooks/combine_curvefits.ipynb @@ -17,11 +17,11 @@ "execution_count": 1, "metadata": { "execution": { - "iopub.execute_input": "2023-12-13T04:24:38.393504Z", - "iopub.status.busy": "2023-12-13T04:24:38.392651Z", - "iopub.status.idle": "2023-12-13T04:24:40.233700Z", - "shell.execute_reply": "2023-12-13T04:24:40.232271Z", - "shell.execute_reply.started": "2023-12-13T04:24:38.393466Z" + "iopub.execute_input": "2023-12-13T22:23:18.563770Z", + "iopub.status.busy": "2023-12-13T22:23:18.563475Z", + "iopub.status.idle": "2023-12-13T22:23:20.528422Z", + "shell.execute_reply": "2023-12-13T22:23:20.527054Z", + "shell.execute_reply.started": "2023-12-13T22:23:18.563739Z" } }, "outputs": [], @@ -43,11 +43,11 @@ "execution_count": 2, "metadata": { "execution": { - "iopub.execute_input": "2023-12-13T04:24:40.242761Z", - "iopub.status.busy": "2023-12-13T04:24:40.242066Z", - "iopub.status.idle": "2023-12-13T04:24:40.251031Z", - "shell.execute_reply": "2023-12-13T04:24:40.250248Z", - "shell.execute_reply.started": "2023-12-13T04:24:40.242729Z" + "iopub.execute_input": "2023-12-13T22:23:20.537617Z", + "iopub.status.busy": "2023-12-13T22:23:20.537206Z", + "iopub.status.idle": "2023-12-13T22:23:20.546183Z", + "shell.execute_reply": "2023-12-13T22:23:20.545224Z", + "shell.execute_reply.started": "2023-12-13T22:23:20.537585Z" } }, "outputs": [], @@ -68,11 +68,11 @@ "execution_count": 3, "metadata": { "execution": { - "iopub.execute_input": "2023-12-13T04:24:40.255392Z", - "iopub.status.busy": "2023-12-13T04:24:40.255064Z", - "iopub.status.idle": "2023-12-13T04:24:40.556528Z", - "shell.execute_reply": "2023-12-13T04:24:40.555779Z", - "shell.execute_reply.started": "2023-12-13T04:24:40.255368Z" + "iopub.execute_input": "2023-12-13T22:23:20.550240Z", + "iopub.status.busy": "2023-12-13T22:23:20.549869Z", + "iopub.status.idle": "2023-12-13T22:23:20.852757Z", + "shell.execute_reply": "2023-12-13T22:23:20.851679Z", + "shell.execute_reply.started": "2023-12-13T22:23:20.550215Z" }, "tags": [] }, @@ -103,11 +103,11 @@ "execution_count": 4, "metadata": { "execution": { - "iopub.execute_input": "2023-12-13T04:24:40.560854Z", - "iopub.status.busy": "2023-12-13T04:24:40.560533Z", - "iopub.status.idle": "2023-12-13T04:24:41.092707Z", - "shell.execute_reply": "2023-12-13T04:24:41.091945Z", - "shell.execute_reply.started": "2023-12-13T04:24:40.560831Z" + "iopub.execute_input": "2023-12-13T22:23:20.856943Z", + "iopub.status.busy": "2023-12-13T22:23:20.856706Z", + "iopub.status.idle": "2023-12-13T22:23:21.378382Z", + "shell.execute_reply": "2023-12-13T22:23:21.377240Z", + "shell.execute_reply.started": "2023-12-13T22:23:20.856918Z" }, "tags": [] }, @@ -152,11 +152,11 @@ "execution_count": 5, "metadata": { "execution": { - "iopub.execute_input": "2023-12-13T04:24:41.097074Z", - "iopub.status.busy": "2023-12-13T04:24:41.096829Z", - "iopub.status.idle": "2023-12-13T04:24:41.175799Z", - "shell.execute_reply": "2023-12-13T04:24:41.174999Z", - "shell.execute_reply.started": "2023-12-13T04:24:41.097051Z" + "iopub.execute_input": "2023-12-13T22:23:21.382761Z", + "iopub.status.busy": "2023-12-13T22:23:21.382532Z", + "iopub.status.idle": "2023-12-13T22:23:21.466523Z", + "shell.execute_reply": "2023-12-13T22:23:21.465608Z", + "shell.execute_reply.started": "2023-12-13T22:23:21.382737Z" }, "tags": [] }, @@ -179,6 +179,205 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Combine fits only for certain sera:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-13T22:23:21.470062Z", + "iopub.status.busy": "2023-12-13T22:23:21.469813Z", + "iopub.status.idle": "2023-12-13T22:23:21.528967Z", + "shell.execute_reply": "2023-12-13T22:23:21.528093Z", + "shell.execute_reply.started": "2023-12-13T22:23:21.470038Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
serumvirusreplicatenreplicatesic50ic50_boundic50_strmidpointslopetopbottom
0FI6v3WT1<NA>0.016702interpolated0.01670.0167022.50467010
1FI6v3WT2<NA>0.019020interpolated0.0190.0190202.51273310
2FI6v3WT3<NA>0.015167interpolated0.01520.0151671.87799510
3FI6v3WTaverage30.017029interpolated0.0170.0170292.27931610
4FI6v3P80D1<NA>0.012115interpolated0.01210.0121152.02469810
5FI6v3P80D3<NA>0.012835interpolated0.01280.0128352.05906110
6FI6v3P80Daverage20.012472interpolated0.01250.0124722.03543510
\n", + "
" + ], + "text/plain": [ + " serum virus replicate nreplicates ic50 ic50_bound ic50_str \\\n", + "0 FI6v3 WT 1 0.016702 interpolated 0.0167 \n", + "1 FI6v3 WT 2 0.019020 interpolated 0.019 \n", + "2 FI6v3 WT 3 0.015167 interpolated 0.0152 \n", + "3 FI6v3 WT average 3 0.017029 interpolated 0.017 \n", + "4 FI6v3 P80D 1 0.012115 interpolated 0.0121 \n", + "5 FI6v3 P80D 3 0.012835 interpolated 0.0128 \n", + "6 FI6v3 P80D average 2 0.012472 interpolated 0.0125 \n", + "\n", + " midpoint slope top bottom \n", + "0 0.016702 2.504670 1 0 \n", + "1 0.019020 2.512733 1 0 \n", + "2 0.015167 1.877995 1 0 \n", + "3 0.017029 2.279316 1 0 \n", + "4 0.012115 2.024698 1 0 \n", + "5 0.012835 2.059061 1 0 \n", + "6 0.012472 2.035435 1 0 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(\n", + " neutcurve.CurveFits.combineCurveFits(\n", + " [fit1, fit2],\n", + " sera=[\"FI6v3\"],\n", + " viruses=[\"WT\", \"P80D\", \"V135T\"],\n", + " serum_virus_replicates_to_drop=[\n", + " (\"FI6v3\", \"P80D\", \"2\"),\n", + " (\"FI6v3\", \"V135T\", \"1\"),\n", + " (\"FI6v3\", \"V135T\", \"2\"),\n", + " (\"FI6v3\", \"V135T\", \"3\"),\n", + " ],\n", + " ).fitParams(average_only=False)\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": { @@ -196,14 +395,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "execution": { - "iopub.execute_input": "2023-12-13T04:24:41.179493Z", - "iopub.status.busy": "2023-12-13T04:24:41.179178Z", - "iopub.status.idle": "2023-12-13T04:24:41.804570Z", - "shell.execute_reply": "2023-12-13T04:24:41.803184Z", - "shell.execute_reply.started": "2023-12-13T04:24:41.179468Z" + "iopub.execute_input": "2023-12-13T22:23:21.532457Z", + "iopub.status.busy": "2023-12-13T22:23:21.532252Z", + "iopub.status.idle": "2023-12-13T22:23:22.126597Z", + "shell.execute_reply": "2023-12-13T22:23:22.125230Z", + "shell.execute_reply.started": "2023-12-13T22:23:21.532436Z" } }, "outputs": [ @@ -214,8 +413,8 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# NBVAL_RAISES_EXCEPTION\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[43mneutcurve\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mCurveFits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcombineCurveFits\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfit1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfit2_invalid\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/neutcurve/neutcurve/curvefits.py:161\u001b[0m, in \u001b[0;36mCurveFits.combineCurveFits\u001b[0;34m(curvefits_list)\u001b[0m\n\u001b[1;32m 150\u001b[0m combined_fits\u001b[38;5;241m.\u001b[39mdf \u001b[38;5;241m=\u001b[39m combined_fits\u001b[38;5;241m.\u001b[39m_get_avg_and_stderr_df(combined_fits\u001b[38;5;241m.\u001b[39mdf)\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(combined_fits\u001b[38;5;241m.\u001b[39mdf) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(\n\u001b[1;32m 152\u001b[0m combined_fits\u001b[38;5;241m.\u001b[39mdf\u001b[38;5;241m.\u001b[39mgroupby(\n\u001b[1;32m 153\u001b[0m [\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 159\u001b[0m )\n\u001b[1;32m 160\u001b[0m ):\n\u001b[0;32m--> 161\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mduplicated sera/virus/replicate in `curvefits_list`\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 163\u001b[0m \u001b[38;5;66;03m# combine sera\u001b[39;00m\n\u001b[1;32m 164\u001b[0m combined_fits\u001b[38;5;241m.\u001b[39msera \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28mdict\u001b[39m\u001b[38;5;241m.\u001b[39mfromkeys([serum \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m curvefits_list \u001b[38;5;28;01mfor\u001b[39;00m serum \u001b[38;5;129;01min\u001b[39;00m f\u001b[38;5;241m.\u001b[39msera])\n\u001b[1;32m 166\u001b[0m )\n", + "Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# NBVAL_RAISES_EXCEPTION\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[43mneutcurve\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mCurveFits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcombineCurveFits\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfit1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfit2_invalid\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/neutcurve/neutcurve/curvefits.py:199\u001b[0m, in \u001b[0;36mCurveFits.combineCurveFits\u001b[0;34m(curvefits_list, sera, viruses, serum_virus_replicates_to_drop)\u001b[0m\n\u001b[1;32m 188\u001b[0m combined_fits\u001b[38;5;241m.\u001b[39mdf \u001b[38;5;241m=\u001b[39m combined_fits\u001b[38;5;241m.\u001b[39m_get_avg_and_stderr_df(combined_fits\u001b[38;5;241m.\u001b[39mdf)\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(combined_fits\u001b[38;5;241m.\u001b[39mdf) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(\n\u001b[1;32m 190\u001b[0m combined_fits\u001b[38;5;241m.\u001b[39mdf\u001b[38;5;241m.\u001b[39mgroupby(\n\u001b[1;32m 191\u001b[0m [\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 197\u001b[0m )\n\u001b[1;32m 198\u001b[0m ):\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mduplicated sera/virus/replicate in `curvefits_list`\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 201\u001b[0m \u001b[38;5;66;03m# combine sera\u001b[39;00m\n\u001b[1;32m 202\u001b[0m combined_fits\u001b[38;5;241m.\u001b[39msera \u001b[38;5;241m=\u001b[39m combined_fits\u001b[38;5;241m.\u001b[39mdf[combined_fits\u001b[38;5;241m.\u001b[39mserum_col]\u001b[38;5;241m.\u001b[39munique()\u001b[38;5;241m.\u001b[39mtolist()\n", "\u001b[0;31mValueError\u001b[0m: duplicated sera/virus/replicate in `curvefits_list`" ] } @@ -225,13 +424,6 @@ "\n", "neutcurve.CurveFits.combineCurveFits([fit1, fit2_invalid])" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {