diff --git a/auton_survival/models/cmhe/__init__.py b/auton_survival/models/cmhe/__init__.py
new file mode 100644
index 0000000..c31434f
--- /dev/null
+++ b/auton_survival/models/cmhe/__init__.py
@@ -0,0 +1 @@
+from .cmhe_api import DeepCoxMixturesHE
\ No newline at end of file
diff --git a/auton_survival/models/cmhe/cmhe_api.py b/auton_survival/models/cmhe/cmhe_api.py
index f7400ee..0d943f4 100644
--- a/auton_survival/models/cmhe/cmhe_api.py
+++ b/auton_survival/models/cmhe/cmhe_api.py
@@ -1,11 +1,12 @@
-from .cmhe_torch import CoxMixtureHeterogenousEffects
-from .cmhe_torch import DeepCoxMixtureHeterogenousEffects
+from .cmhe_torch import CoxMixtureHETorch
+from .cmhe_torch import DeepCoxMixtureHETorch
 from .cmhe_utilities import train_cmhe, predict_survival
+from .cmhe_utilities import predict_latent_phi, predict_latent_z
 
 import torch
 import numpy as np
 
-class DeepCoxMixturesHeterogenousEffects:
+class DeepCoxMixturesHE:
   """A Deep Cox Mixtures with Heterogenous Effects model.
 
   This is the main interface to a Deep Cox Mixture with Heterogenous Effects.
@@ -17,8 +18,8 @@ class DeepCoxMixturesHeterogenousEffects:
 
   References
   ----------
-  [1] <a href="https://arxiv.org/abs/2101.06536">Deep Cox Mixtures
-  for Survival Regression. Machine Learning in Health Conference (2021)</a>
+  [1] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski.
+  "Counterfactual phenotyping for censored Time-to-Events" (2022).
 
   Parameters
   ----------
@@ -29,6 +30,7 @@ class DeepCoxMixturesHeterogenousEffects:
   layers: list
       A list of integers consisting of the number of neurons in each
       hidden layer.
+
   Example
   -------
   >>> from auton_survival import CoxMixturesHeterogenousEffects
@@ -37,11 +39,13 @@ class DeepCoxMixturesHeterogenousEffects:
 
   """
 
-  def __init__(self, layers=None):
+  def __init__(self, k, g, layers=None):
 
     self.layers = layers
     self.fitted = False
-
+    self.k = k 
+    self.g = g
+  
   def __call__(self):
     if self.fitted:
       print("A fitted instance of the CMHE model")
@@ -50,8 +54,11 @@ def __call__(self):
 
     print("Hidden Layers:", self.layers)
 
-  def _preprocess_test_data(self, x, a):
-    return torch.from_numpy(x).float(), torch.from_numpy(a).float()
+  def _preprocess_test_data(self, x, a=None):
+    if a is not None:
+      return torch.from_numpy(x).float(), torch.from_numpy(a).float()
+    else:
+      return torch.from_numpy(x).float()
 
   def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state):
 
@@ -60,42 +67,44 @@ def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state):
     np.random.seed(random_state)
     np.random.shuffle(idx)
 
-    x_train, t_train, e_train, a_train = x[idx], t[idx], e[idx], a[idx]
+    x_tr, t_tr, e_tr, a_tr = x[idx], t[idx], e[idx], a[idx]
 
-    x_train = torch.from_numpy(x_train).float()
-    t_train = torch.from_numpy(t_train).float()
-    e_train = torch.from_numpy(e_train).float()
-    a_train = torch.from_numpy(a_train).float()
+    x_tr = torch.from_numpy(x_tr).float()
+    t_tr = torch.from_numpy(t_tr).float()
+    e_tr = torch.from_numpy(e_tr).float()
+    a_tr = torch.from_numpy(a_tr).float()
 
     if val_data is None:
 
-      vsize = int(vsize*x_train.shape[0])
-      x_val, t_val, e_val, a_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:], a_train[-vsize:]
+      vsize = int(vsize*x_tr.shape[0])
+      x_vl, t_vl, e_vl, a_vl = x_tr[-vsize:], t_tr[-vsize:], e_tr[-vsize:], a_tr[-vsize:]
 
-      x_train = x_train[:-vsize]
-      t_train = t_train[:-vsize]
-      e_train = e_train[:-vsize]
-      a_train = a_train[:-vsize]
+      x_tr = x_tr[:-vsize]
+      t_tr = t_tr[:-vsize]
+      e_tr = e_tr[:-vsize]
+      a_tr = a_tr[:-vsize]
 
     else:
 
-      x_val, t_val, e_val, a_val = val_data
+      x_vl, t_vl, e_vl, a_vl = val_data
 
-      x_val = torch.from_numpy(x_val).float()
-      t_val = torch.from_numpy(t_val).float()
-      e_val = torch.from_numpy(e_val).float()
-      a_val = torch.from_numpy(a_val).float()
+      x_vl = torch.from_numpy(x_vl).float()
+      t_vl = torch.from_numpy(t_vl).float()
+      e_vl = torch.from_numpy(e_vl).float()
+      a_vl = torch.from_numpy(a_vl).float()
 
-    return (x_train, t_train, e_train, a_train,
-    	      x_val, t_val, e_val, a_val)
+    return (x_tr, t_tr, e_tr, a_tr,
+    	      x_vl, t_vl, e_vl, a_vl)
 
   def _gen_torch_model(self, inputdim, optimizer):
     """Helper function to return a torch model."""
     if len(self.layers):
