Skip to content

Commit

Permalink
[C++] Refactor PredictionContext and yet more performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jcking committed Mar 29, 2022
1 parent de8b934 commit 89c958b
Show file tree
Hide file tree
Showing 13 changed files with 459 additions and 485 deletions.
4 changes: 2 additions & 2 deletions runtime/Cpp/runtime/src/atn/ATNConfigSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace atn {

ATNConfigSet(const ATNConfigSet &other);

ATNConfigSet(ATNConfigSet &&) = delete;
ATNConfigSet(ATNConfigSet&&) = delete;

explicit ATNConfigSet(bool fullCtx);

Expand All @@ -68,7 +68,7 @@ namespace atn {

bool addAll(const ATNConfigSet &other);

std::vector<ATNState *> getStates() const;
std::vector<ATNState*> getStates() const;

/**
* Gets the complete set of represented alternatives for the configuration
Expand Down
7 changes: 3 additions & 4 deletions runtime/Cpp/runtime/src/atn/ATNSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ void ATNSimulator::clearDFA() {
throw UnsupportedOperationException("This ATN simulator does not support clearing the DFA.");
}

PredictionContextCache& ATNSimulator::getSharedContextCache() {
PredictionContextCache& ATNSimulator::getSharedContextCache() const {
return _sharedContextCache;
}

Ref<const PredictionContext> ATNSimulator::getCachedContext(Ref<const PredictionContext> const& context) {
Ref<const PredictionContext> ATNSimulator::getCachedContext(const Ref<const PredictionContext> &context) {
// This function must only be called with an active state lock, as we are going to change a shared structure.
std::map<Ref<const PredictionContext>, Ref<const PredictionContext>> visited;
return PredictionContext::getCachedContext(context, _sharedContextCache, visited);
return PredictionContext::getCachedContext(context, getSharedContextCache());
}
5 changes: 3 additions & 2 deletions runtime/Cpp/runtime/src/atn/ATNSimulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ namespace atn {
* @since 4.3
*/
virtual void clearDFA();
virtual PredictionContextCache& getSharedContextCache();
virtual Ref<const PredictionContext> getCachedContext(Ref<const PredictionContext> const& context);

PredictionContextCache& getSharedContextCache() const;
Ref<const PredictionContext> getCachedContext(const Ref<const PredictionContext> &context);

protected:
/// <summary>
Expand Down
37 changes: 24 additions & 13 deletions runtime/Cpp/runtime/src/atn/ArrayPredictionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
*/

#include "support/Arrays.h"
#include "support/Casts.h"
#include "atn/SingletonPredictionContext.h"
#include "misc/MurmurHash.h"

#include "atn/ArrayPredictionContext.h"

using namespace antlr4::atn;
using namespace antlrcpp;

ArrayPredictionContext::ArrayPredictionContext(Ref<const SingletonPredictionContext> const& a)
: ArrayPredictionContext({ a->parent }, { a->returnState }) {
}

ArrayPredictionContext::ArrayPredictionContext(std::vector<Ref<const PredictionContext>> parents,
std::vector<size_t> returnStates)
: PredictionContext(calculateHashCode(parents, returnStates)), parents(std::move(parents)), returnStates(std::move(returnStates)) {
: parents(std::move(parents)), returnStates(std::move(returnStates)) {
assert(this->parents.size() > 0);
assert(this->returnStates.size() > 0);
assert(this->parents.size() == this->returnStates.size());
}

PredictionContextType ArrayPredictionContext::getContextType() const {
Expand All @@ -34,29 +38,36 @@ size_t ArrayPredictionContext::size() const {
return returnStates.size();
}

Ref<const PredictionContext> ArrayPredictionContext::getParent(size_t index) const {
const Ref<const PredictionContext>& ArrayPredictionContext::getParent(size_t index) const {
return parents[index];
}

size_t ArrayPredictionContext::getReturnState(size_t index) const {
return returnStates[index];
}

bool ArrayPredictionContext::operator == (PredictionContext const& o) const {
if (this == &o) {
return true;
size_t ArrayPredictionContext::hashCodeImpl() const {
size_t hash = misc::MurmurHash::initialize();
hash = misc::MurmurHash::update(hash, static_cast<size_t>(getContextType()));
for (const auto &parent : parents) {
hash = misc::MurmurHash::update(hash, parent);
}
if (o.getContextType() != PredictionContextType::ARRAY) {
return false;
for (const auto &returnState : returnStates) {
hash = misc::MurmurHash::update(hash, returnState);
}
return misc::MurmurHash::finish(hash, 1 + parents.size() + returnStates.size());
}

const ArrayPredictionContext *other = static_cast<const ArrayPredictionContext*>(&o);
if (hashCode() != other->hashCode()) {
return false; // can't be same if hash is different
bool ArrayPredictionContext::equals(const PredictionContext &other) const {
if (this == &other) {
return true;
}

return antlrcpp::Arrays::equals(returnStates, other->returnStates) &&
antlrcpp::Arrays::equals(parents, other->parents);
if (getContextType() != other.getContextType()) {
return false;
}
const ArrayPredictionContext &array = downCast<const ArrayPredictionContext&>(other);
return Arrays::equals(returnStates, array.returnStates) &&
Arrays::equals(parents, array.parents);
}

std::string ArrayPredictionContext::toString() const {
Expand Down
23 changes: 13 additions & 10 deletions runtime/Cpp/runtime/src/atn/ArrayPredictionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,27 @@ namespace atn {
/// returnState == EMPTY_RETURN_STATE.
// Also here: we use a strong reference to our parents to avoid having them freed prematurely.
// See also SinglePredictionContext.
const std::vector<Ref<const PredictionContext>> parents;
std::vector<Ref<const PredictionContext>> parents;

/// Sorted for merge, no duplicates; if present, EMPTY_RETURN_STATE is always last.
const std::vector<size_t> returnStates;
std::vector<size_t> returnStates;

explicit ArrayPredictionContext(Ref<const SingletonPredictionContext> const &a);

ArrayPredictionContext(std::vector<Ref<const PredictionContext>> parents, std::vector<size_t> returnStates);

PredictionContextType getContextType() const override;

virtual bool isEmpty() const override;
virtual size_t size() const override;
virtual Ref<const PredictionContext> getParent(size_t index) const override;
virtual size_t getReturnState(size_t index) const override;
bool operator == (const PredictionContext &o) const override;
ArrayPredictionContext(ArrayPredictionContext&&) = default;

virtual std::string toString() const override;
PredictionContextType getContextType() const override;
bool isEmpty() const override;
size_t size() const override;
const Ref<const PredictionContext>& getParent(size_t index) const override;
size_t getReturnState(size_t index) const override;
bool equals(const PredictionContext &other) const override;
std::string toString() const override;

protected:
size_t hashCodeImpl() const override;
};

} // namespace atn
Expand Down
Loading

0 comments on commit 89c958b

Please sign in to comment.