Skip to content

Commit

Permalink
Move state handling to the lineage (#993)
Browse files Browse the repository at this point in the history
* Moved state handling to the lineage

* Fixes

* Fixes

* Cleanup

* Clean up

* Fix the state sorting

* Minor fixes

* Minor fix
  • Loading branch information
aarmey authored Mar 6, 2024
1 parent 9f37cb0 commit 39929b4
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 172 deletions.
18 changes: 11 additions & 7 deletions lineage/Analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def fit_list(


def Analyze_list(
pop_list: list, num_states: int, fpi=None, fT=None, rng=None
pop_list: list, num_states: int, fpi=None, fT=None, rng=None, write_states=False
) -> Tuple[list[tHMM], float, list[np.ndarray]]:
"""This function runs the analyze function for the case when we want to fit multiple conditions at the same time.
:param pop_list: The list of cell populations to run the analyze function on.
Expand All @@ -93,6 +93,14 @@ def Analyze_list(
LL = LL2
gammas = gammas2

# store the Viterbi-predicted states
if write_states:
for tHMMobj in tHMMobj_list:
states = tHMMobj.predict()

for lin_indx, lin in enumerate(tHMMobj.X):
lin.states = states[lin_indx]

return tHMMobj_list, LL, gammas


Expand Down Expand Up @@ -168,13 +176,9 @@ def Results(tHMMobj: tHMM, LL: float) -> dict[str, Any]:

results_dict["total_number_of_lineages"] = len(tHMMobj.X)
results_dict["LL"] = LL
results_dict["total_number_of_cells"] = sum(
[len(lineage.output_lineage) for lineage in tHMMobj.X]
)
results_dict["total_number_of_cells"] = sum([len(lin) for lin in tHMMobj.X])

true_states_by_lineage = [
[cell.state for cell in lineage.output_lineage] for lineage in tHMMobj.X
]
true_states_by_lineage = [lin.states for lin in tHMMobj.X]

results_dict["transition_matrix_similarity"] = np.linalg.norm(
tHMMobj.estimate.T - tHMMobj.X[0].T
Expand Down
31 changes: 15 additions & 16 deletions lineage/CellVar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
""" This file contains the class for CellVar which holds the state and observation information in the hidden and observed trees respectively. """

from __future__ import annotations
import numpy as np
from typing import Optional
from dataclasses import dataclass


@dataclass(init=True, repr=True, eq=True, order=True)
class Time:
"""
Class that stores all the time related observations in a neater format.
This assists in pruning based on experimental time and obtaining
attributes of the lineage as a whole like the average growth rate.
"""

startT: float
endT: float
transition_time: float = 0.0


class CellVar:
"""
Cell class.
Expand All @@ -14,13 +26,13 @@ class CellVar:
parent: Optional["CellVar"]
gen: int
observed: bool
state: Optional[int]
state: int
obs: np.ndarray
time: Optional[Time]
left: Optional["CellVar"]
right: Optional["CellVar"]

def __init__(self, parent: Optional["CellVar"], state: Optional[int] = None):
def __init__(self, parent: Optional["CellVar"], state: int = -1):
"""Instantiates the cell object.
Contains memeber variables that identify daughter cells
and parent cells. Also contains the state of the cell.
Expand Down Expand Up @@ -78,16 +90,3 @@ def isLeaf(self) -> bool:

# otherwise, it itself is observed and at least one of its daughters is observed
return False


@dataclass(init=True, repr=True, eq=True, order=True)
class Time:
"""
Class that stores all the time related observations in a neater format.
This assists in pruning based on experimental time and obtaining
attributes of the lineage as a whole like the average growth rate.
"""

startT: float
endT: float
transition_time: float = 0.0
5 changes: 3 additions & 2 deletions lineage/LineageInputOutput.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" The file contains the methods used to input lineage data from the Heiser lab. """

import logging
import math
import pandas as pd
from .CellVar import CellVar as c
Expand Down Expand Up @@ -250,7 +251,7 @@ def tryRecursion(

# Check that the parent cell didn't get time censored (Likely divided in last frame)
if divisionTime == exp_time:
print(
logging.info(
f"Cell time censorship, but daughters were found in row {parentPos+1}, column {pColumn+1}. By default they will be set to None"
)

Expand All @@ -263,7 +264,7 @@ def tryRecursion(

# Creating daughter
daughterCell = c(parent=parentCell)
daughterCell.obs = [0, 0, 0, 0, 0, 0]
daughterCell.obs = np.array([0.0, 0, 0, 0, 0, 0])

# find upper daughter
daughterCell.left = tryRecursion(
Expand Down
3 changes: 3 additions & 0 deletions lineage/LineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LineageTree:
leaves_idx: np.ndarray
output_lineage: list[CellVar]
cell_to_daughters: np.ndarray
states: np.ndarray
E: Sequence[StA | StB]

def __init__(self, list_of_cells: list, E: Sequence[StA | StB]):
Expand All @@ -31,6 +32,8 @@ def __init__(self, list_of_cells: list, E: Sequence[StA | StB]):
# Leaves have no daughters
self.leaves_idx = np.nonzero(np.all(self.cell_to_daughters == -1, axis=1))[0]

self.states = np.array([cell.state for cell in self.output_lineage], dtype=int)

@classmethod
def rand_init(
cls,
Expand Down
26 changes: 3 additions & 23 deletions lineage/figures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,14 +1023,7 @@ def plot_all(ax, num_states, tHMMobj_list, Dname, cons, concsValues):
(4, num_states, 2)
) # the avg lifetime: num_conc x num_states x num_phases
bern_lpt = np.zeros((4, num_states, 2)) # bernoulli
# print parameters and estimated values
print(
Dname,
"\n the \u03C0: ",
tHMMobj_list[0].estimate.pi,
"\n the transition matrix: ",
tHMMobj_list[0].estimate.T,
)

for idx, tHMMobj in enumerate(tHMMobj_list): # for each concentration data
for i in range(num_states):
lpt_avg[idx, i, 0] = np.log10(
Expand All @@ -1046,19 +1039,6 @@ def plot_all(ax, num_states, tHMMobj_list, Dname, cons, concsValues):
plotting(ax, lpt_avg, bern_lpt, cons, concsValues, num_states)


def sort_lins(tHMMobj):
def sort_lins(tHMMobj) -> list:
"""Sorts lineages based on their root cell state for plotting the lineage trees."""
num_st = tHMMobj.estimate.num_states

st = [] # holds the state of root cell in all lineages for this particular tHMMobj
for lins in tHMMobj.X:
st.append(lins.output_lineage[0].state)

states = []
for i in range(num_st):
st_i = [index for index, val in enumerate(st) if val == i]
temp = [tHMMobj.X[k] for k in st_i]

states += temp

return states
return sorted(tHMMobj.X, key=lambda lin: lin.states[0])
10 changes: 1 addition & 9 deletions lineage/figures/figure11.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@
concsValues = ["Control", "25 nM", "50 nM", "250 nM"]

num_states = 4
lapt_tHMMobj_list = Analyze_list(AllLapatinib, num_states)[0]

lapt_states_list = [tHMMobj.predict() for tHMMobj in lapt_tHMMobj_list]

# assign the predicted states to each cell
for idx, lapt_tHMMobj in enumerate(lapt_tHMMobj_list):
for lin_indx, lin in enumerate(lapt_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = lapt_states_list[idx][lin_indx][cell_indx]
lapt_tHMMobj_list = Analyze_list(AllLapatinib, num_states, write_states=True)[0]

T_lap = lapt_tHMMobj_list[0].estimate.T
num_states = lapt_tHMMobj_list[0].num_states
Expand Down
36 changes: 11 additions & 25 deletions lineage/figures/figure111.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,20 @@ def state_abundance_perRep(reps):
s3 = []
s4 = []
s5 = []

for rep in reps:
st0 = 0
st1 = 0
st2 = 0
st3 = 0
st4 = 0
st5 = 0
states = []

for lineageTree_list in rep:
for lineage_tree in lineageTree_list:
for cell in lineage_tree.output_lineage:
if cell.state == 0:
st0 += 1
elif cell.state == 1:
st1 += 1
elif cell.state == 2:
st2 += 1
elif cell.state == 3:
st3 += 1
elif cell.state == 4:
st4 += 1
elif cell.state == 5:
st5 += 1
s0.append(st0)
s1.append(st1)
s2.append(st2)
s3.append(st3)
s4.append(st4)
s5.append(st5)
states = np.concatenate((states, lineage_tree.states))

s0.append(np.sum(states == 0))
s1.append(np.sum(states == 1))
s2.append(np.sum(states == 2))
s3.append(np.sum(states == 3))
s4.append(np.sum(states == 4))
s5.append(np.sum(states == 5))

return [s0, s1, s2, s3, s4, s5]

Expand Down
10 changes: 1 addition & 9 deletions lineage/figures/figure12.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@
concsValues = ["Control", "5 nM", "10 nM", "30 nM"]

num_states = 5
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0]

gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list]

for idx, gemc_tHMMobj in enumerate(gemc_tHMMobj_list):
for lin_indx, lin in enumerate(gemc_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = gemc_states_list[idx][lin_indx][cell_indx]
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0]

T_gem = gemc_tHMMobj_list[0].estimate.T
num_states = gemc_tHMMobj_list[0].num_states

# plot transition block
plot_networkx(T_gem, "gemcitabine")
Expand Down
12 changes: 1 addition & 11 deletions lineage/figures/figure13.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,10 @@
concs = ["PBS", "EGF", "HGF", "OSM"]

num_states = 3
hgf_tHMMobj_list = Analyze_list(GFs, num_states)[0]

hgf_states_list = [tHMMobj.predict() for tHMMobj in hgf_tHMMobj_list]

# assign the predicted states to each cell
for idx, hgf_tHMMobj in enumerate(hgf_tHMMobj_list):
for lin_indx, lin in enumerate(hgf_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = hgf_states_list[idx][lin_indx][cell_indx]
hgf_tHMMobj_list = Analyze_list(GFs, num_states, write_states=True)[0]

T_hgf = hgf_tHMMobj_list[0].estimate.T

num_states = hgf_tHMMobj_list[0].num_states

rcParams["font.sans-serif"] = "Arial"


Expand Down
10 changes: 1 addition & 9 deletions lineage/figures/figure16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
from ..Lineage_collections import AllGemcitabine

num_states = 5
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0]

gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list]

# assign the predicted states to each cell
for idx, lapt_tHMMobj in enumerate(gemc_tHMMobj_list):
for lin_indx, lin in enumerate(lapt_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = gemc_states_list[idx][lin_indx][cell_indx]
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0]

only_lapatinib_control_1 = gemc_tHMMobj_list[0].X[0:100]

Expand Down
10 changes: 1 addition & 9 deletions lineage/figures/figure17.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
from ..Lineage_collections import AllGemcitabine

num_states = 5
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0]

gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list]

# assign the predicted states to each cell
for idx, lapt_tHMMobj in enumerate(gemc_tHMMobj_list):
for lin_indx, lin in enumerate(lapt_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = gemc_states_list[idx][lin_indx][cell_indx]
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0]

only_gemcitabine_control_1 = gemc_tHMMobj_list[0].X[100:200]

Expand Down
8 changes: 3 additions & 5 deletions lineage/figures/figure6.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def accuracy():
balanced_score = np.empty(len(list_of_populations))

for ii, pop in enumerate(list_of_populations):
ravel_true_states = np.array(
[cell.state for lineage in pop for cell in lineage.output_lineage]
)
ravel_true_states = np.array([lineage.states for lineage in pop]).flatten()
all_cells = np.array(
[cell.obs[2] for lineage in pop for cell in lineage.output_lineage]
)
Expand Down Expand Up @@ -122,9 +120,9 @@ def accuracy():
for cell in lineage.output_lineage
]
distribution_df["State"] = [
"State 1" if cell.state == 0 else "State 2"
"State 1" if state == 0 else "State 2"
for lineage in lineages
for cell in lineage.output_lineage
for state in lineage.states
]
distribution_df["Distribution Similarity"] = (
len_lineages[0] * ["Same\n" + str(0) + "-" + str(wass[-1] / 4)]
Expand Down
11 changes: 1 addition & 10 deletions lineage/figures/figureS11.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from ..Analyze import Analyze_list
from ..Lineage_collections import AllLapatinib

num_states = 4
lapt_tHMMobj_list = Analyze_list(AllLapatinib, num_states)[0]

lapt_states_list = [tHMMobj.predict() for tHMMobj in lapt_tHMMobj_list]

# assign the predicted states to each cell
for idx, lapt_tHMMobj in enumerate(lapt_tHMMobj_list):
for lin_indx, lin in enumerate(lapt_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = lapt_states_list[idx][lin_indx][cell_indx]
lapt_tHMMobj_list = Analyze_list(AllLapatinib, 4, write_states=True)[0]

for i in range(4):
lapt_tHMMobj_list[i].X = sort_lins(lapt_tHMMobj_list[i])
Expand Down
9 changes: 1 addition & 8 deletions lineage/figures/figureS12.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
from ..Lineage_collections import AllGemcitabine

num_states = 5
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states)[0]

gemc_states_list = [tHMMobj.predict() for tHMMobj in gemc_tHMMobj_list]

for idx, gemc_tHMMobj in enumerate(gemc_tHMMobj_list):
for lin_indx, lin in enumerate(gemc_tHMMobj.X):
for cell_indx, cell in enumerate(lin.output_lineage):
cell.state = gemc_states_list[idx][lin_indx][cell_indx]
gemc_tHMMobj_list = Analyze_list(AllGemcitabine, num_states, write_states=True)[0]

for i in range(4):
gemc_tHMMobj_list[i].X = sort_lins(gemc_tHMMobj_list[i])
Expand Down
6 changes: 3 additions & 3 deletions lineage/figures/figureS15.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def find_state_proportions(lapt_tHMMobj, control=None):
censor_condition=0, full_lineage=lineage.output_lineage
)

for cell in output_lineage:
for ii, cell in enumerate(output_lineage):
if math.isnan(cell.time.startT): # left censored. startT = 0
cell.time.startT = 0.0
if math.isnan(cell.time.endT): # right censored. endT = 96
cell.time.endT = 96.0
if cell.time.startT <= t <= cell.time.endT:
if cell.state == 0:
if lineage.states[ii] == 0:
st0 += 1
elif cell.state == 1:
elif lineage.states[ii] == 1:
st1 += 1
else:
st2 += 1
Expand Down
Loading

0 comments on commit 39929b4

Please sign in to comment.