From ac613a73fb395836b210710a6fefdf6d32df3386 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Fri, 19 Jan 2024 08:47:52 +0000 Subject: [PATCH] [LLGA JIT fuser] Unary aten::max would have two outputs (#2491) * [Change 1/2] aten::max may be unary & may have two outputs * [Change 2/2] Add UT * Fix style --------- Co-authored-by: Chunyuan WU --- csrc/cpu/jit/codegen/onednn/utils.cpp | 8 ++++---- tests/cpu/test_jit_llga_fuser.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/csrc/cpu/jit/codegen/onednn/utils.cpp b/csrc/cpu/jit/codegen/onednn/utils.cpp index 63685b3cd..5db9f8b7a 100644 --- a/csrc/cpu/jit/codegen/onednn/utils.cpp +++ b/csrc/cpu/jit/codegen/onednn/utils.cpp @@ -157,14 +157,14 @@ void convertInputTo0DTensor( void modifyDtypeOfNode(torch::jit::Node* node, at::ScalarType dtype) { auto existingDtype = - node->output()->type()->expect()->scalarType(); + node->outputs()[0]->type()->expect()->scalarType(); if (existingDtype.has_value()) { switch (existingDtype.value()) { case at::ScalarType::Float: case at::ScalarType::BFloat16: case at::kInt: - node->output()->setType( - node->output()->type()->expect()->withScalarType( + node->outputs()[0]->setType( + node->outputs()[0]->type()->expect()->withScalarType( dtype)); break; default: @@ -189,7 +189,7 @@ void insertTypeCast( } void mayModifyOutputDtype(torch::jit::Node* node) { - if (node->output()->type()->isSubtypeOf(TensorType::get())) { + if (node->outputs()[0]->type()->isSubtypeOf(TensorType::get())) { if (node->hasAttributeS("was_float")) { modifyDtypeOfNode(node, at::ScalarType::Float); node->removeAttributeS("was_float"); diff --git a/tests/cpu/test_jit_llga_fuser.py b/tests/cpu/test_jit_llga_fuser.py index c960250d0..e6a2d2999 100644 --- a/tests/cpu/test_jit_llga_fuser.py +++ b/tests/cpu/test_jit_llga_fuser.py @@ -405,6 +405,21 @@ def forward(self, x, y): graph, _ = self.checkTrace(m, [x, y]) self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) + @llga_fp32_bf16_test_env + def test_max_two_outputs(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + + def forward(self, x): + # max is unary, and would have 2 outputs + return torch.max(x, dim=1) + + m = M() + x = torch.rand(8, 12, 12, 12) + graph, _ = self.checkTrace(m, [x]) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) + @llga_fp32_bf16_test_env def test_bmm_div(self): class M(nn.Module):