Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 1RDM estimator data written to stat.h5 #4568

Merged
merged 4 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/Estimators/OneBodyDensityMatrices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,19 @@ void OneBodyDensityMatrices::registerOperatorEstimator(hdf_archive& file)
}
int nentries = std::accumulate(my_indexes.begin(), my_indexes.end(), 1);

int spin_data_size = 0;
if constexpr (IsComplex_t<Value>::value)
spin_data_size = 2 * basis_size_ * basis_size_;
else
spin_data_size = basis_size_ * basis_size_;

hdf_path hdf_name{my_name_};
hdf_name /= "number_matrix";
for (int s = 0; s < species_.size(); ++s)
{
h5desc_.emplace_back(hdf_name / species_.speciesName[s]);
auto& oh = h5desc_.back();
oh.set_dimensions(my_indexes, 0);
oh.set_dimensions(my_indexes, s * spin_data_size);
}
}

Expand Down
139 changes: 139 additions & 0 deletions src/Estimators/tests/test_OneBodyDensityMatrices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,53 @@ class OneBodyDensityMatricesTests
checkData(returned_data.data(), data.data(), data.size());
}

void testRegisterAndWrite(OneBodyDensityMatrices& obdm)
{
//this test is just going to set some arbitrary data, not actually calculate anything.
//then we will write this data to the hdf5
//then we will open and read the hdf5 and make sure the up and down data are properly written

obdm.data_ = getUniqueSpinData();
hdf_archive hd;
std::string test_file{"1rdm_test.hdf"};
bool okay = hd.create(test_file);
REQUIRE(okay);

obdm.registerOperatorEstimator(hd);
obdm.write(hd);
hd.close();

hdf_archive hd_read;
okay = hd_read.open(test_file);
REQUIRE(okay);
Data up_data;
Data dn_data;

hdf_path up_path = {"OneBodyDensityMatrices/number_matrix/u/value"};
hdf_path dn_path = {"OneBodyDensityMatrices/number_matrix/d/value"};

size_t nb = obdm.basis_size_;
size_t down_offset = 0;
if constexpr (IsComplex_t<OneBodyDensityMatrices::Value>::value)
{
std::array<size_t, 4> shape = {1, nb, nb, 2};
hd_read.readSlabReshaped(up_data, shape, up_path.string());
hd_read.readSlabReshaped(dn_data, shape, dn_path.string());
down_offset = nb * nb * 2; //for real and imag
}
else if constexpr (std::is_floating_point<OneBodyDensityMatrices::Value>::value)
{
std::array<size_t, 3> shape = {1, nb, nb};
hd_read.readSlabReshaped(up_data, shape, up_path.string());
hd_read.readSlabReshaped(dn_data, shape, dn_path.string());
down_offset = nb * nb;
}
//The data in obdm is the reference data, we need to check that
//what was read from the hdf5 is consistent
checkData(obdm.data_.data(), up_data.data(), up_data.size());
checkData(obdm.data_.data() + down_offset, dn_data.data(), dn_data.size());
}

void dumpData(OneBodyDensityMatrices& obdm)
{
std::cout << "Here is what is in your OneBodyDensityMatrices:\n" << NativePrint(obdm.data_) << '\n';
Expand All @@ -199,6 +246,7 @@ class OneBodyDensityMatricesTests
private:
Data getEvaluateMatrixData(OBDMI::Integrator integrator);
Data getAccumulateData();
Data getUniqueSpinData();
};

} // namespace testing
Expand Down Expand Up @@ -461,6 +509,40 @@ TEST_CASE("OneBodyDensityMatrices::evaluateMatrix", "[estimators]")
}
}

