Skip to content

Commit

Permalink
Add second BatchRule ctor taking string format.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Apr 2, 2024
1 parent f6a05a2 commit 8e69cad
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 9 deletions.
10 changes: 10 additions & 0 deletions PhysicsTools/TensorFlowAOT/interface/Batching.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ namespace tfaot {
// constructor
explicit BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0);

// constructor taking a string in the format "batchSize:size1,...,sizeN" with lastPadding being
// inferred from the sum of sizes
BatchRule(const std::string& ruleString);

// destructor
~BatchRule() = default;

Expand All @@ -43,6 +47,9 @@ namespace tfaot {
size_t batchSize_;
std::vector<size_t> sizes_;
size_t lastPadding_;

// validation helper
void validate() const;
};

// stream operator
Expand All @@ -60,6 +67,9 @@ namespace tfaot {
// registers a new rule for a batch size
void setRule(const BatchRule& rule) { rules_.insert_or_assign(rule.getBatchSize(), rule); }

// registers a new rule for a batch size, given a rule string (see BatchRule constructor)
void setRule(const std::string& ruleString) { this->setRule(BatchRule(ruleString)); }

// returns whether a rule was already registered for a certain batch size
bool hasRule(size_t batchSize) const { return rules_.find(batchSize) != rules_.end(); }

Expand Down
3 changes: 3 additions & 0 deletions PhysicsTools/TensorFlowAOT/interface/Model.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ namespace tfaot {
batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding));
}

// adds a new batch rule to the strategy, given a rule string (see BatchRule constructor)
void setBatchRule(const std::string& batchRule) { batchStrategy_.setRule(BatchRule(batchRule)); }

// evaluates the model for multiple inputs and outputs of different types
template <typename... Outputs, typename... Inputs>
std::tuple<Outputs...> run(size_t batchSize, Inputs&&... inputs);
Expand Down
55 changes: 46 additions & 9 deletions PhysicsTools/TensorFlowAOT/src/Batching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,70 @@ namespace tfaot {

BatchRule::BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding)
: batchSize_(batchSize), sizes_(sizes), lastPadding_(lastPadding) {
validate();
}

BatchRule::BatchRule(const std::string& ruleString) {
// extract the target batch size from the front
std::string rule = ruleString;
auto pos = rule.find(":");
if (pos == std::string::npos) {
throw cms::Exception("InvalidBatchRule") << "invalid batch rule format: " << ruleString;
}
size_t batchSize = std::stoi(rule.substr(0, pos));
rule = rule.substr(pos + 1);

// loop through remaining comma-separated sizes
std::vector<size_t> sizes;
size_t sumSizes = 0;
while (!rule.empty()) {
pos = rule.find(",");
sizes.push_back(std::stoi(rule.substr(0, pos)));
sumSizes += sizes.back();
rule = pos == std::string::npos ? "" : rule.substr(pos + 1);
}

// the sum of composite batch sizes should never be smaller than the target batch size
if (sumSizes < batchSize) {
throw cms::Exception("InvalidBatchRule")
<< "sum of composite batch sizes is smaller than target batch size: " << ruleString;
}

// set members and validate
batchSize_ = batchSize;
sizes_ = sizes;
lastPadding_ = sumSizes - batchSize;
validate();
}

void BatchRule::validate() const {
// sizes must not be empty
if (sizes.size() == 0) {
if (sizes_.size() == 0) {
throw cms::Exception("EmptySizes") << "no batch sizes provided for stitching";
}

// the padding must be smaller than the last size
size_t lastSize = sizes[sizes.size() - 1];
if (lastPadding >= lastSize) {
size_t lastSize = sizes_[sizes_.size() - 1];
if (lastPadding_ >= lastSize) {
throw cms::Exception("WrongPadding")
<< "padding " << lastPadding << " must be smaller than last size " << lastSize;
<< "padding " << lastPadding_ << " must be smaller than last size " << lastSize;
}

// compute the covered batch size
size_t sizeSum = 0;
for (const size_t& s : sizes_) {
sizeSum += s;
}
if (lastPadding > sizeSum) {
if (lastPadding_ > sizeSum) {
throw cms::Exception("WrongPadding")
<< "padding " << lastPadding << " must not be larger than sum of sizes " << sizeSum;
<< "padding " << lastPadding_ << " must not be larger than sum of sizes " << sizeSum;
}
sizeSum -= lastPadding;
sizeSum -= lastPadding_;

// compare to given batch size
if (batchSize != sizeSum) {
if (batchSize_ != sizeSum) {
throw cms::Exception("WrongBatchSize")
<< "batch size " << batchSize << " does not match sum of sizes - padding " << sizeSum;
<< "batch size " << batchSize_ << " does not match sum of sizes - padding " << sizeSum;
}
}

Expand Down
4 changes: 4 additions & 0 deletions PhysicsTools/TensorFlowAOT/test/testInterface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void testInterface::test_simple() {
// register (optional) batch rules
model.setBatchRule(1, {1});
model.setBatchRule(3, {2, 2}, 1);
model.setBatchRule("5:2,2,2");

// test batching strategies
CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(1));
Expand All @@ -50,6 +51,9 @@ void testInterface::test_simple() {
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).nSizes() == 2);
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).getLastPadding() == 1);
CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4));
CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(5));
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).nSizes() == 3);
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).getLastPadding() == 1);

// evaluate batch size 1
tfaot::FloatArrays input_bs1 = {{0, 1, 2, 3}};
Expand Down

0 comments on commit 8e69cad

Please sign in to comment.