Skip to content

Commit

Permalink
Merge pull request #4160 from ye-luo/sposet-name
Browse files Browse the repository at this point in the history
Require named SPOSet
  • Loading branch information
ye-luo authored Aug 10, 2022
2 parents a58ef79 + 42119b1 commit 53a686d
Show file tree
Hide file tree
Showing 57 changed files with 316 additions and 295 deletions.
1 change: 1 addition & 0 deletions src/Estimators/OneBodyDensityMatrices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ OneBodyDensityMatrices::OneBodyDensityMatrices(OneBodyDensityMatricesInput&& obd
input_(obdmi),
lattice_(lattice),
species_(species),
basis_functions_("OneBodyDensityMatrices::basis"),
timers_("OneBodyDensityMatrix")
{
my_name_ = "OneBodyDensityMatrices";
Expand Down
11 changes: 5 additions & 6 deletions src/QMCHamiltonians/DensityMatrices1B.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ using MatrixOperators::product;
using MatrixOperators::product_AtB;


DensityMatrices1B::DensityMatrices1B(ParticleSet& P,
TrialWaveFunction& psi,
ParticleSet* Pcl)
: lattice_(P.getLattice()), Psi(psi), Pq(P), Pc(Pcl)
DensityMatrices1B::DensityMatrices1B(ParticleSet& P, TrialWaveFunction& psi, ParticleSet* Pcl)
: basis_functions("DensityMatrices1B::basis"), lattice_(P.getLattice()), Psi(psi), Pq(P), Pc(Pcl)
{
reset();
}
Expand Down Expand Up @@ -232,7 +230,8 @@ void DensityMatrices1B::set_state(xmlNodePtr cur)
metric = 1.0 / samples;
}
else
throw std::runtime_error("DensityMatrices1B::set_state invalid integrator\n valid options are: uniform_grid, uniform, density");
throw std::runtime_error(
"DensityMatrices1B::set_state invalid integrator\n valid options are: uniform_grid, uniform, density");

if (evstr == "loop")
evaluator = loop;
Expand All @@ -257,7 +256,7 @@ void DensityMatrices1B::set_state(xmlNodePtr cur)
for (int i = 0; i < sposets.size(); ++i)
{
auto& spomap = Psi.getSPOMap();
auto spo_it = spomap.find(sposets[i]);
auto spo_it = spomap.find(sposets[i]);
if (spo_it == spomap.end())
throw std::runtime_error("DensityMatrices1B::put sposet " + sposets[i] + " does not exist.");
basis_functions.add(spo_it->second->makeClone());
Expand Down
6 changes: 3 additions & 3 deletions src/QMCHamiltonians/tests/test_ecp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,10 @@ TEST_CASE("Evaluate_soecp", "[hamiltonian]")
kdn.resize(nelec);
kdn[0] = PosType(2, 2, 2);

auto spo_up = std::make_unique<FreeOrbital>(kup);
auto spo_dn = std::make_unique<FreeOrbital>(kdn);
auto spo_up = std::make_unique<FreeOrbital>("free_orb_up", kup);
auto spo_dn = std::make_unique<FreeOrbital>("free_orb_up", kdn);

auto spinor_set = std::make_unique<SpinorSet>();
auto spinor_set = std::make_unique<SpinorSet>("free_orb_spinor");
spinor_set->set_spos(std::move(spo_up), std::move(spo_dn));
QMCTraits::IndexType norb = spinor_set->getOrbitalSetSize();
REQUIRE(norb == 1);
Expand Down
13 changes: 11 additions & 2 deletions src/QMCWaveFunctions/BsplineFactory/BsplineReaderBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ void BsplineReaderBase::setCommon(xmlNodePtr cur)
std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePtr cur)
{
int ns(0);
std::string spo_object_name;
OhmmsAttributeSet a;
a.add(ns, "size");
a.add(spo_object_name, "name");
a.add(spo_object_name, "id");
a.put(cur);

if (ns == 0)
Expand All @@ -129,11 +132,17 @@ std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePt
vals.myName = make_bandgroup_name(mybuilder->getName(), spin, mybuilder->twist_num_, mybuilder->TileMatrix, 0, ns);
vals.selectBands(fullband, 0, ns, false);

return create_spline_set(spin, vals);
return create_spline_set(spo_object_name, spin, vals);
}

std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePtr cur, SPOSetInputInfo& input_info)
{
std::string spo_object_name;
OhmmsAttributeSet a;
a.add(spo_object_name, "name");
a.add(spo_object_name, "id");
a.put(cur);

if (spo2band.empty())
spo2band.resize(mybuilder->states.size());

Expand All @@ -156,7 +165,7 @@ std::unique_ptr<SPOSet> BsplineReaderBase::create_spline_set(int spin, xmlNodePt
vals.selectBands(fullband, spo2band[spin][input_info.min_index()], input_info.max_index() - input_info.min_index(),
false);

return create_spline_set(spin, vals);
return create_spline_set(spo_object_name, spin, vals);
}

/** build index tables to map a state to band with k-point folidng
Expand Down
6 changes: 4 additions & 2 deletions src/QMCWaveFunctions/BsplineFactory/BsplineReaderBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ struct BsplineReaderBase

bspline->HalfG = 0;
TinyVector<int, 3> bconds = mybuilder->TargetPtcl.getLattice().BoxBConds;
if (!bspline->is_complex)
if (!bspline->isComplex())
{
//no k-point folding, single special k point (G, L ...)
TinyVector<double, 3> twist0 = mybuilder->TwistAngles[bandgroup.TwistIndex];
Expand Down Expand Up @@ -165,7 +165,9 @@ struct BsplineReaderBase

/** create the actual spline sets
*/
virtual std::unique_ptr<SPOSet> create_spline_set(int spin, const BandInfoGroup& bandgroup) = 0;
virtual std::unique_ptr<SPOSet> create_spline_set(const std::string& my_name,
int spin,
const BandInfoGroup& bandgroup) = 0;

/** setting common parameters
*/
Expand Down
11 changes: 4 additions & 7 deletions src/QMCWaveFunctions/BsplineFactory/BsplineSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class BsplineSet : public SPOSet
{
protected:
static const int D = DIM;
///true if the computed values are complex
bool is_complex;
///Index of this adoptor, when multiple adoptors are used for NUMA or distributed cases
size_t MyIndex;
///first index of the SPOs this Spline handles
Expand All @@ -56,13 +54,12 @@ class BsplineSet : public SPOSet
aligned_vector<int> BandIndexMap;
///band offsets used for communication
std::vector<int> offset;
///keyword used to match hdf5
std::string KeyWord;

public:
BsplineSet(bool use_OMP_offload = false, bool ion_deriv = false, bool optimizable = false)
: SPOSet(use_OMP_offload, ion_deriv, optimizable), is_complex(false), MyIndex(0), first_spo(0), last_spo(0)
{}
BsplineSet(const std::string& my_name) : SPOSet(my_name), MyIndex(0), first_spo(0), last_spo(0) {}

virtual bool isComplex() const = 0;
virtual std::string getKeyword() const = 0;

auto& getHalfG() const { return HalfG; }

Expand Down
9 changes: 4 additions & 5 deletions src/QMCWaveFunctions/BsplineFactory/HybridRepCplx.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ class HybridRepCplx : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
using SPLINEBASE::myV;

public:
HybridRepCplx()
{
this->className = "Hybrid" + this->className;
this->KeyWord = "Hybrid" + this->KeyWord;
}
HybridRepCplx(const std::string& my_name) : SPLINEBASE(my_name) {}

std::string getClassName() const final { return "Hybrid" + SPLINEBASE::getClassName(); }
std::string getKeyword() const final { return "Hybrid" + SPLINEBASE::getKeyword(); }

std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<HybridRepCplx>(*this); }

Expand Down
9 changes: 4 additions & 5 deletions src/QMCWaveFunctions/BsplineFactory/HybridRepReal.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,10 @@ class HybridRepReal : public SPLINEBASE, private HybridRepCenterOrbitals<typenam
using SPLINEBASE::PrimLattice;

public:
HybridRepReal()
{
this->className = "Hybrid" + this->className;
this->KeyWord = "Hybrid" + this->KeyWord;
}
HybridRepReal(const std::string& my_name) : SPLINEBASE(my_name) {}

std::string getClassName() const final { return "Hybrid" + SPLINEBASE::getClassName(); }
std::string getKeyword() const final { return "Hybrid" + SPLINEBASE::getKeyword(); }

std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<HybridRepReal>(*this); }

Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/HybridRepSetReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ class HybridRepSetReader : public SplineSetReader<SA>
splineData_r[ip] = all_vals[idx][ip][lm];
atomic_spline_r = einspline::create(atomic_spline_r, 0.0, spline_radius, spline_npoints, splineData_r.data(),
((lm == 0) || (lm > 3)));
if (!bspline->is_complex)
if (!bspline->isComplex())
{
mycenter.set_spline(atomic_spline_r, lm, iorb);
einspline::destroy(atomic_spline_r);
Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

namespace qmcplusplus
{
template<typename ST>
SplineC2C<ST>::SplineC2C(const SplineC2C& in) = default;

template<typename ST>
inline void SplineC2C<ST>::set_spline(SingleSplineType* spline_r,
SingleSplineType* spline_i,
Expand Down
13 changes: 7 additions & 6 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2C.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ class SplineC2C : public BsplineSet
ghContainer_type mygH;

public:
SplineC2C()
{
is_complex = true;
className = "SplineC2C";
KeyWord = "SplineC2C";
}
SplineC2C(const std::string& my_name) : BsplineSet(my_name) {}

SplineC2C(const SplineC2C& in);
virtual std::string getClassName() const override { return "SplineC2C"; }
virtual std::string getKeyword() const override { return "SplineC2C"; }
bool isComplex() const override { return true; };


std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<SplineC2C>(*this); }

Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2COMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

namespace qmcplusplus
{
template<typename ST>
SplineC2COMPTarget<ST>::SplineC2COMPTarget(const SplineC2COMPTarget& in) = default;

namespace C2C
{
template<typename ST, typename TT>
Expand Down
17 changes: 9 additions & 8 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2COMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,19 @@ class SplineC2COMPTarget : public BsplineSet
ghContainer_type mygH;

public:
SplineC2COMPTarget()
: BsplineSet(true),
SplineC2COMPTarget(const std::string& my_name)
: BsplineSet(my_name),
offload_timer_(*timer_manager.createTimer("SplineC2COMPTarget::offload", timer_level_fine)),
GGt_offload(std::make_shared<OffloadVector<ST>>(9)),
PrimLattice_G_offload(std::make_shared<OffloadVector<ST>>(9))
{
is_complex = true;
className = "SplineC2COMPTarget";
KeyWord = "SplineC2C";
}
{}

SplineC2COMPTarget(const SplineC2COMPTarget& in);

SplineC2COMPTarget(const SplineC2COMPTarget& in) = default;
virtual std::string getClassName() const override { return "SplineC2COMPTarget"; }
virtual std::string getKeyword() const override { return "SplineC2C"; }
bool isComplex() const override { return true; };
bool isOMPoffload() const override { return true; }

void createResource(ResourceCollection& collection) const override
{
Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

namespace qmcplusplus
{
template<typename ST>
SplineC2R<ST>::SplineC2R(const SplineC2R& in) = default;

template<typename ST>
inline void SplineC2R<ST>::set_spline(SingleSplineType* spline_r,
SingleSplineType* spline_i,
Expand Down
12 changes: 6 additions & 6 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2R.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ class SplineC2R : public BsplineSet
ghContainer_type mygH;

public:
SplineC2R() : nComplexBands(0)
{
is_complex = true;
className = "SplineC2R";
KeyWord = "SplineC2R";
}
SplineC2R(const std::string& my_name) : BsplineSet(my_name), nComplexBands(0) {}

SplineC2R(const SplineC2R& in);
virtual std::string getClassName() const override { return "SplineC2R"; }
virtual std::string getKeyword() const override { return "SplineC2R"; }
bool isComplex() const override { return true; };

std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<SplineC2R>(*this); }

Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2ROMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

namespace qmcplusplus
{
template<typename ST>
SplineC2ROMPTarget<ST>::SplineC2ROMPTarget(const SplineC2ROMPTarget& in) = default;

namespace C2R
{
template<typename ST, typename TT>
Expand Down
17 changes: 9 additions & 8 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2ROMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,20 @@ class SplineC2ROMPTarget : public BsplineSet
ghContainer_type mygH;

public:
SplineC2ROMPTarget()
: BsplineSet(true),
SplineC2ROMPTarget(const std::string& my_name)
: BsplineSet(my_name),
offload_timer_(*timer_manager.createTimer("SplineC2ROMPTarget::offload", timer_level_fine)),
nComplexBands(0),
GGt_offload(std::make_shared<OffloadVector<ST>>(9)),
PrimLattice_G_offload(std::make_shared<OffloadVector<ST>>(9))
{
is_complex = true;
className = "SplineC2ROMPTarget";
KeyWord = "SplineC2R";
}
{}

SplineC2ROMPTarget(const SplineC2ROMPTarget& in);

SplineC2ROMPTarget(const SplineC2ROMPTarget& in) = default;
virtual std::string getClassName() const override { return "SplineC2ROMPTarget"; }
virtual std::string getKeyword() const override { return "SplineC2R"; }
bool isComplex() const override { return true; };
bool isOMPoffload() const override { return true; }

void createResource(ResourceCollection& collection) const override
{
Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

namespace qmcplusplus
{
template<typename ST>
SplineR2R<ST>::SplineR2R(const SplineR2R& in) = default;

template<typename ST>
inline void SplineR2R<ST>::set_spline(SingleSplineType* spline_r,
SingleSplineType* spline_i,
Expand Down
12 changes: 6 additions & 6 deletions src/QMCWaveFunctions/BsplineFactory/SplineR2R.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ class SplineR2R : public BsplineSet
ghContainer_type mygH;

public:
SplineR2R()
{
is_complex = false;
className = "SplineR2R";
KeyWord = "SplineR2R";
}
SplineR2R(const std::string& my_name) : BsplineSet(my_name) {}

SplineR2R(const SplineR2R& in);
virtual std::string getClassName() const override { return "SplineR2R"; }
virtual std::string getKeyword() const override { return "SplineR2R"; }
bool isComplex() const override { return false; };

std::unique_ptr<SPOSet> makeClone() const override { return std::make_unique<SplineR2R>(*this); }

Expand Down
Loading

0 comments on commit 53a686d

Please sign in to comment.