Skip to content

Commit

Permalink
Merge pull request #35 from jcyannotty/develop
Browse files Browse the repository at this point in the history
updates to docs and test_trees.py
  • Loading branch information
jcyannotty authored Aug 11, 2023
2 parents 4c3102c + 920facf commit 4b9be31
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 167 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ venv/
ENV/
env.bak/
venv.bak/
test_env

# Spyder project settings
.spyderproject
Expand Down
79 changes: 28 additions & 51 deletions Taweret/mix/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def __init__(self, model_dict: dict, **kwargs):
Parameters:
----------
:param model_dict : dict
:param dict model_dict:
Dictionary of models where each item is an instance of BaseModel.
:param kwargs : dict
:param dict kwargs:
Additional arguments to pass to the constructor.
Returns:
Expand Down Expand Up @@ -136,6 +136,7 @@ def __init__(self, model_dict: dict, **kwargs):
self._is_prior_set = False
self._is_predict_run = False ### Remove ????


def evaluate(self):
'''
Evaluate the mixed-model to get a point prediction.
Expand All @@ -144,6 +145,7 @@ def evaluate(self):

raise Exception("Not applicable for trees.")


def evaluate_weights(self):
'''
Evaluate the weight functions to get a point prediction.
Expand Down Expand Up @@ -172,8 +174,7 @@ def posterior(self):
Returns:
---------
:returns posterior:
The posterior of the error standard deviation .
:returns: The posterior of the error standard deviation .
:rtype: np.ndarray
'''
Expand All @@ -186,13 +187,12 @@ def prior(self):
prior distributions.
Parameters:
------------
:param None.
-----------
:param: None.
Returns:
---------
:returns:
A dictionary of the hyperparameters used in the model.
--------
:returns: A dictionary of the hyperparameters used in the model.
:rtype: dict
'''
Expand Down Expand Up @@ -221,7 +221,7 @@ def set_prior(self, ntree: int = 1,ntreeh:int = 1, k: float=2,power: float=2.0,b
overallnu: int = 10,inform_prior: bool = True,tauvec: bool = None,betavec: bool = None):
'''
Sets the hyperparameters in the tree and terminal node priors. Also
specifies if an informative or non-informative prior will be used.
specifies if an informative or non-informative prior will be used when mixing EFTs.
Parameters:
-----------
Expand Down Expand Up @@ -260,7 +260,7 @@ def set_prior(self, ntree: int = 1,ntreeh:int = 1, k: float=2,power: float=2.0,b
Returns:
--------
:retval: None.
:returns: None.
'''
# Extract arguments
Expand Down Expand Up @@ -415,26 +415,19 @@ def predict(self, X: np.ndarray, ci: float = 0.95):
Obtain the posterior predictive distribution of the mixed-model at a set
of inputs X.
Parameters
Parameters:
----------
:param X : np.ndarray
input parameter values
:param ci : double
credible interval width, must be value within the interval (0,1)
:param np.ndarray X: design matrix of testing inputs.
:param float ci: credible interval width, must be value within the interval (0,1).
Returns:
--------
:returns: The posterior prediction draws and summaries.
:rtype: np.ndarray, np.ndarray, np.ndarray, np.ndarray
:retval evaluated_posterior:
the posterior predictive distribution evaluated at the specified
test points
:retval mean:
posterior mean of the mixed-model at each input in X.
:retval credible_intervals:
pointwise credible intervals at each input in X.
:retval std_dev:
posterior standard deviation of the mixed-model samples.
:return value: the posterior predictive distribution evaluated at the specified test points
:return value: the posterior mean of the mixed-model at each input in X.
:return value: the pointwise credible intervals at each input in X.
:return value: the posterior standard deviation of the mixed-model samples.
'''

# Set q_lower and q_upper
Expand Down Expand Up @@ -516,23 +509,17 @@ def predict_weights(self, X: np.ndarray, ci: float = 0.95):
Parameters:
----------
:param np.ndarray X:
input parameter values
:param float ci:
credible interval width, must be value within the interval (0,1)
:param np.ndarray X: design matrix of testing inputs.
:param float ci: credible interval width, must be value within the interval (0,1).
Returns:
--------
:returns: The posterior weight function draws and summaries.
:rtype: np.ndarray, np.ndarray, np.ndarray, np.ndarray
:retval evaluated_posterior:
the posterior draws of the model weight functions at each input in X.
:retval mean:
posterior mean of the model weights at each input in X.
:retval credible_intervals:
pointwise credible intervals for the weight functions.
:retval std_dev:
posterior standard deviation of the weight functions samples.
:return value: the posterior draws of the model weight functions at each input in X.
:return value: posterior mean of the model weights at each input in X.
:return value: pointwise credible intervals for the weight functions.
:return value: posterior standard deviation of the weight functions samples.
'''

