diff --git a/lineage/Analyze.py b/lineage/Analyze.py index c9db81dc7..c69181b78 100644 --- a/lineage/Analyze.py +++ b/lineage/Analyze.py @@ -66,7 +66,7 @@ def fit_list( def Analyze_list( - pop_list: list, num_states: int, fpi=None, fT=None, rng=None + pop_list: list, num_states: int, fpi=None, fT=None, rng=None, write_states=False ) -> Tuple[list[tHMM], float, list[np.ndarray]]: """This function runs the analyze function for the case when we want to fit multiple conditions at the same time. :param pop_list: The list of cell populations to run the analyze function on. @@ -93,6 +93,14 @@ def Analyze_list( LL = LL2 gammas = gammas2 + # store the Viterbi-predicted states + if write_states: + for tHMMobj in tHMMobj_list: + states = tHMMobj.predict() + + for lin_indx, lin in enumerate(tHMMobj.X): + lin.states = states[lin_indx] + return tHMMobj_list, LL, gammas @@ -168,13 +176,9 @@ def Results(tHMMobj: tHMM, LL: float) -> dict[str, Any]: results_dict["total_number_of_lineages"] = len(tHMMobj.X) results_dict["LL"] = LL - results_dict["total_number_of_cells"] = sum( - [len(lineage.output_lineage) for lineage in tHMMobj.X] - ) + results_dict["total_number_of_cells"] = sum([len(lin) for lin in tHMMobj.X]) - true_states_by_lineage = [ - [cell.state for cell in lineage.output_lineage] for lineage in tHMMobj.X - ] + true_states_by_lineage = [lin.states for lin in tHMMobj.X] results_dict["transition_matrix_similarity"] = np.linalg.norm( tHMMobj.estimate.T - tHMMobj.X[0].T diff --git a/lineage/CellVar.py b/lineage/CellVar.py index 0cf192194..7b7c8e185 100644 --- a/lineage/CellVar.py +++ b/lineage/CellVar.py @@ -1,11 +1,23 @@ """ This file contains the class for CellVar which holds the state and observation information in the hidden and observed trees respectively. """ -from __future__ import annotations import numpy as np from typing import Optional from dataclasses import dataclass +@dataclass(init=True, repr=True, eq=True, order=True) +class Time: + """ + Class that stores all the time related observations in a neater format. + This assists in pruning based on experimental time and obtaining + attributes of the lineage as a whole like the average growth rate. + """ + + startT: float + endT: float + transition_time: float = 0.0 + + class CellVar: """ Cell class. @@ -14,13 +26,13 @@ class CellVar: parent: Optional["CellVar"] gen: int observed: bool - state: Optional[int] + state: int obs: np.ndarray time: Optional[Time] left: Optional["CellVar"] right: Optional["CellVar"] - def __init__(self, parent: Optional["CellVar"], state: Optional[int] = None): + def __init__(self, parent: Optional["CellVar"], state: int = -1): """Instantiates the cell object. Contains memeber variables that identify daughter cells and parent cells. Also contains the state of the cell. @@ -78,16 +90,3 @@ def isLeaf(self) -> bool: # otherwise, it itself is observed and at least one of its daughters is observed return False - - -@dataclass(init=True, repr=True, eq=True, order=True) -class Time: - """ - Class that stores all the time related observations in a neater format. - This assists in pruning based on experimental time and obtaining - attributes of the lineage as a whole like the average growth rate. - """ - - startT: float - endT: float - transition_time: float = 0.0 diff --git a/lineage/LineageInputOutput.py b/lineage/LineageInputOutput.py index d23248a7b..a26f6bd8d 100644 --- a/lineage/LineageInputOutput.py +++ b/lineage/LineageInputOutput.py @@ -1,5 +1,6 @@ """ The file contains the methods used to input lineage data from the Heiser lab. """ +import logging import math import pandas as pd from .CellVar import CellVar as c @@ -250,7 +251,7 @@ def tryRecursion( # Check that the parent cell didn't get time censored (Likely divided in last frame) if divisionTime == exp_time: - print( + logging.info( f"Cell time censorship, but daughters were found in row {parentPos+1}, column {pColumn+1}. By default they will be set to None" ) @@ -263,7 +264,7 @@ def tryRecursion( # Creating daughter daughterCell = c(parent=parentCell) - daughterCell.obs = [0, 0, 0, 0, 0, 0] + daughterCell.obs = np.array([0.0, 0, 0, 0, 0, 0]) # find upper daughter daughterCell.left = tryRecursion( diff --git a/lineage/LineageTree.py b/lineage/LineageTree.py index 60c7df749..e8b71363e 100644 --- a/lineage/LineageTree.py +++ b/lineage/LineageTree.py @@ -19,6 +19,7 @@ class LineageTree: leaves_idx: np.ndarray output_lineage: list[CellVar] cell_to_daughters: np.ndarray + states: np.ndarray E: Sequence[StA | StB] def __init__(self, list_of_cells: list, E: Sequence[StA | StB]): @@ -31,6 +32,8 @@ def __init__(self, list_of_cells: list, E: Sequence[StA | StB]): # Leaves have no daughters self.leaves_idx = np.nonzero(np.all(self.cell_to_daughters == -1, axis=1))[0] + self.states = np.array([cell.state for cell in self.output_lineage], dtype=int) + @classmethod def rand_init( cls, diff --git a/lineage/figures/common.py b/lineage/figures/common.py index 88de397af..c8023d5e9 100644 --- a/lineage/figures/common.py +++ b/lineage/figures/common.py @@ -1023,14 +1023,7 @@ def plot_all(ax, num_states, tHMMobj_list, Dname, cons, concsValues): (4, num_states, 2) ) # the avg lifetime: num_conc x num_states x num_phases bern_lpt = np.zeros((4, num_states, 2)) # bernoulli - # print parameters and estimated values - print( - Dname, - "\n the \u03C0: ", - tHMMobj_list[0].estimate.pi, - "\n the transition matrix: ", - tHMMobj_list[0].estimate.T, - ) + for idx, tHMMobj in enumerate(tHMMobj_list): # for each concentration data for i in range(num_states): lpt_avg[idx, i, 0] = np.log10( @@ -1046,19 +1039,6 @@ def plot_all(ax, num_states, tHMMobj_list, Dname, cons, concsValues): plotting(ax, lpt_avg, bern_lpt, cons, concsValues, num_states) -def sort_lins(tHMMobj): +def sort_lins(tHMMobj) -> list: """Sorts lineages based on their root cell state for plotting the lineage trees.""" - num_st = tHMMobj.estimate.num_states - - st = [] # holds the state of root cell in all lineages for this particular tHMMobj - for lins in tHMMobj.X: - st.append(lins.output_lineage[0].state) - - states = [] - for i in range(num_st): - st_i = [index for index, val in enumerate(st) if val == i] - temp = [tHMMobj.X[k] for k in st_i] - - states += temp - - return states + return sorted(tHMMobj.X, key=lambda lin: lin.states[0]) diff --git a/lineage/figures/figure11.py b/lineage/figures/figure11.py index 500e984a3..c6ff33f98 100644 --- a/lineage/figures/figure11.py +++ b/lineage/figures/figure11.py @@ -10,15 +10,7 @@ concsValues = ["Control", "25 nM", "50 nM", "250 nM"] num_states = 4 -lapt_tHMMobj_list = Analyze_list(AllLapatinib, num_states)[0] - -lapt_states_list = [tHMMobj.predict() for tHMMobj in lapt_tHMMobj_list] - -# assign the predicted states to each cell -for idx, lapt_tHMMobj in enumerate(lapt_tHMMobj_list): - for lin_indx, lin in enumerate(lapt_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = lapt_states_list[idx][lin_indx][cell_indx] +lapt_tHMMobj_list = Analyze_list(AllLapatinib, num_states, write_states=True)[0] T_lap = lapt_tHMMobj_list[0].estimate.T num_states = lapt_tHMMobj_list[0].num_states diff --git a/lineage/figures/figure111.py b/lineage/figures/figure111.py index 03bb34365..37d21b2b0 100644 --- a/lineage/figures/figure111.py +++ b/lineage/figures/figure111.py @@ -51,34 +51,20 @@ def state_abundance_perRep(reps): s3 = [] s4 = [] s5 = [] + for rep in reps: - st0 = 0 - st1 = 0 - st2 = 0 - st3 = 0 - st4 = 0 - st5 = 0 + states = [] + for lineageTree_list in rep: for lineage_tree in lineageTree_list: - for cell in lineage_tree.output_lineage: - if cell.state == 0: - st0 += 1 - elif cell.state == 1: - st1 += 1 - elif cell.state == 2: - st2 += 1 - elif cell.state == 3: - st3 += 1 - elif cell.state == 4: - st4 += 1 - elif cell.state == 5: - st5 += 1 - s0.append(st0) - s1.append(st1) - s2.append(st2) - s3.append(st3) - s4.append(st4) - s5.append(st5) + states = np.concatenate((states, lineage_tree.states)) + + s0.append(np.sum(states == 0)) + s1.append(np.sum(states == 1)) + s2.append(np.sum(states == 2)) + s3.append(np.sum(states == 3)) + s4.append(np.sum(states == 4)) + s5.append(np.sum(states == 5)) return [s0, s1, s2, s3, s4, s5] diff --git a/lineage/figures/figure12.py b/lineage/figures/figure12.py index 2d4f98e0a..b8e5adefc 100644 --- a/lineage/figures/figure12.py +++ b/lineage/figures/figure12.py @@ -11,17 +11,9 @@ concsValues = ["Control", "5 nM", "10 nM", "30 nM"] num_states = 5 -gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0] - -gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list] - -for idx, gemc_tHMMobj in enumerate(gemc_tHMMobj_list): - for lin_indx, lin in enumerate(gemc_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = gemc_states_list[idx][lin_indx][cell_indx] +gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0] T_gem = gemc_tHMMobj_list[0].estimate.T -num_states = gemc_tHMMobj_list[0].num_states # plot transition block plot_networkx(T_gem, "gemcitabine") diff --git a/lineage/figures/figure13.py b/lineage/figures/figure13.py index a5f1aec47..26c350371 100644 --- a/lineage/figures/figure13.py +++ b/lineage/figures/figure13.py @@ -14,20 +14,10 @@ concs = ["PBS", "EGF", "HGF", "OSM"] num_states = 3 -hgf_tHMMobj_list = Analyze_list(GFs, num_states)[0] - -hgf_states_list = [tHMMobj.predict() for tHMMobj in hgf_tHMMobj_list] - -# assign the predicted states to each cell -for idx, hgf_tHMMobj in enumerate(hgf_tHMMobj_list): - for lin_indx, lin in enumerate(hgf_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = hgf_states_list[idx][lin_indx][cell_indx] +hgf_tHMMobj_list = Analyze_list(GFs, num_states, write_states=True)[0] T_hgf = hgf_tHMMobj_list[0].estimate.T -num_states = hgf_tHMMobj_list[0].num_states - rcParams["font.sans-serif"] = "Arial" diff --git a/lineage/figures/figure16.py b/lineage/figures/figure16.py index 0454de523..c572fa794 100644 --- a/lineage/figures/figure16.py +++ b/lineage/figures/figure16.py @@ -6,15 +6,7 @@ from ..Lineage_collections import AllGemcitabine num_states = 5 -gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0] - -gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list] - -# assign the predicted states to each cell -for idx, lapt_tHMMobj in enumerate(gemc_tHMMobj_list): - for lin_indx, lin in enumerate(lapt_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = gemc_states_list[idx][lin_indx][cell_indx] +gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0] only_lapatinib_control_1 = gemc_tHMMobj_list[0].X[0:100] diff --git a/lineage/figures/figure17.py b/lineage/figures/figure17.py index 88d44dd69..3057d7f8d 100644 --- a/lineage/figures/figure17.py +++ b/lineage/figures/figure17.py @@ -6,15 +6,7 @@ from ..Lineage_collections import AllGemcitabine num_states = 5 -gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0] - -gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list] - -# assign the predicted states to each cell -for idx, lapt_tHMMobj in enumerate(gemc_tHMMobj_list): - for lin_indx, lin in enumerate(lapt_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = gemc_states_list[idx][lin_indx][cell_indx] +gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0] only_gemcitabine_control_1 = gemc_tHMMobj_list[0].X[100:200] diff --git a/lineage/figures/figure6.py b/lineage/figures/figure6.py index b1782f100..99fe33954 100644 --- a/lineage/figures/figure6.py +++ b/lineage/figures/figure6.py @@ -77,9 +77,7 @@ def accuracy(): balanced_score = np.empty(len(list_of_populations)) for ii, pop in enumerate(list_of_populations): - ravel_true_states = np.array( - [cell.state for lineage in pop for cell in lineage.output_lineage] - ) + ravel_true_states = np.array([lineage.states for lineage in pop]).flatten() all_cells = np.array( [cell.obs[2] for lineage in pop for cell in lineage.output_lineage] ) @@ -122,9 +120,9 @@ def accuracy(): for cell in lineage.output_lineage ] distribution_df["State"] = [ - "State 1" if cell.state == 0 else "State 2" + "State 1" if state == 0 else "State 2" for lineage in lineages - for cell in lineage.output_lineage + for state in lineage.states ] distribution_df["Distribution Similarity"] = ( len_lineages[0] * ["Same\n" + str(0) + "-" + str(wass[-1] / 4)] diff --git a/lineage/figures/figureS11.py b/lineage/figures/figureS11.py index 1063556fe..a8f7b004f 100644 --- a/lineage/figures/figureS11.py +++ b/lineage/figures/figureS11.py @@ -6,16 +6,7 @@ from ..Analyze import Analyze_list from ..Lineage_collections import AllLapatinib -num_states = 4 -lapt_tHMMobj_list = Analyze_list(AllLapatinib, num_states)[0] - -lapt_states_list = [tHMMobj.predict() for tHMMobj in lapt_tHMMobj_list] - -# assign the predicted states to each cell -for idx, lapt_tHMMobj in enumerate(lapt_tHMMobj_list): - for lin_indx, lin in enumerate(lapt_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = lapt_states_list[idx][lin_indx][cell_indx] +lapt_tHMMobj_list = Analyze_list(AllLapatinib, 4, write_states=True)[0] for i in range(4): lapt_tHMMobj_list[i].X = sort_lins(lapt_tHMMobj_list[i]) diff --git a/lineage/figures/figureS12.py b/lineage/figures/figureS12.py index b2a47d26f..e7f6d1e5c 100644 --- a/lineage/figures/figureS12.py +++ b/lineage/figures/figureS12.py @@ -7,14 +7,7 @@ from ..Lineage_collections import AllGemcitabine num_states = 5 -gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0] - -gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list] - -for idx, gemc_tHMMobj in enumerate(gemc_tHMMobj_list): - for lin_indx, lin in enumerate(gemc_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = gemc_states_list[idx][lin_indx][cell_indx] +gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0] for i in range(4): gemc_tHMMobj_list[i].X = sort_lins(gemc_tHMMobj_list[i]) diff --git a/lineage/figures/figureS15.py b/lineage/figures/figureS15.py index a67a2f592..d5382ba14 100644 --- a/lineage/figures/figureS15.py +++ b/lineage/figures/figureS15.py @@ -28,15 +28,15 @@ def find_state_proportions(lapt_tHMMobj, control=None): censor_condition=0, full_lineage=lineage.output_lineage ) - for cell in output_lineage: + for ii, cell in enumerate(output_lineage): if math.isnan(cell.time.startT): # left censored. startT = 0 cell.time.startT = 0.0 if math.isnan(cell.time.endT): # right censored. endT = 96 cell.time.endT = 96.0 if cell.time.startT <= t <= cell.time.endT: - if cell.state == 0: + if lineage.states[ii] == 0: st0 += 1 - elif cell.state == 1: + elif lineage.states[ii] == 1: st1 += 1 else: st2 += 1 diff --git a/lineage/figures/figureS16.py b/lineage/figures/figureS16.py index 211b789e3..3f5f0721d 100644 --- a/lineage/figures/figureS16.py +++ b/lineage/figures/figureS16.py @@ -27,25 +27,15 @@ def plot_barcode_vs_state(ax, drug_name): """Plots the histogram of barcode vs states after clustering, using the parameters from lapatinib and gemcitabine fits.""" if drug_name == "lapatinibs": - tHMMobj_list, _, _ = Analyze_list(AllLapatinib, 4) + tHMMobj_list, _, _ = Analyze_list(AllLapatinib, 4, write_states=True) elif drug_name == "gemcitabines": - tHMMobj_list, _, _ = Analyze_list(AllGemcitabine, 5) - - states_list = [tHMMobj.predict() for tHMMobj in tHMMobj_list] - - for idx, tHMMobj in enumerate(tHMMobj_list): - for lin_indx, lin in enumerate(tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = states_list[idx][lin_indx][cell_indx] + tHMMobj_list, _, _ = Analyze_list(AllGemcitabine, 5, write_states=True) num_states = tHMMobj_list[0].num_states states_by_lin = [] for lineage in tHMMobj_list[0].X: - tmp2 = [] - for cell in lineage.output_lineage: - tmp2.append(cell.state) - states_by_lin.append(tmp2) + states_by_lin.append(lineage.states) for i in range(num_lineages): ax[i].hist(states_by_lin[i], bins=np.linspace(0, 5, 11)) diff --git a/lineage/figures/figureS17.py b/lineage/figures/figureS17.py index 479219873..21859a11e 100644 --- a/lineage/figures/figureS17.py +++ b/lineage/figures/figureS17.py @@ -7,14 +7,7 @@ from ..Lineage_collections import GFs from ..Analyze import Analyze_list -hgf_tHMMobj_list = Analyze_list(GFs, 3)[0] -hgf_states_list = [tHMMobj.predict() for tHMMobj in hgf_tHMMobj_list] - -# assign the predicted states to each cell -for idx, hgf_tHMMobj in enumerate(hgf_tHMMobj_list): - for lin_indx, lin in enumerate(hgf_tHMMobj.X): - for cell_indx, cell in enumerate(lin.output_lineage): - cell.state = hgf_states_list[idx][lin_indx][cell_indx] +hgf_tHMMobj_list = Analyze_list(GFs, 3, write_states=True)[0] for thmm_obj in hgf_tHMMobj_list: thmm_obj.X = sort_lins(thmm_obj) @@ -26,7 +19,7 @@ def makeFigure(): """ Makes figure 101. """ - num_lins = [len(hgf_tHMMobj_list[i].X) for i in range(4)] + num_lins = [len(tM.X) for tM in hgf_tHMMobj_list] ax, f = getSetup((10, 20), (np.max(num_lins), 4)) for i in range(4 * np.max(num_lins)): diff --git a/lineage/plotTree.py b/lineage/plotTree.py index b3f71b733..445908dcb 100644 --- a/lineage/plotTree.py +++ b/lineage/plotTree.py @@ -4,6 +4,8 @@ from Bio.Phylo.BaseTree import Clade import networkx as nx from .figures.common import getSetup +from .LineageTree import LineageTree +from .CellVar import CellVar cs = ["lightblue", "orange", "lightgreen", "red", "purple", "grey"] stateColors = ["blue", "orange", "green", "red", "purple", "grey"] @@ -26,7 +28,7 @@ def plot_lineage_samples(tHMMobj_list, name): f.savefig("lineage/figures/cartoons/" + name + ".svg") -def CladeRecursive(cell, a: list, censor: bool, color: bool): +def CladeRecursive(cell: CellVar, a: list, censor: bool, color: bool): """A recurssive function that takes in the root cell and traverses through cells to plot the lineage. The width of the lines show the phase of the cells. The color of the lines show the state of the cells. @@ -80,10 +82,15 @@ def CladeRecursive(cell, a: list, censor: bool, color: bool): return my_clade -def plotLineage(lineage, axes, censor=True, color=True): +def plotLineage(lineage: LineageTree, axes, censor: bool = True, color: bool = True): """ Given a lineage of cells, uses the `CladeRecursive` function to plot the lineage. """ + for ii, cell in enumerate(lineage.output_lineage): + cell.state = lineage.states[ii] + + if color: + assert cell.state >= 0 root = lineage.output_lineage[0] if np.isfinite(root.obs[4]): # the lineage starts from G1 phase @@ -103,7 +110,9 @@ def plotLineage(lineage, axes, censor=True, color=True): return draw(c, axes=axes) -def plotLineage_MCF10A(lineage, axes, censor=True, color=True): +def plotLineage_MCF10A( + lineage: LineageTree, axes, censor: bool = True, color: bool = True +): """ Given a lineage of cells, uses the `CladeRecursive` function to plot the lineage. """ diff --git a/lineage/tests/test_BaumWelch.py b/lineage/tests/test_BaumWelch.py index a75acce19..789bd6943 100644 --- a/lineage/tests/test_BaumWelch.py +++ b/lineage/tests/test_BaumWelch.py @@ -115,7 +115,7 @@ def test_E_step(cens): do_E_step(tHMMobj) pred_states = tHMMobj.predict() - true_states = [cell.state for cell in tHMMobj.X[0].output_lineage] + true_states = tHMMobj.X[0].states assert rand_score(true_states, pred_states[0]) >= 0.9