Skip to content

Commit

Permalink
BatchNormalization: SYCL: convert memory format to align with SYCL ke…
Browse files Browse the repository at this point in the history
…rnel assumption (#3857) (#3882)

* check and transform format

* Update BatchNorm.cpp

* Update BatchNorm

* add comments



---------

Signed-off-by: Feng Yuan <feng1.yuan@intel.com>
Co-authored-by: Ye Ting <ting.ye@intel.com>
Co-authored-by: Feng Yuan <feng1.yuan@intel.com>
  • Loading branch information
3 people authored Mar 11, 2024
1 parent 1eef60d commit a1e2271
Showing 1 changed file with 98 additions and 16 deletions.
114 changes: 98 additions & 16 deletions csrc/gpu/aten/operators/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,14 @@ void batch_norm_elementwise(
batch_norm_elemt_channels_first_template<
scalar_t,
accscalar_t,
int32_t>(out, self, *weight, *bias, mean_, invstd_);
int32_t>(
out, self.contiguous(), *weight, *bias, mean_, invstd_);
} else {
batch_norm_elemt_channels_first_template<
scalar_t,
scalar_t,
int32_t>(out, self, *weight, *bias, mean_, invstd_);
int32_t>(
out, self.contiguous(), *weight, *bias, mean_, invstd_);
}
});
return;
Expand All @@ -607,7 +609,16 @@ void batch_norm_elementwise(
(!mean_.defined() || mean_.is_contiguous()) &&
(!invstd_.defined() || invstd_.is_contiguous())) {
batch_norm_elemt_channels_last_template(
out, self, *weight, *bias, mean_, invstd_);
out,
// It is a WA to fix Mobile-SSD convergence issue.
// TODO: Fully support: Check and convert activations with any
// shapes to align with kernel required memory layout.
self.dim() == 4 ? self.contiguous(at::MemoryFormat::ChannelsLast)
: self,
*weight,
*bias,
mean_,
invstd_);
return;
}
}
Expand Down Expand Up @@ -2858,21 +2869,43 @@ Tensor batch_norm_elementwise_backward_train(
scalar_t,
accscalar_t,
int32_t>(
grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
grad_out.contiguous(),
input.contiguous(),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu);
} else {
return batch_norm_backward_elemt_channels_first_template<
scalar_t,
scalar_t,
int32_t>(
grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
grad_out.contiguous(),
input.contiguous(),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu);
}
});
}
case Impl::ChannelsLast: {
if ((!weight.defined() || weight.is_contiguous()) &&
mean.is_contiguous() && invstd.is_contiguous()) {
return batch_norm_backward_elemt_channels_last_template(
grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu);
// It is a WA to fix Mobile-SSD convergence issue.
grad_out.dim() == 4
? grad_out.contiguous(at::MemoryFormat::ChannelsLast)
: grad_out,
input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast)
: input,
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu);
}
}
case Impl::General: {
Expand Down Expand Up @@ -3091,7 +3124,18 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_dispatch(
(!weight.defined() || weight.is_contiguous()) && mean.is_contiguous() &&
invstd.is_contiguous()) {
return batch_norm_backward_reduce_channels_last_template(
grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g);
// It is a WA to fix Mobile-SSD convergence issue.
grad_output.dim() == 4
? grad_output.contiguous(at::MemoryFormat::ChannelsLast)
: grad_output,
input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast)
: input,
mean,
invstd,
weight,
input_g,
weight_g,
bias_g);
}
return IPEX_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
Expand Down Expand Up @@ -3282,8 +3326,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
scalar_t,
accscalar_t,
int32_t>(
grad_output,
input,
grad_output.contiguous(),
input.contiguous(),
*weight,
*running_mean,
*running_var,
Expand All @@ -3297,8 +3341,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
scalar_t,
scalar_t,
int32_t>(
grad_output,
input,
grad_output.contiguous(),
input.contiguous(),
*weight,
*running_mean,
*running_var,
Expand Down Expand Up @@ -3913,7 +3957,17 @@ Tensor batch_norm_backward_elemt_dispatch(
batch_norm_use_channels_last_kernels(self) &&
batch_norm_use_channels_last_kernels(input)) {
return batch_norm_backward_elemt_channels_last_template(
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
// It is a WA to fix Mobile-SSD convergence issue.
self.dim() == 4 ? self.contiguous(at::MemoryFormat::ChannelsLast)
: self,
input.dim() == 4 ? input.contiguous(at::MemoryFormat::ChannelsLast)
: input,
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count);
}

return IPEX_DISPATCH_FLOATING_TYPES_AND2(
Expand All @@ -3938,27 +3992,55 @@ Tensor batch_norm_backward_elemt_dispatch(
scalar_t,
accscalar_t,
int32_t>(
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
self.contiguous(),
input.contiguous(),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count);
} else {
return batch_norm_backward_elemt_channels_first_template<
scalar_t,
scalar_t,
int32_t>(
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
self.contiguous(),
input.contiguous(),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count);
}
} else {
if (is_half_float || is_bfloat16_float) {
return batch_norm_backward_elemt_channels_first_template<
scalar_t,
accscalar_t,
int64_t>(
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
self.contiguous(),
input.contiguous(),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count);
} else {
return batch_norm_backward_elemt_channels_first_template<
scalar_t,
scalar_t,
int64_t>(
self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
self.contiguous(),
input.contiguous(),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count);
}
}
});
Expand Down

0 comments on commit a1e2271

Please sign in to comment.