Skip to content

Commit

Permalink
Fix filter points option (gwastro#3477)
Browse files Browse the repository at this point in the history
* Making filter points option to work again

* zip to list
  • Loading branch information
bhooshan-gadre authored and OliverEdy committed Apr 3, 2023
1 parent 190b0b5 commit 498dd7a
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions bin/bank/pycbc_geom_aligned_bank
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ metricParams = pycbc.tmpltbank.determine_eigen_directions(metricParams)
logging.info("Calculating covariance matrix")

vals = pycbc.tmpltbank.estimate_mass_range(1000000, massRangeParams, \
metricParams, metricParams.fUpper, covary=False)
metricParams, metricParams.fUpper, covary=False)
cov = numpy.cov(vals)
evalsCV, evecsCV = numpy.linalg.eig(cov)
evecsCVdict = {}
Expand Down Expand Up @@ -355,19 +355,12 @@ if opts.filter_points:
# Create a large set of points and map to xi_i to give a starting point when
# mapping from xi_i to masses and spins
# Use the EM constraint only if asked to do so
rTotmass, rEta, rBeta, rSigma, rGamma, rSpin1z, rSpin2z = \
rMass1, rMass2, rSpin1z, rSpin2z = \
pycbc.tmpltbank.get_random_mass(2000000, massRangeParams)
diff = (rTotmass*rTotmass * (1-4*rEta))**0.5
rMass1 = (rTotmass + diff)/2.
rMass2 = (rTotmass - diff)/2.
rChis = (rSpin1z + rSpin2z)/2.

rXis = pycbc.tmpltbank.get_cov_params(rTotmass, rEta, rBeta, rSigma,
rGamma, rChis, metricParams, metricParams.fUpper)
rXis = pycbc.tmpltbank.get_cov_params(rMass1, rMass2, rSpin1z, rSpin2z, metricParams, metricParams.fUpper)

xis = (numpy.array(rXis)).T
physMasses = numpy.array([rTotmass, rEta, rSpin1z, rSpin2z])
physMasses = physMasses.T
f0 = opts.f0
order = opts.pn_order
maxmass1 = opts.max_mass1
Expand All @@ -386,10 +379,10 @@ if opts.filter_points:
logging.info("Setting up KDtree to compute distances.")
if opts.threed_lattice:
tree = spatial.KDTree(xis[:,:3])
xi_points = zip(v1s,v2s,v3s)
xi_points = list(zip(v1s,v2s,v3s))
else:
tree = spatial.KDTree(xis[:,:2])
xi_points = zip(v1s,v2s)
xi_points = list(zip(v1s,v2s))

logging.info("Computing distances using KDtree.")
dists, pointargs = tree.query(xi_points)
Expand Down Expand Up @@ -453,7 +446,7 @@ h5file['metric_evals'] = metricParams.evals[metricParams.fUpper]
h5file['metric_evecs'] = metricParams.evecs[metricParams.fUpper]

h5file.close()

# And begin dag generation
# First: Set up the config parser.
cp = ConfigParser.ConfigParser()
Expand Down

0 comments on commit 498dd7a

Please sign in to comment.