TEST_CASE("OneBodyDensityMatrices::registerAndWrite", "[estimators]")
{
using namespace testing;
using namespace onebodydensitymatrices;
using MCPWalker = OperatorEstBase::MCPWalker;
using namespace onebodydensitymatrices;

ProjectData test_project("test", ProjectData::DriverVersion::BATCH);
Communicate* comm;
comm = OHMMS::Controller;

Libxml2Document doc;
bool okay = doc.parseFromString(valid_one_body_density_matrices_input_sections[valid_obdm_input]);
if (!okay)
throw std::runtime_error("cannot parse OneBodyDensitMatricesInput section");
xmlNodePtr node = doc.getRoot();
OneBodyDensityMatricesInput obdmi(node);

std::string integrator_str =
InputSection::reverseLookupInputEnumMap(obdmi.get_integrator(), OBDMI::lookup_input_enum_value);
std::cout << "Test registerAndWrite for: " << integrator_str << '\n';

auto particle_pool = MinimalParticlePool::make_diamondC_1x1x1(comm);
auto wavefunction_pool =
MinimalWaveFunctionPool::make_diamondC_1x1x1(test_project.getRuntimeOptions(), comm, particle_pool);
auto& spomap = wavefunction_pool.getWaveFunction("wavefunction")->getSPOMap();
auto& pset_target = *(particle_pool.getParticleSet("e"));
auto& species_set = pset_target.getSpeciesSet();
OneBodyDensityMatrices obdm(std::move(obdmi), pset_target.getLattice(), species_set, spomap, pset_target);

OneBodyDensityMatricesTests<double> obdmt;
obdmt.testRegisterAndWrite(obdm);
}

namespace testing
{
// The test result data is defined down here for readability of the test code.
Expand Down Expand Up @@ -949,5 +1031,62 @@ typename OneBodyDensityMatricesTests<T>::Data OneBodyDensityMatricesTests<T>::ge
return data;
}

