From f48a3822035d8a946c7a9479c253a559b870ca06 Mon Sep 17 00:00:00 2001 From: Francisco Barros Date: Sun, 20 Oct 2019 19:31:05 +0100 Subject: [PATCH] Update MetropolisHastings.py --- hive/app/domain/MetropolisHastings.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/hive/app/domain/MetropolisHastings.py b/hive/app/domain/MetropolisHastings.py index 3b9f632f..4cf6cf23 100644 --- a/hive/app/domain/MetropolisHastings.py +++ b/hive/app/domain/MetropolisHastings.py @@ -8,9 +8,9 @@ # region module public functions -def metropolis_algorithm(adj_matrix, ddv, column_major_in=False, column_major_out=True): +def metropolis_algorithm(adj, ddv, column_major_in=False, column_major_out=True): """ - :param adj_matrix: any adjacency matrix (list of lits) provided by the user in row major form + :param adj: any adjacency matrix (list of lits) provided by the user in row major form :type list> :param ddv: a stochastic desired distribution vector :type list @@ -23,15 +23,21 @@ def metropolis_algorithm(adj_matrix, ddv, column_major_in=False, column_major_ou """ ddv = np.asarray(ddv) - adj_matrix = np.asarray(adj_matrix) + adj = np.asarray(adj) + + # Input checking + if ddv.shape[0] != adj.shape[1]: + raise DistributionShapeError("distribution shape: {}, proposal matrix shape: {}".format(ddv.shape, adj.shape)) + if adj.shape[0] != adj.shape[1]: + raise MatrixNotSquareError("rows: {}, columns: {}, expected square matrix".format(adj.shape[0], adj.shape[1])) if column_major_in: - adj_matrix = adj_matrix.transpose() + adj = adj.transpose() - shape = adj_matrix.shape - size = adj_matrix.shape[0] + shape = adj.shape + size = adj.shape[0] - rw = _construct_random_walk_matrix(adj_matrix, shape, size) + rw = _construct_random_walk_matrix(adj, shape, size) r = _construct_rejection_matrix(ddv, rw, shape, size) transition_matrix = np.zeros(shape=shape)