From 4eddbe9386681303ad03af37ec191e4c9e62223a Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 19:20:24 +0800 Subject: [PATCH] Align the semantic of accumulate type in the renorm operator with stock PyTorch (#823) Align the semantic of accumulate type in the renorm operator with stock PyTorch. --- src/ATen/native/xpu/Normalization.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/Normalization.cpp b/src/ATen/native/xpu/Normalization.cpp index 29236e553..3bc170da6 100644 --- a/src/ATen/native/xpu/Normalization.cpp +++ b/src/ATen/native/xpu/Normalization.cpp @@ -47,7 +47,10 @@ Tensor& renorm_impl( reduce_dims.erase(reduce_dims.begin() + dim); auto dtype = self.scalar_type(); - auto acc_type = at::toAccumulateType(dtype, c10::DeviceType::XPU); + + // This is a device-independent accumulate type, and we follow PyTorch's design. + auto acc_type = at::toAccumulateType(dtype, true); + Tensor norm; if (acc_type != dtype) { norm = at::linalg_vector_norm(