Skip to content

Commit

Permalink
Merge pull request #4691 from markdewing/fix_4690
Browse files Browse the repository at this point in the history
Guard batch size computation against zero
  • Loading branch information
prckent authored Aug 8, 2023
2 parents ba7c40a + 2a7ed46 commit 1559124
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/QMCDrivers/WFOpt/QMCCostFunctionBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,26 @@ void QMCCostFunctionBatched::getConfigurations(const std::string& aroot)
}
}

// Input - sample_size - number of samples to process
// - batch_size - process samples in batch_size at a time
// Output - num_batches - number of batches to use
// - final_batch_size - the last batch size. May be smaller than batch_size
// if the number of samples is not a multiple of the batch size
/** Compute number of batches and final batch size given the number of samples
* and a batch size.
* \param[in] sample_size number of samples to process.
* \param[in] batch_size process samples in batch_size at a time (typically the number of walkers in a crowd).
* \param[out] num_batches number of batches to use.
* \param[out] final_batch_size the last batch size. May be smaller than batch_size
* if the number of samples is not a multiple of the batch size.
*
* There may be cases where the batch size is zero. One cause is when the number of walkers per
* rank is less than the number of crowds.
*/
void compute_batch_parameters(int sample_size, int batch_size, int& num_batches, int& final_batch_size)
{
num_batches = sample_size / batch_size;
if (batch_size == 0)
num_batches = 0;
else
num_batches = sample_size / batch_size;

final_batch_size = batch_size;
if (sample_size % batch_size != 0)
if (batch_size != 0 && sample_size % batch_size != 0)
{
num_batches += 1;
final_batch_size = sample_size % batch_size;
Expand Down
5 changes: 5 additions & 0 deletions src/QMCDrivers/tests/test_QMCCostFunctionBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ TEST_CASE("compute_batch_parameters", "[drivers]")
compute_batch_parameters(sample_size, batch_size, num_batches, final_batch_size);
CHECK(num_batches == 3);
CHECK(final_batch_size == 3);

batch_size = 0;
compute_batch_parameters(sample_size, batch_size, num_batches, final_batch_size);
CHECK(num_batches == 0);
CHECK(final_batch_size == 0);
}

namespace testing
Expand Down

0 comments on commit 1559124

Please sign in to comment.