Skip to content

Commit

Permalink
Adding spin dependent one-body jastrow
Browse files Browse the repository at this point in the history
  • Loading branch information
shivupa committed Jul 31, 2021
1 parent 7fff80a commit 62f7886
Show file tree
Hide file tree
Showing 8 changed files with 1,557 additions and 41 deletions.
744 changes: 744 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1Spin.h

Large diffs are not rendered by default.

189 changes: 148 additions & 41 deletions src/QMCWaveFunctions/Jastrow/RadialJastrowBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "RadialJastrowBuilder.h"

#include "QMCWaveFunctions/Jastrow/J1OrbitalSoA.h"
#include "QMCWaveFunctions/Jastrow/J1Spin.h"
#include "QMCWaveFunctions/Jastrow/J2OrbitalSoA.h"

#if defined(ENABLE_OFFLOAD)
Expand Down Expand Up @@ -41,19 +42,19 @@

namespace qmcplusplus
{

// quick helper class to allow use of RPA
class RPAFunctor
{};

// helper class to simplify and localize ugly ifdef stuff for types
template<class RadFuncType, unsigned Implementation>
template<class RadFuncType, unsigned Implementation = RadialJastrowBuilder::detail::CPU>
class JastrowTypeHelper
{
public:
using J1OrbitalType = J1OrbitalSoA<RadFuncType>;
using J2OrbitalType = J2OrbitalSoA<RadFuncType>;
using DiffJ2OrbitalType = DiffTwoBodyJastrowOrbital<RadFuncType>;
using J1Type = J1OrbitalSoA<RadFuncType>;
using J1SpinType = J1Spin<RadFuncType>;
using J2Type = J2OrbitalSoA<RadFuncType>;
using DiffJ2Type = DiffTwoBodyJastrowOrbital<RadFuncType>;
};

#if defined(QMC_CUDA)
Expand All @@ -62,30 +63,20 @@ class JastrowTypeHelper<BsplineFunctor<RadialJastrowBuilder::RealType>, RadialJa
{
public:
using RadFuncType = BsplineFunctor<RadialJastrowBuilder::RealType>;
using J1OrbitalType = OneBodyJastrowOrbitalBspline<RadFuncType>;
using J2OrbitalType = TwoBodyJastrowOrbitalBspline<RadFuncType>;
using DiffJ2OrbitalType = DiffTwoBodyJastrowOrbital<RadFuncType>;
using J1Type = OneBodyJastrowOrbitalBspline<RadFuncType>;
using J2Type = TwoBodyJastrowOrbitalBspline<RadFuncType>;
using DiffJ2Type = DiffTwoBodyJastrowOrbital<RadFuncType>;
};
#endif

template<>
class JastrowTypeHelper<BsplineFunctor<RadialJastrowBuilder::RealType>, RadialJastrowBuilder::detail::CPU>
{
public:
using RadFuncType = BsplineFunctor<RadialJastrowBuilder::RealType>;
using J1OrbitalType = J1OrbitalSoA<RadFuncType>;
using J2OrbitalType = J2OrbitalSoA<RadFuncType>;
using DiffJ2OrbitalType = DiffTwoBodyJastrowOrbital<RadFuncType>;
};

#if defined(ENABLE_OFFLOAD)
template<>
class JastrowTypeHelper<BsplineFunctor<RadialJastrowBuilder::RealType>, RadialJastrowBuilder::detail::OMPTARGET>
{
public:
using RadFuncType = BsplineFunctor<RadialJastrowBuilder::RealType>;
using J2OrbitalType = J2OMPTarget<RadFuncType>;
using DiffJ2OrbitalType = DiffTwoBodyJastrowOrbital<RadFuncType>;
using J2Type = J2OMPTarget<RadFuncType>;
using DiffJ2Type = DiffTwoBodyJastrowOrbital<RadFuncType>;
};
#endif

Expand Down Expand Up @@ -168,15 +159,15 @@ template<class RadFuncType, unsigned Implementation>
std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr cur)
{
ReportEngine PRE(ClassName, "createJ2(xmlNodePtr)");
using Real = typename RadFuncType::real_type;
using J2OrbitalType = typename JastrowTypeHelper<RadFuncType, Implementation>::J2OrbitalType;
using DiffJ2OrbitalType = typename JastrowTypeHelper<RadFuncType, Implementation>::DiffJ2OrbitalType;
using Real = typename RadFuncType::real_type;
using J2Type = typename JastrowTypeHelper<RadFuncType, Implementation>::J2Type;
using DiffJ2Type = typename JastrowTypeHelper<RadFuncType, Implementation>::DiffJ2Type;

XMLAttrString input_name(cur, "name");
std::string j2name = input_name.empty() ? "J2_" + Jastfunction : input_name;
SpeciesSet& species(targetPtcl.getSpeciesSet());
auto J2 = std::make_unique<J2OrbitalType>(j2name, targetPtcl);
auto dJ2 = std::make_unique<DiffJ2OrbitalType>(targetPtcl);
auto J2 = std::make_unique<J2Type>(j2name, targetPtcl);
auto dJ2 = std::make_unique<DiffJ2Type>(targetPtcl);

std::string init_mode("0");
{
Expand Down Expand Up @@ -352,13 +343,13 @@ template<class RadFuncType, unsigned Implementation>
std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1(xmlNodePtr cur)
{
ReportEngine PRE(ClassName, "createJ1(xmlNodePtr)");
using Real = typename RadFuncType::real_type;
using J1OrbitalType = typename JastrowTypeHelper<RadFuncType, Implementation>::J1OrbitalType;
using Real = typename RadFuncType::real_type;
using J1Type = typename JastrowTypeHelper<RadFuncType, Implementation>::J1Type;

XMLAttrString input_name(cur, "name");
std::string jname = input_name.empty() ? Jastfunction : input_name;

auto J1 = std::make_unique<J1OrbitalType>(jname, *SourcePtcl, targetPtcl);
auto J1 = std::make_unique<J1Type>(jname, *SourcePtcl, targetPtcl);

xmlNodePtr kids = cur->xmlChildrenNode;

Expand Down Expand Up @@ -444,12 +435,12 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1(xmlNodePtr
template<>
std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1<RPAFunctor>(xmlNodePtr cur)
{
using Real = RealType;
using Real = RealType;
using SplineEngineType = CubicBspline<Real, LINEAR_1DGRID, FIRSTDERIV_CONSTRAINTS>;
using RadFunctorType = CubicSplineSingle<Real, SplineEngineType>;
using GridType = LinearGrid<Real>;
using HandlerType = LRHandlerBase;
using J1OrbitalType = J1OrbitalSoA<RadFunctorType>;
using J1Type = J1OrbitalSoA<RadFunctorType>;

std::string input_name;
std::string rpafunc = "RPA";
Expand Down Expand Up @@ -489,18 +480,18 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1<RPAFunctor
}
myHandler->Breakup(targetPtcl, Rs);

Real Rcut = myHandler->get_rc() - 0.1;
Real Rcut = myHandler->get_rc() - 0.1;
GridType* myGrid = new GridType;
int npts = static_cast<int>(Rcut / 0.01) + 1;
myGrid->set(0, Rcut, npts);

//create the numerical functor
auto nfunc = std::make_unique<RadFunctorType>();
auto nfunc = std::make_unique<RadFunctorType>();
ShortRangePartAdapter<Real>* SRA = new ShortRangePartAdapter<Real>(myHandler);
SRA->setRmax(Rcut);
nfunc->initialize(SRA, myGrid);

auto J1 = std::make_unique<J1OrbitalType>(jname, *SourcePtcl, targetPtcl);
auto J1 = std::make_unique<J1Type>(jname, *SourcePtcl, targetPtcl);

SpeciesSet& sSet = SourcePtcl->getSpeciesSet();
for (int ig = 0; ig < sSet.getTotalNum(); ig++)
Expand All @@ -512,6 +503,89 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1<RPAFunctor
return J1;
}

template<class RadFuncType, unsigned Implementation>
std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1Spin(xmlNodePtr cur)
{
ReportEngine PRE(ClassName, "createJ1Spin(xmlNodePtr)");
using RT = typename RadFuncType::real_type;
using J1Type = typename JastrowTypeHelper<RadFuncType>::J1SpinType;

XMLAttrString input_name(cur, "name");
std::string jname = input_name.empty() ? Jastfunction : input_name;

std::unique_ptr<J1Type> J1 = std::make_unique<J1Type>(jname, *SourcePtcl, targetPtcl);

xmlNodePtr kids = cur->xmlChildrenNode;

// Find the number of the source species
SpeciesSet& sSet = SourcePtcl->getSpeciesSet();
SpeciesSet& tSet = targetPtcl.getSpeciesSet();
bool success = false;
bool Opt(true);
while (kids != NULL)
{
std::string kidsname = (char*)kids->name;
tolower(kidsname);
if (kidsname == "correlation")
{
std::string speciesA;
std::string speciesB;
RealType cusp(0);
OhmmsAttributeSet rAttrib;
rAttrib.add(speciesA, "elementType");
rAttrib.add(speciesA, "speciesA");
rAttrib.add(speciesB, "speciesB");
rAttrib.add(cusp, "cusp");
rAttrib.put(kids);
auto functor = std::make_unique<RadFuncType>();
functor->setPeriodic(SourcePtcl->Lattice.SuperCellEnum != SUPERCELL_OPEN);
functor->cutoff_radius = targetPtcl.Lattice.WignerSeitzRadius;
functor->setCusp(cusp);
const int ig = sSet.findSpecies(speciesA);
const int jg = speciesB.size() ? tSet.findSpecies(speciesB) : -1;
if (ig == sSet.getTotalNum())
{
PRE.error("species " + speciesA + " requested for Jastrow " + jname + " does not exist in ParticleSet " +
SourcePtcl->getName(),
true);
}
if (jg == tSet.getTotalNum())
{
PRE.error("species " + speciesB + " requested for Jastrow " + jname + " does not exist in ParticleSet " +
targetPtcl.getName(),
true);
}
app_summary() << " Radial function for species: " << speciesA << " - " << speciesB << std::endl;
functor->put(kids);
app_summary() << std::endl;
if (is_manager())
{
char fname[128];
if (speciesB.size())
sprintf(fname, "%s.%s.%s%s.g%03d.dat", jname.c_str(), NameOpt.c_str(), speciesA.c_str(), speciesB.c_str(),
getGroupID());
else
sprintf(fname, "%s.%s.%s.g%03d.dat", jname.c_str(), NameOpt.c_str(), speciesA.c_str(), getGroupID());
std::ofstream os(fname);
print(*functor.get(), os);
}
J1->addFunc(ig, std::move(functor), jg);
success = true;
}
kids = kids->next;
}
if (success)
{
J1->setOptimizable(Opt);
return J1;
}
else
{
PRE.error("BsplineJastrowBuilder failed to add an One-Body Jastrow.");
return std::unique_ptr<WaveFunctionComponent>();
}
}


std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::buildComponent(xmlNodePtr cur)
{
Expand All @@ -521,7 +595,7 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::buildComponent(xmlN
aAttrib.add(NameOpt, "name");
aAttrib.add(TypeOpt, "type");
aAttrib.add(Jastfunction, "function");
aAttrib.add(SpinOpt, "spin");
aAttrib.add(SpinOpt, "spin", {"no", "yes"});
#if defined(ENABLE_OFFLOAD)
aAttrib.add(useGPU, "gpu", {"yes", "no"});
#endif
Expand All @@ -539,16 +613,34 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::buildComponent(xmlN
// it's a one body jastrow factor
if (Jastfunction == "bspline")
{
if (SpinOpt == "yes")
{
#if defined(QMC_CUDA)
return createJ1<BsplineFunctor<RealType>, detail::CUDA_LEGACY>(cur);
return createJ1Spin<BsplineFunctor<RealType>, detail::CUDA_LEGACY>(cur);
#else
return createJ1<BsplineFunctor<RealType>>(cur);
return createJ1Spin<BsplineFunctor<RealType>>(cur);
#endif
}
else
{
#if defined(QMC_CUDA)
return createJ1<BsplineFunctor<RealType>, detail::CUDA_LEGACY>(cur);
#else
return createJ1<BsplineFunctor<RealType>>(cur);
#endif
}
}
else if (Jastfunction == "pade")
{
guardAgainstPBC();
return createJ1<PadeFunctor<RealType>>(cur);
if (SpinOpt == "yes")
{
return createJ1Spin<PadeFunctor<RealType>>(cur);
}
else
{
return createJ1<PadeFunctor<RealType>>(cur);
}
}
else if (Jastfunction == "pade2")
{
Expand All @@ -558,11 +650,25 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::buildComponent(xmlN
else if (Jastfunction == "shortrangecusp")
{
//guardAgainstPBC(); // is this needed?
return createJ1<ShortRangeCuspFunctor<RealType>>(cur);
if (SpinOpt == "yes")
{
return createJ1Spin<ShortRangeCuspFunctor<RealType>>(cur);
}
else
{
return createJ1<ShortRangeCuspFunctor<RealType>>(cur);
}
}
else if (Jastfunction == "user")
{
return createJ1<UserFunctor<RealType>>(cur);
if (SpinOpt == "yes")
{
return createJ1Spin<UserFunctor<RealType>>(cur);
}
else
{
return createJ1<UserFunctor<RealType>>(cur);
}
}
else if (Jastfunction == "rpa")
{
Expand All @@ -587,8 +693,9 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::buildComponent(xmlN
#if defined(ENABLE_OFFLOAD)
if (useGPU == "yes")
{
static_assert(std::is_same<JastrowTypeHelper<BsplineFunctor<RealType>, OMPTARGET>::J2OrbitalType,
J2OMPTarget<BsplineFunctor<RealType>>>::value, "check consistent type");
static_assert(std::is_same<JastrowTypeHelper<BsplineFunctor<RealType>, OMPTARGET>::J2Type,
J2OMPTarget<BsplineFunctor<RealType>>>::value,
"check consistent type");
app_summary() << " Running on an accelerator via OpenMP offload." << std::endl;
return createJ2<BsplineFunctor<RealType>, detail::OMPTARGET>(cur);
}
Expand Down
3 changes: 3 additions & 0 deletions src/QMCWaveFunctions/Jastrow/RadialJastrowBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class RadialJastrowBuilder : public WaveFunctionComponentBuilder
template<class RadFuncType, unsigned Implementation = detail::CPU>
std::unique_ptr<WaveFunctionComponent> createJ1(xmlNodePtr cur);

template<class RadFuncType, unsigned Implementation = detail::CPU>
std::unique_ptr<WaveFunctionComponent> createJ1Spin(xmlNodePtr cur);

template<class RadFuncType, unsigned Implementation = detail::CPU>
std::unique_ptr<WaveFunctionComponent> createJ2(xmlNodePtr cur);

Expand Down
1 change: 1 addition & 0 deletions src/QMCWaveFunctions/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ set(JASTROW_SRC
test_pade_jastrow.cpp
test_short_range_cusp_jastrow.cpp
test_J1OrbitalSoA.cpp
test_J1Spin.cpp
test_J2_bspline.cpp
test_DiffTwoBodyJastrowOrbital.cpp)
set(DETERMINANT_SRC
Expand Down
Loading

0 comments on commit 62f7886

Please sign in to comment.