diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index a275fea93dff..df2baa35c852 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -114,15 +114,10 @@ const xla::Shape& Value::xla_node_shape() const { Node::Node(torch::lazy::OpKind op, OpList operands, xla::Shape shape, size_t num_outputs, torch::lazy::hash_t hash_seed) - : torch::lazy::Node( - op, num_outputs, - /*node_hash=*/torch::lazy::HashCombine(op.hash(), hash_seed), - /*dag_hash_fn=*/ - [&](bool /*bakeInSizes*/) -> torch::lazy::hash_t { - return GetOperandHashes( - operands, torch::lazy::HashCombine(op.hash(), hash_seed)); - }), - xla_shape_(std::move(shape)) { + : torch::lazy::Node(op, num_outputs), + xla_shape_(std::move(shape)), + node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)), + dag_hash_(GetOperandHashes(operands, node_hash_)) { for (auto& operand : operands) { AddOperand(operand.node, operand.index); } @@ -139,11 +134,10 @@ Node::Node(torch::lazy::OpKind op, OpList operands, Node::Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs, torch::lazy::hash_t hash_seed) - : torch::lazy::Node(op, num_outputs, /*node_hash_fn=*/ - [&](bool /*bakeInSizes*/) -> torch::lazy::hash_t { - return GetOpHash(op, shape, hash_seed); - }), - xla_shape_(std::move(shape)) {} + : torch::lazy::Node(op, num_outputs), + xla_shape_(std::move(shape)), + node_hash_(GetOpHash(op, xla_shape_, hash_seed)), + dag_hash_(node_hash_) {} Node::~Node() { for (size_t i = 0; i < operands_as_outputs_.size(); ++i) { diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 55d2f9509e12..a14e5f06185e 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -125,6 +125,12 @@ class Node : public torch::lazy::Node { XlaOpVector ReturnOps(absl::Span ops, LoweringContext* loctx) const; + torch::lazy::hash_t node_hash() const { return node_hash_; } + + torch::lazy::hash_t hash() const override { return dag_hash_; } + + torch::lazy::hash_t shapeHash() const override { return dag_hash_; } + private: // Adds node's index output number as operand. void AddOperand(torch::lazy::NodePtr node, size_t index = 0); @@ -144,6 +150,8 @@ class Node : public torch::lazy::Node { xla::Shape xla_shape_; // We use a set for uses, as we want deterministic use sequencing. std::set uses_; + torch::lazy::hash_t node_hash_; + torch::lazy::hash_t dag_hash_; }; // RAII data structure to be used a stack variable to enter a new IR scope. IR diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index d2a860fca943..b196067ec715 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -55,7 +55,7 @@ torch::lazy::hash_t ComputeNodeKey( } const ir::Node* casted = dynamic_cast(node); key = torch::lazy::HashCombine(key, torch::lazy::Hash(casted->xla_shape())); - return torch::lazy::HashCombine(key, node->node_hash()); + return torch::lazy::HashCombine(key, casted->node_hash()); } xla::XlaComputation BuildNodeComputation(