template<typename T>
typename OneBodyDensityMatricesTests<T>::Data OneBodyDensityMatricesTests<T>::getUniqueSpinData()
{
Data data;
if constexpr (IsComplex_t<OneBodyDensityMatrices::Value>::value)
{
data = {0.80788614, 0.03850586, 0.31091131, 0.38205625, 0.44777905, 0.21091672, 0.10743698, 0.96152914, 0.20791861,
0.6312172, 0.45626258, 0.08119136, 0.90431926, 0.11375654, 0.8067544, 0.54595911, 0.84150056, 0.20621477,
0.03414278, 0.29305023, 0.45018883, 0.20085125, 0.40840351, 0.66825399, 0.84792934, 0.16337382, 0.29096146,
0.39040606, 0.83718991, 0.36255192, 0.75091973, 0.46448582, 0.52401399, 0.90176318, 0.38023557, 0.71028634,
0.84000996, 0.23419118, 0.30553855, 0.72794856, 0.32839433, 0.09785895, 0.4900262, 0.03030494, 0.14355491,
0.57876809, 0.75103524, 0.02254518, 0.66510853, 0.03025319, 0.14589509, 0.01467558, 0.48894704, 0.04131938,
0.78907133, 0.99612016, 0.96833451, 0.67448307, 0.93273006, 0.20283358, 0.73427328, 0.51123229, 0.58574252,
0.75568488, 0.9315693, 0.31038253, 0.34644132, 0.36777262, 0.63775962, 0.05487314, 0.83054504, 0.73965924,
0.37751022, 0.93023354, 0.48824517, 0.03106788, 0.91065432, 0.65471251, 0.00795302, 0.31126833, 0.58215106,
0.89782226, 0.33120933, 0.38434518, 0.94334496, 0.36956736, 0.30646061, 0.07810828, 0.97154079, 0.02514768,
0.84117586, 0.23971438, 0.22623238, 0.35594978, 0.36437615, 0.77427557, 0.53121471, 0.53240224, 0.39277132,
0.69702301, 0.00654816, 0.94311127, 0.99594941, 0.49187803, 0.10536416, 0.97302689, 0.44064925, 0.28800946,
0.29735792, 0.09520762, 0.16488844, 0.53418347, 0.65144419, 0.03973145, 0.53885783, 0.7358722, 0.45309111,
0.79190158, 0.22632916, 0.20165111, 0.42914622, 0.22774585, 0.85961036, 0.56634649, 0.10872035, 0.69251128,
0.2021623, 0.67509793, 0.52598371, 0.46793523, 0.91830822, 0.77020753, 0.09795741, 0.76007691, 0.50678007,
0.4795003, 0.66071835, 0.32442191, 0.23907092, 0.00210992, 0.75524376, 0.94234713, 0.5690928, 0.34679526,
0.30868398, 0.77709926, 0.51133303, 0.20918474, 0.39361117, 0.7962174, 0.48195751, 0.99470092, 0.30140288,
0.54417926, 0.83619258, 0.66607407, 0.63815312, 0.84754301, 0.46217861, 0.40382941, 0.8669163, 0.80892846,
0.78275811, 0.08359574, 0.74127963, 0.83093771, 0.88526743, 0.83504267, 0.97853459, 0.21692576, 0.45164991,
0.9074328, 0.87753109, 0.85276772, 0.53516365, 0.53168515, 0.40269904, 0.89566943, 0.37372088, 0.73002752,
0.35593885, 0.73251118, 0.77375582, 0.12760254, 0.8455303, 0.91220549, 0.97342524, 0.6901983, 0.01956135,
0.98966668, 0.89508114, 0.86107002, 0.85139379, 0.26868744, 0.49119517, 0.79074938, 0.91333433, 0.16790021,
0.14352574, 0.97134471, 0.69411371, 0.28605858, 0.70411151, 0.36656519, 0.70878013, 0.21327726, 0.34290399,
0.84309746, 0.90860334, 0.97362624, 0.05473755, 0.71643348, 0.14711903, 0.38781449, 0.784074, 0.40246134,
0.40066814, 0.93058349, 0.43298608, 0.11167385, 0.27113968, 0.33209627, 0.40601194, 0.81328762, 0.68107437,
0.46367926, 0.13120107, 0.38408714, 0.64249068, 0.68798637, 0.35231959, 0.98679773, 0.12638461, 0.75466016,
0.97212161, 0.15569373, 0.55338423, 0.2814492, 0.88983892, 0.33155614, 0.25340461, 0.02949572, 0.08162776,
0.49678983, 0.59962038, 0.20915831, 0.7750513, 0.6575729, 0.50223288, 0.37927361, 0.09678806, 0.22351711,
0.127808, 0.16958427, 0.26687417, 0.37408405};
}
else if constexpr (std::is_floating_point<OneBodyDensityMatrices::Value>::value)
{
data = {0.19381403, 0.71283387, 0.74185289, 0.01593182, 0.4237967, 0.30713958, 0.04877154, 0.32839586, 0.55966269,
0.86166257, 0.13164395, 0.18967966, 0.52716562, 0.75501117, 0.89099129, 0.91181301, 0.56299847, 0.84040689,
0.42003362, 0.54417536, 0.91786624, 0.24179691, 0.80097938, 0.52312247, 0.49135109, 0.06204171, 0.32918368,
0.19109569, 0.27910843, 0.13869181, 0.39276429, 0.84432737, 0.55296738, 0.58336158, 0.42912991, 0.73847165,
0.70821508, 0.44336715, 0.34394486, 0.81857914, 0.17582283, 0.85206451, 0.00962322, 0.6018576, 0.77709871,
0.9363809, 0.77303195, 0.79399817, 0.64872202, 0.90561921, 0.34909147, 0.5382709, 0.4735065, 0.97592345,
0.64042891, 0.98233348, 0.61072865, 0.99648271, 0.93723708, 0.12341335, 0.87404106, 0.52492966, 0.50025206,
0.32956586, 0.35388674, 0.41701219, 0.35787114, 0.78154075, 0.19389593, 0.56085759, 0.42076409, 0.45505835,
0.13691315, 0.92741853, 0.65416634, 0.01324141, 0.70580805, 0.36063625, 0.20206282, 0.04019175, 0.3161708,
0.8021294, 0.47419179, 0.58339627, 0.94680233, 0.14275504, 0.51723762, 0.88195736, 0.02861162, 0.54720941,
0.47704361, 0.72112318, 0.71249342, 0.57327699, 0.82174918, 0.65460258, 0.58492448, 0.0654615, 0.23514782,
0.56317195, 0.99078012, 0.16018222, 0.98232388, 0.48021303, 0.45997915, 0.22901306, 0.30486665, 0.47519321,
0.11869839, 0.25773838, 0.30733499, 0.03014402, 0.41846284, 0.51370103, 0.14486378, 0.91023931, 0.45369315,
0.18793261, 0.51507439, 0.46019929, 0.67773434, 0.20830221, 0.59268401, 0.48456955, 0.9678142, 0.50709602,
0.85130517, 0.60737725};
}
return data;
}

} // namespace testing
} // namespace qmcplusplus