Skip to content

Commit

Permalink
Revert "Revert "Node hash new (#3514)" (#3517)" (#3518)
Browse files Browse the repository at this point in the history
This reverts commit 7c57210.
  • Loading branch information
JackCaoG authored Apr 22, 2022
1 parent 1d0a6e7 commit f4b6488
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
22 changes: 8 additions & 14 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,10 @@ const xla::Shape& Value::xla_node_shape() const {

Node::Node(torch::lazy::OpKind op, OpList operands, xla::Shape shape,
size_t num_outputs, torch::lazy::hash_t hash_seed)
: torch::lazy::Node(
op, num_outputs,
/*node_hash=*/torch::lazy::HashCombine(op.hash(), hash_seed),
/*dag_hash_fn=*/
[&](bool /*bakeInSizes*/) -> torch::lazy::hash_t {
return GetOperandHashes(
operands, torch::lazy::HashCombine(op.hash(), hash_seed));
}),
xla_shape_(std::move(shape)) {
: torch::lazy::Node(op, num_outputs),
xla_shape_(std::move(shape)),
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
dag_hash_(GetOperandHashes(operands, node_hash_)) {
for (auto& operand : operands) {
AddOperand(operand.node, operand.index);
}
Expand All @@ -139,11 +134,10 @@ Node::Node(torch::lazy::OpKind op, OpList operands,

Node::Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs,
torch::lazy::hash_t hash_seed)
: torch::lazy::Node(op, num_outputs, /*node_hash_fn=*/
[&](bool /*bakeInSizes*/) -> torch::lazy::hash_t {
return GetOpHash(op, shape, hash_seed);
}),
xla_shape_(std::move(shape)) {}
: torch::lazy::Node(op, num_outputs),
xla_shape_(std::move(shape)),
node_hash_(GetOpHash(op, xla_shape_, hash_seed)),
dag_hash_(node_hash_) {}

Node::~Node() {
for (size_t i = 0; i < operands_as_outputs_.size(); ++i) {
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class Node : public torch::lazy::Node {
XlaOpVector ReturnOps(absl::Span<const xla::XlaOp> ops,
LoweringContext* loctx) const;

torch::lazy::hash_t node_hash() const { return node_hash_; }

torch::lazy::hash_t hash() const override { return dag_hash_; }

torch::lazy::hash_t shapeHash() const override { return dag_hash_; }

private:
// Adds node's index output number as operand.
void AddOperand(torch::lazy::NodePtr node, size_t index = 0);
Expand All @@ -144,6 +150,8 @@ class Node : public torch::lazy::Node {
xla::Shape xla_shape_;
// We use a set for uses, as we want deterministic use sequencing.
std::set<Use> uses_;
torch::lazy::hash_t node_hash_;
torch::lazy::hash_t dag_hash_;
};

// RAII data structure to be used a stack variable to enter a new IR scope. IR
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/op_by_op_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ torch::lazy::hash_t ComputeNodeKey(
}
const ir::Node* casted = dynamic_cast<const ir::Node*>(node);
key = torch::lazy::HashCombine(key, torch::lazy::Hash(casted->xla_shape()));
return torch::lazy::HashCombine(key, node->node_hash());
return torch::lazy::HashCombine(key, casted->node_hash());
}

xla::XlaComputation BuildNodeComputation(
Expand Down

0 comments on commit f4b6488

Please sign in to comment.