-      return DeepCoxMixtureHeterogenousEffects(inputdim, layers=self.layers,
-                                               optimizer=optimizer)
+      return DeepCoxMixtureHETorch(inputdim, self.k, self.g,
+                                   layers=self.layers,
+                                   optimizer=optimizer)
     else:
-      return CoxMixtureHeterogenousEffects(inputdim, optimizer=optimizer)
+      return CoxMixtureHETorch(inputdim, self.k, self.g, 
+                               optimizer=optimizer)
 
   def fit(self, x, t, e, a, vsize=0.15, val_data=None,
           iters=1, learning_rate=1e-3, batch_size=100,
@@ -137,17 +146,17 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None,
                                                     vsize, val_data,
                                                     random_state)
 
-    x_train, t_train, e_train, a_train, x_val, t_val, e_val, a_val = processed_data
+    x_tr, t_tr, e_tr, a_tr, x_vl, t_vl, e_vl, a_vl = processed_data
 
     #Todo: Change this somehow. The base design shouldn't depend on child
 
-    inputdim = x_train.shape[-1]
+    inputdim = x_tr.shape[-1]
 
     model = self._gen_torch_model(inputdim, optimizer)
 
     model, _ = train_cmhe(model,
-                          (x_train, t_train, e_train, a_train),
-                          (x_val, t_val, e_val, a_val),
+                          (x_tr, t_tr, e_tr, a_tr),
+                          (x_vl, t_vl, e_vl, a_vl),
                           epochs=iters,
                           lr=learning_rate,
                           bs=batch_size,
@@ -198,3 +207,30 @@ def predict_survival(self, x, a, t=None):
     scores = predict_survival(self.torch_model, x, a, t)
     return scores
 
+  def predict_latent_z(self, x):
+
+    """Returns the estimated latent base survival group \( z \) given the confounders \( x \)."""
+
+    x = self._preprocess_test_data(x)
+
+    if self.fitted:
+      scores = predict_latent_z(self.torch_model, x)
+      return scores
+    else:
+      raise Exception("The model has not been fitted yet. Please fit the " +
+                      "model using the `fit` method on some training data " +
+                      "before calling `predict_latent_z`.")
+
+  def predict_latent_phi(self, x):
+
+    """Returns the estimated latent treatment effect group \( \phi \) given the confounders \( x \)."""
+
+    x = self._preprocess_test_data(x)
+
+    if self.fitted:
+      scores = predict_latent_phi(self.torch_model, x)
+      return scores
+    else:
+      raise Exception("The model has not been fitted yet. Please fit the " +
+                      "model using the `fit` method on some training data " +
+                      "before calling `predict_latent_phi`.")
\ No newline at end of file
diff --git a/auton_survival/models/cmhe/cmhe_torch.py b/auton_survival/models/cmhe/cmhe_torch.py
index 891f24b..3dcc725 100644
--- a/auton_survival/models/cmhe/cmhe_torch.py
+++ b/auton_survival/models/cmhe/cmhe_torch.py
@@ -1,6 +1,6 @@
 import torch
 
-class CoxMixtureHeterogenousEffects(torch.nn.Module):
+class CoxMixtureHETorch(torch.nn.Module):
   """PyTorch model definition of the Cox Mixture with Hereogenous Effects Model.
 
   Cox Mixtures with Heterogenous Effects involves the assuming that the
@@ -14,7 +14,7 @@ class CoxMixtureHeterogenousEffects(torch.nn.Module):
 
   def __init__(self, k, g, inputdim, hidden=None):
 
-    super(CoxMixtureHeterogenousEffects, self).__init__()
+    super(CoxMixtureHETorch, self).__init__()
 
     assert isinstance(k, int)
 
@@ -57,13 +57,11 @@ def forward(self, x, a):
 
     return logp_jointlatent_gate, logp_joint_hrs
 
-class DeepCoxMixtureHeterogenousEffects(CoxMixtureHeterogenousEffects):
+class DeepCoxMixtureHETorch(CoxMixtureHETorch):
 
   def __init__(self, k, g, inputdim, hidden):
 
-    super(DeepCoxMixtureHeterogenousEffects, self).__init__(k, g,
-                                                            inputdim,
-                                                            hidden)
+    super(DeepCoxMixtureHETorch, self).__init__(k, g, inputdim, hidden)
 
     # Get rich feature representations of the covariates
     self.embedding = torch.nn.Sequential(torch.nn.Linear(inputdim, hidden),
diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py
index 687fa28..530b40c 100644
--- a/auton_survival/models/cmhe/cmhe_utilities.py
+++ b/auton_survival/models/cmhe/cmhe_utilities.py
@@ -315,3 +315,21 @@ def predict_survival(model, x, a, t):
 
     predictions.append((gates*expert_outputs).sum(axis=1))
   return np.array(predictions).T
+
+def predict_latent_z(model, x):
+
+  model, _ = model
+  gates = model.model.embedding(x)
+
+  z_gate_probs = torch.exp(gates).sum(axis=2).detach().numpy()
+
+  return z_gate_probs
+
+def predict_latent_phi(model, x):
+
+  model, _ = model
+  gates, _ = model(x)
+
+  phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy()
+
+  return phi_gate_probs