Skip to content

Commit

Permalink
[C++] Refactor PredictionContext and yet more performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jcking committed Mar 31, 2022
1 parent e962bef commit 77f332a
Show file tree
Hide file tree
Showing 13 changed files with 501 additions and 502 deletions.
4 changes: 2 additions & 2 deletions runtime/Cpp/runtime/src/atn/ATNConfigSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace atn {

ATNConfigSet(const ATNConfigSet &other);

ATNConfigSet(ATNConfigSet &&) = delete;
ATNConfigSet(ATNConfigSet&&) = delete;

explicit ATNConfigSet(bool fullCtx);

Expand All @@ -68,7 +68,7 @@ namespace atn {

bool addAll(const ATNConfigSet &other);

std::vector<ATNState *> getStates() const;
std::vector<ATNState*> getStates() const;

/**
* Gets the complete set of represented alternatives for the configuration
Expand Down
7 changes: 3 additions & 4 deletions runtime/Cpp/runtime/src/atn/ATNSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ void ATNSimulator::clearDFA() {
throw UnsupportedOperationException("This ATN simulator does not support clearing the DFA.");
}

PredictionContextCache& ATNSimulator::getSharedContextCache() {
PredictionContextCache& ATNSimulator::getSharedContextCache() const {
return _sharedContextCache;
}

Ref<const PredictionContext> ATNSimulator::getCachedContext(Ref<const PredictionContext> const& context) {
Ref<const PredictionContext> ATNSimulator::getCachedContext(const Ref<const PredictionContext> &context) {
// This function must only be called with an active state lock, as we are going to change a shared structure.
std::map<Ref<const PredictionContext>, Ref<const PredictionContext>> visited;
return PredictionContext::getCachedContext(context, _sharedContextCache, visited);
return PredictionContext::getCachedContext(context, getSharedContextCache());
}
5 changes: 3 additions & 2 deletions runtime/Cpp/runtime/src/atn/ATNSimulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ namespace atn {
* @since 4.3
*/
virtual void clearDFA();
virtual PredictionContextCache& getSharedContextCache();
virtual Ref<const PredictionContext> getCachedContext(Ref<const PredictionContext> const& context);

PredictionContextCache& getSharedContextCache() const;
Ref<const PredictionContext> getCachedContext(const Ref<const PredictionContext> &context);

protected:
/// <summary>
Expand Down
67 changes: 47 additions & 20 deletions runtime/Cpp/runtime/src/atn/ArrayPredictionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,39 @@
* can be found in the LICENSE.txt file in the project root.
*/

#include "support/Arrays.h"
#include "atn/SingletonPredictionContext.h"

#include "atn/ArrayPredictionContext.h"

#include <cstring>

#include "atn/SingletonPredictionContext.h"
#include "misc/MurmurHash.h"
#include "support/Casts.h"

using namespace antlr4::atn;
using namespace antlr4::misc;
using namespace antlrcpp;

namespace {

bool cachedHashCodeEqual(size_t lhs, size_t rhs) {
return lhs == rhs || lhs == 0 || rhs == 0;
}

bool predictionContextEqual(const Ref<const PredictionContext> &lhs, const Ref<const PredictionContext> &rhs) {
return *lhs == *rhs;
}

ArrayPredictionContext::ArrayPredictionContext(Ref<const SingletonPredictionContext> const& a)
: ArrayPredictionContext({ a->parent }, { a->returnState }) {
}

ArrayPredictionContext::ArrayPredictionContext(const SingletonPredictionContext &predictionContext)
: ArrayPredictionContext({ predictionContext.parent }, { predictionContext.returnState }) {}

ArrayPredictionContext::ArrayPredictionContext(std::vector<Ref<const PredictionContext>> parents,
std::vector<size_t> returnStates)
: PredictionContext(PredictionContextType::ARRAY, calculateHashCode(parents, returnStates)), parents(std::move(parents)), returnStates(std::move(returnStates)) {
assert(this->parents.size() > 0);
assert(this->returnStates.size() > 0);
: PredictionContext(PredictionContextType::ARRAY), parents(std::move(parents)), returnStates(std::move(returnStates)) {
assert(this->parents.size() > 0);
assert(this->returnStates.size() > 0);
assert(this->parents.size() == this->returnStates.size());
}

bool ArrayPredictionContext::isEmpty() const {
Expand All @@ -30,29 +47,39 @@ size_t ArrayPredictionContext::size() const {
return returnStates.size();
}

Ref<const PredictionContext> ArrayPredictionContext::getParent(size_t index) const {
const Ref<const PredictionContext>& ArrayPredictionContext::getParent(size_t index) const {
return parents[index];
}

size_t ArrayPredictionContext::getReturnState(size_t index) const {
return returnStates[index];
}

bool ArrayPredictionContext::operator == (PredictionContext const& o) const {
if (this == &o) {
return true;
size_t ArrayPredictionContext::hashCodeImpl() const {
size_t hash = MurmurHash::initialize();
hash = MurmurHash::update(hash, static_cast<size_t>(getContextType()));
for (const auto &parent : parents) {
hash = MurmurHash::update(hash, parent);
}
if (o.getContextType() != PredictionContextType::ARRAY) {
return false;
for (const auto &returnState : returnStates) {
hash = MurmurHash::update(hash, returnState);
}
return MurmurHash::finish(hash, 1 + parents.size() + returnStates.size());
}

const ArrayPredictionContext *other = static_cast<const ArrayPredictionContext*>(&o);
if (hashCode() != other->hashCode()) {
return false; // can't be same if hash is different
bool ArrayPredictionContext::equals(const PredictionContext &other) const {
if (this == std::addressof(other)) {
return true;
}

return antlrcpp::Arrays::equals(returnStates, other->returnStates) &&
antlrcpp::Arrays::equals(parents, other->parents);
if (getContextType() != other.getContextType()) {
return false;
}
const auto &array = downCast<const ArrayPredictionContext&>(other);
return returnStates.size() == array.returnStates.size() &&
parents.size() == array.parents.size() &&
cachedHashCodeEqual(cachedHashCode(), array.cachedHashCode()) &&
std::memcmp(returnStates.data(), array.returnStates.data(), returnStates.size() * sizeof(decltype(returnStates)::value_type)) == 0 &&
std::equal(parents.begin(), parents.end(), array.parents.begin(), predictionContextEqual);
}

std::string ArrayPredictionContext::toString() const {
Expand Down
22 changes: 13 additions & 9 deletions runtime/Cpp/runtime/src/atn/ArrayPredictionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,26 @@ namespace atn {
/// returnState == EMPTY_RETURN_STATE.
// Also here: we use a strong reference to our parents to avoid having them freed prematurely.
// See also SinglePredictionContext.
const std::vector<Ref<const PredictionContext>> parents;
std::vector<Ref<const PredictionContext>> parents;

/// Sorted for merge, no duplicates; if present, EMPTY_RETURN_STATE is always last.
const std::vector<size_t> returnStates;
std::vector<size_t> returnStates;

explicit ArrayPredictionContext(Ref<const SingletonPredictionContext> const &a);
explicit ArrayPredictionContext(const SingletonPredictionContext &predictionContext);

ArrayPredictionContext(std::vector<Ref<const PredictionContext>> parents, std::vector<size_t> returnStates);

virtual bool isEmpty() const override;
virtual size_t size() const override;
virtual Ref<const PredictionContext> getParent(size_t index) const override;
virtual size_t getReturnState(size_t index) const override;
bool operator == (const PredictionContext &o) const override;
ArrayPredictionContext(ArrayPredictionContext&&) = default;

virtual std::string toString() const override;
bool isEmpty() const override;
size_t size() const override;
const Ref<const PredictionContext>& getParent(size_t index) const override;
size_t getReturnState(size_t index) const override;
bool equals(const PredictionContext &other) const override;
std::string toString() const override;

protected:
size_t hashCodeImpl() const override;
};

} // namespace atn
Expand Down
Loading

0 comments on commit 77f332a

Please sign in to comment.