# Set q_lower and q_upper
Expand Down Expand Up @@ -600,8 +587,7 @@ def plot_prediction(self, xdim: int = 0):
Parameters:
----------
:param int xdim:
index of the column to plot against the predictions.
:param int xdim: index of the column to plot against the predictions.
Returns:
--------
Expand Down Expand Up @@ -635,14 +621,14 @@ def plot_weights(self, xdim: int = 0):
can be any column of the design matrix X, which is passed into
the predict_weights function.
Parameters
Parameters:
----------
:param int xdim:
index of the column to plot against the predictions.
:param int xdim: index of the column to plot against the predictions.
Returns:
--------
:return: None.
'''
# Check if weights are already loaded
col_list = ['red','blue','green','purple','orange']
Expand All @@ -669,15 +655,6 @@ def plot_sigma(self):
'''
Plot the posterior distribution of the observational error
standard deviation.
Parameters
----------
:param: None
index of the column to plot against the predictions.
Returns:
--------
:return: None.
'''
fig = plt.figure(figsize=(6,5))
plt.hist(self.posterior, zorder = 2)
Expand Down
32 changes: 14 additions & 18 deletions docs/source/Taweret.core.rst
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
Taweret.core namespace
======================

.. py:module:: Taweret.core
Submodules
----------


.. automodule:: Taweret.core.base_mixer
:members:
:undoc-members:
:show-inheritance:

:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.core.base_model
:members:
:undoc-members:
:show-inheritance:

:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.core.setup
:members:
:undoc-members:
:show-inheritance:

:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.core.wrappers
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:


32 changes: 14 additions & 18 deletions docs/source/Taweret.mix.rst
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
Taweret.mix namespace
=====================

.. py:module:: Taweret.mix
Submodules
----------


.. automodule:: Taweret.mix.bivariate_linear
:members:
:undoc-members:
:show-inheritance:

:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.mix.gaussian
:members:
:undoc-members:
:show-inheritance:

:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.mix.linear
:members:
:undoc-members:
:show-inheritance:

:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.mix.trees
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:


21 changes: 12 additions & 9 deletions docs/source/Taweret.models.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
Taweret.models namespace
========================

.. py:module:: Taweret.models
Submodules
----------


.. automodule:: Taweret.models.coleman_models
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.models.polynomial_models
:members:
:undoc-members:
:show-inheritance:

.. automodule:: Taweret.models.samba_models
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:


17 changes: 8 additions & 9 deletions docs/source/Taweret.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@ Subpackages
-----------

.. toctree::
:maxdepth: 4

Taweret.core
Taweret.mix
Taweret.models
Taweret.sampler
Taweret.utils
Taweret.core
Taweret.mix
Taweret.models
Taweret.sampler
Taweret.utils

Module contents
---------------

.. automodule:: Taweret
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
11 changes: 5 additions & 6 deletions docs/source/Taweret.sampler.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
Taweret.sampler namespace
=========================

.. py:module:: Taweret.sampler
Submodules
----------


.. automodule:: Taweret.sampler.likelihood_wrappers
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:


11 changes: 5 additions & 6 deletions docs/source/Taweret.utils.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
Taweret.utils namespace
=======================

.. py:module:: Taweret.utils
Submodules
----------


.. automodule:: Taweret.utils.utils
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:


3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@

# To turn off pesky smart quotes warnings
smartquotes = False

# Set the master/main rst file
master_doc = 'index'
Loading

0 comments on commit 4b9be31

Please sign in to comment.