Skip to content

Commit

Permalink
made the argument min_blockdim non-const to avoid copying twice
Browse files Browse the repository at this point in the history
  • Loading branch information
manuschneider committed Oct 16, 2024
1 parent f21621e commit 7c62492
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 46 deletions.
4 changes: 2 additions & 2 deletions include/linalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ namespace cytnx {
*/
std::vector<cytnx::UniTensor> Svd_truncate(const cytnx::UniTensor &Tin,
const cytnx_uint64 &keepdim,
const std::vector<cytnx_uint64> min_blockdim,
std::vector<cytnx_uint64> min_blockdim,
const double &err = 0., const bool &is_UvT = true,
const unsigned int &return_err = 0,
const cytnx_uint64 &mindim = 1);
Expand Down Expand Up @@ -814,7 +814,7 @@ namespace cytnx {
*/
std::vector<cytnx::UniTensor> Gesvd_truncate(
const cytnx::UniTensor &Tin, const cytnx_uint64 &keepdim,
const std::vector<cytnx_uint64> min_blockdim, const double &err = 0., const bool &is_U = true,
std::vector<cytnx_uint64> min_blockdim, const double &err = 0., const bool &is_U = true,
const bool &is_vT = true, const unsigned int &return_err = 0, const cytnx_uint64 &mindim = 1);

std::vector<cytnx::UniTensor> Hosvd(
Expand Down
12 changes: 6 additions & 6 deletions pybind/linalg_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ void linalg_binding(py::module &m) {
py::arg("is_vT") = true, py::arg("return_err") = (unsigned int)(0), py::arg("mindim") = 1);
m_linalg.def(
"Gesvd_truncate",
[](const UniTensor &Tin, const cytnx_uint64 &keepdim,
const std::vector<cytnx_uint64> min_blockdim, const cytnx_double &err, const bool &is_U,
const bool &is_vT, const unsigned int &return_err, const cytnx_uint64 &mindim) {
[](const UniTensor &Tin, const cytnx_uint64 &keepdim, std::vector<cytnx_uint64> min_blockdim,
const cytnx_double &err, const bool &is_U, const bool &is_vT, const unsigned int &return_err,
const cytnx_uint64 &mindim) {
return cytnx::linalg::Gesvd_truncate(Tin, keepdim, min_blockdim, err, is_U, is_vT, return_err,
mindim);
},
Expand All @@ -92,9 +92,9 @@ void linalg_binding(py::module &m) {
py::arg("return_err") = (unsigned int)(0), py::arg("mindim") = 1);
m_linalg.def(
"Svd_truncate",
[](const UniTensor &Tin, const cytnx_uint64 &keepdim,
const std::vector<cytnx_uint64> min_blockdim, const cytnx_double &err, const bool &is_UvT,
const unsigned int &return_err, const cytnx_uint64 &mindim) {
[](const UniTensor &Tin, const cytnx_uint64 &keepdim, std::vector<cytnx_uint64> min_blockdim,
const cytnx_double &err, const bool &is_UvT, const unsigned int &return_err,
const cytnx_uint64 &mindim) {
return cytnx::linalg::Svd_truncate(Tin, keepdim, min_blockdim, err, is_UvT, return_err,
mindim);
},
Expand Down
35 changes: 17 additions & 18 deletions src/linalg/Gesvd_truncate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,49 +385,48 @@ namespace cytnx {

void _gesvd_truncate_Block_UT(std::vector<UniTensor> &outCyT, const cytnx::UniTensor &Tin,

Check warning on line 386 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L386

Added line #L386 was not covered by tests
const cytnx_uint64 &keepdim,
const std::vector<cytnx_uint64> min_blockdim, const double &err,
std::vector<cytnx_uint64> min_blockdim, const double &err,
const bool &is_U, const bool &is_vT,
const unsigned int &return_err, const cytnx_uint64 &mindim) {
// currently, Gesvd is used as a standard for the full SVD before truncation
cytnx_int64 keep_dim = keepdim; // these must be signed int, because they can become

Check warning on line 392 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L392

Added line #L392 was not covered by tests
// negative!
cytnx_int64 min_dim = (mindim < 1 ? 1 : mindim);

Check warning on line 394 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L394

Added line #L394 was not covered by tests
std::vector<cytnx_uint64> minblockdim = min_blockdim;

outCyT = linalg::Gesvd(Tin, is_U, is_vT);

Check warning on line 396 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L396

Added line #L396 was not covered by tests
if (minblockdim.size() == 1) // if only one element given, make it a vector
minblockdim = std::vector<cytnx_uint64>(outCyT[0].Nblocks(), minblockdim[0]);
if (min_blockdim.size() == 1) // if only one element given, make it a vector
min_blockdim.resize(outCyT[0].Nblocks(), min_blockdim.front());

Check warning on line 398 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L398

Added line #L398 was not covered by tests
cytnx_error_msg(
minblockdim.size() != outCyT[0].Nblocks(),
min_blockdim.size() != outCyT[0].Nblocks(),
"[ERROR][Gesvd_truncate] min_blockdim must have the same number of elements as "
"blocks in the singular value UniTensor%s",
"\n");

Check warning on line 403 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L403

Added line #L403 was not covered by tests

// process truncation:
// 1) concate all S vals from all blk but exclude the first minblockdim Svals in each block
// 1) concate all S vals from all blk but exclude the first min_blockdim Svals in each block
// (since they will be kept anyways later)
Tensor Sall; // S vals excluding the already kept ones

Check warning on line 408 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L408

Added line #L408 was not covered by tests
Tensor Block; // current block
cytnx_uint64 blockdim;
bool anySall = false; // are there already any values in Sall vals?
bool any_min_blockdim = false; // is any minblockdim > 0?
bool any_min_blockdim = false; // is any min_blockdim > 0?
for (int b = 0; b < outCyT[0].Nblocks(); b++) {
if (minblockdim[b] < 1) // save whole block to Sall
if (min_blockdim[b] < 1) // save whole block to Sall
Block = outCyT[0].get_block_(b);
else {
any_min_blockdim = true;

Check warning on line 417 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L417

Added line #L417 was not covered by tests
blockdim = outCyT[0].get_block_(b).shape()[0];
if (blockdim <= minblockdim[b]) {
if (blockdim <= min_blockdim[b]) {
// keep whole block
keep_dim -= blockdim;
min_dim -= blockdim;
continue;

Check warning on line 423 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L421-L423

Added lines #L421 - L423 were not covered by tests
}
// remove first minblockdim[b] values
// remove first min_blockdim[b] values
blockdim = outCyT[0].get_block_(b).shape()[0];
Block = outCyT[0].get_block_(b).get({ac::range(minblockdim[b], blockdim)});
keep_dim -= minblockdim[b];
min_dim -= minblockdim[b];
Block = outCyT[0].get_block_(b).get({ac::range(min_blockdim[b], blockdim)});
keep_dim -= min_blockdim[b];
min_dim -= min_blockdim[b];

Check warning on line 429 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L428-L429

Added lines #L428 - L429 were not covered by tests
}
if (anySall)
Sall = algo::Concatenate(Sall, Block);
Expand Down Expand Up @@ -498,13 +497,13 @@ namespace cytnx {
cytnx_uint64 cnt = 0;

Check warning on line 497 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L496-L497

Added lines #L496 - L497 were not covered by tests
for (int b = 0; b < S.Nblocks(); b++) {
Storage stmp = S.get_block_(b).storage();
cytnx_int64 kdim = minblockdim[b];
cytnx_int64 kdim = min_blockdim[b];
if (keep_dim > 0) {
// search for first value >= Smin
for (int i = stmp.size(); i > minblockdim[b]; i--) {
// Careful here: if (int i = stmp.size() -1; i >= minblockdim[b]; i--) is used
// instead, the compiler might make i an unsigned integer; if then minblockdim[b] ==
// 0, the condition i > minblockdim[b] is always fulfilled and the loop never stops!
for (int i = stmp.size(); i > min_blockdim[b]; i--) {
// Careful here: if (int i = stmp.size() -1; i >= min_blockdim[b]; i--) is used
// instead, the compiler might make i an unsigned integer; if then min_blockdim[b] ==
// 0, the condition i > min_blockdim[b] is always fulfilled and the loop never stops!
if (stmp(i - 1) >= Smin) {
kdim = i;
break;

Check warning on line 509 in src/linalg/Gesvd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Gesvd_truncate.cpp#L508-L509

Added lines #L508 - L509 were not covered by tests
Expand Down
38 changes: 18 additions & 20 deletions src/linalg/Svd_truncate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,49 +354,47 @@ namespace cytnx {
} // Svd_truncate

void _svd_truncate_Block_UT(std::vector<UniTensor> &outCyT, const cytnx::UniTensor &Tin,

Check warning on line 356 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L356

Added line #L356 was not covered by tests
const cytnx_uint64 &keepdim,
const std::vector<cytnx_uint64> min_blockdim, const double &err,
const bool &is_UvT, const int &return_err,
const cytnx_uint64 &keepdim, std::vector<cytnx_uint64> min_blockdim,
const double &err, const bool &is_UvT, const int &return_err,
const cytnx_uint64 &mindim) {
// currently, Gesvd is used as a standard for the full SVD before truncation
cytnx_int64 keep_dim = keepdim; // these must be signed int, because they can become

Check warning on line 361 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L361

Added line #L361 was not covered by tests
// negative!
cytnx_int64 min_dim = (mindim < 1 ? 1 : mindim);

Check warning on line 363 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L363

Added line #L363 was not covered by tests
std::vector<cytnx_uint64> minblockdim = min_blockdim;

outCyT = linalg::Gesvd(Tin, is_UvT, is_UvT);

Check warning on line 365 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L365

Added line #L365 was not covered by tests
if (minblockdim.size() == 1) // if only one element given, make it a vector
minblockdim = std::vector<cytnx_uint64>(outCyT[0].Nblocks(), minblockdim[0]);
cytnx_error_msg(minblockdim.size() != outCyT[0].Nblocks(),
if (min_blockdim.size() == 1) // if only one element given, make it a vector
min_blockdim.resize(outCyT[0].Nblocks(), min_blockdim.front());

Check warning on line 367 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L367

Added line #L367 was not covered by tests
cytnx_error_msg(min_blockdim.size() != outCyT[0].Nblocks(),
"[ERROR][Svd_truncate] min_blockdim must have the same number of elements as "
"blocks in the singular value UniTensor%s",
"\n");

Check warning on line 371 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L371

Added line #L371 was not covered by tests

// process truncation:
// 1) concate all S vals from all blk but exclude the firt minblockdim Svals in each block
// 1) concate all S vals from all blk but exclude the first min_blockdim Svals in each block
// (since they will be kept anyways later)
Tensor Sall; // S vals excluding the already kept ones

Check warning on line 376 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L376

Added line #L376 was not covered by tests
Tensor Block; // current block
cytnx_uint64 blockdim;
bool anySall = false; // are there already any values in Sall vals?
bool any_min_blockdim = false; // is any minblockdim > 0?
bool any_min_blockdim = false; // is any min_blockdim > 0?
for (int b = 0; b < outCyT[0].Nblocks(); b++) {
if (minblockdim[b] < 1) // save whole block to Sall
if (min_blockdim[b] < 1) // save whole block to Sall
Block = outCyT[0].get_block_(b);
else {
any_min_blockdim = true;

Check warning on line 385 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L385

Added line #L385 was not covered by tests
blockdim = outCyT[0].get_block_(b).shape()[0];
if (blockdim <= minblockdim[b]) {
if (blockdim <= min_blockdim[b]) {
// keep whole block
keep_dim -= blockdim;
min_dim -= blockdim;
continue;

Check warning on line 391 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L389-L391

Added lines #L389 - L391 were not covered by tests
}
// remove first minblockdim[b] values
// remove first min_blockdim[b] values
blockdim = outCyT[0].get_block_(b).shape()[0];
Block = outCyT[0].get_block_(b).get({ac::range(minblockdim[b], blockdim)});
keep_dim -= minblockdim[b];
min_dim -= minblockdim[b];
Block = outCyT[0].get_block_(b).get({ac::range(min_blockdim[b], blockdim)});
keep_dim -= min_blockdim[b];
min_dim -= min_blockdim[b];

Check warning on line 397 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L396-L397

Added lines #L396 - L397 were not covered by tests
}
if (anySall)
Sall = algo::Concatenate(Sall, Block);
Expand Down Expand Up @@ -467,13 +465,13 @@ namespace cytnx {
cytnx_uint64 cnt = 0;

Check warning on line 465 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L464-L465

Added lines #L464 - L465 were not covered by tests
for (int b = 0; b < S.Nblocks(); b++) {
Storage stmp = S.get_block_(b).storage();
cytnx_int64 kdim = minblockdim[b];
cytnx_int64 kdim = min_blockdim[b];
if (keep_dim > 0) {
// search for first value >= Smin
for (int i = stmp.size(); i > minblockdim[b]; i--) {
// Careful here: if (int i = stmp.size() -1; i >= minblockdim[b]; i--) is used
// instead, the compiler might make i an unsigned integer; if then minblockdim[b] ==
// 0, the condition i > minblockdim[b] is always fulfilled and the loop never stops!
for (int i = stmp.size(); i > min_blockdim[b]; i--) {
// Careful here: if (int i = stmp.size() -1; i >= min_blockdim[b]; i--) is used
// instead, the compiler might make i an unsigned integer; if then min_blockdim[b] ==
// 0, the condition i > min_blockdim[b] is always fulfilled and the loop never stops!
if (stmp(i - 1) >= Smin) {
kdim = i;
break;

Check warning on line 477 in src/linalg/Svd_truncate.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Svd_truncate.cpp#L476-L477

Added lines #L476 - L477 were not covered by tests
Expand Down

0 comments on commit 7c62492

Please sign in to comment.