Skip to content

Commit

Permalink
enable conv3d module case for imperative and jit path (#425)
Browse files Browse the repository at this point in the history
* enable conv3d module case for imperative and jit path

* add some checks
  • Loading branch information
XiaobingSuper authored Dec 23, 2021
1 parent ed31bba commit ae33faf
Show file tree
Hide file tree
Showing 20 changed files with 714 additions and 931 deletions.
220 changes: 44 additions & 176 deletions intel_extension_for_pytorch/csrc/aten/cpu/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ void convolution_kernel_output(
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
auto output_sizes = output.sizes();

bool is_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
ideep::tensor mkldnn_output = itensor_view_from_dense(output);

if (bias.defined()) {
Expand Down Expand Up @@ -131,84 +129,6 @@ at::Tensor convolution_kernel(
return output;
}

at::Tensor convolution_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const ideep::attr_t& attr) {
bool use_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
auto mkldnn_memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::Contiguous;
auto input_ = input.contiguous(mkldnn_memory_format);
ideep::tensor mkldnn_weight = get_conv_packed_weight(
weight,
stride,
padding,
dilation,
weight.sizes(),
groups,
use_channels_last,
false,
use_channels_last,
input.sizes(),
attr);
return convolution_kernel(
input_, mkldnn_weight, bias_opt, stride, padding, dilation, groups, attr);
}

void convolution_inplace_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::Tensor& output,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const ideep::attr_t& attr) {
// TODO: the input will be actively converted to channels last format
// after the 5-D tensor supports channels last format.
bool use_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
auto mkldnn_memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::Contiguous;
auto input_ = input.contiguous(mkldnn_memory_format);
ideep::tensor mkldnn_weight = get_conv_packed_weight(
weight,
stride,
padding,
dilation,
weight.sizes(),
groups,
use_channels_last,
false,
use_channels_last,
input.sizes(),
attr);

output = IS_CONTIGUOUS_ANY(output)
? output
: output.contiguous(output.suggest_memory_format());
output = output.to(input_.suggest_memory_format());
convolution_kernel_output(
input_,
mkldnn_weight,
bias_opt,
output,
stride,
padding,
dilation,
groups,
attr);
}

at::Tensor convolution_forward_impl(
const at::Tensor& input,
const at::Tensor& weight,
Expand All @@ -233,25 +153,32 @@ at::Tensor convolution_forward_impl(
weight.scalar_type() == input.scalar_type(),
"the input and weight need have same data type");
TORCH_CHECK(
input.dim() == 4,
"Only support 2d convolution for convolution_forward_impl");
input.dim() == 4 || input.dim() == 5,
"Only support 2d or 3d convolution for convolution_forward_impl");
// TODO: add bias dtype check
// case 1: weight is not packed, check weight.suggest_memory_format()
// case 2: weight is packed or use user's setting, weight_channels_last.
bool weight_use_channels_last =
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d ||
weight_channels_last;
bool use_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight_use_channels_last;
auto mkldnn_memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::Contiguous;
auto input_ = input.contiguous(mkldnn_memory_format);
auto memory_format = at::MemoryFormat::Contiguous;
if (use_channels_last) {
if (input.dim() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
} else {
memory_format = at::MemoryFormat::ChannelsLast3d;
}
}
auto input_ = input.to(memory_format);
at::Tensor weight_ = weight;
// if weight is not packed, convert format, and weight will has same format
// with input.
if (!weight_packed) {
weight_ = weight_.contiguous(mkldnn_memory_format);
weight_ = weight_.contiguous(memory_format);
}
// get original weight dims.
std::vector<int64_t> origin_weight_dims;
Expand All @@ -277,88 +204,6 @@ at::Tensor convolution_forward_impl(
input_, mkldnn_weight, bias_opt, stride, padding, dilation, groups, attr);
}

void convolution_forward_inplace_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::Tensor& output,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
at::IntArrayRef kernel_size,
int64_t groups,
int64_t output_channel,
bool weight_channels_last,
bool weight_packed,
const ideep::attr_t& attr) {
#if defined(IPEX_DISP_OP)
printf("torch_ipex::convolution_forward\n");
#endif
#if defined(IPEX_PROFILE_OP)
RECORD_FUNCTION(
"torch_ipex::convolution_forward_inplace_impl",
std::vector<c10::IValue>({}));
#endif
TORCH_CHECK(
weight.scalar_type() == input.scalar_type(),
"the input and weight need have same data type");
TORCH_CHECK(
input.dim() == 4,
"Only support 2d convolution for convolution_forward_inplace_impl");
// TODO: add bias dtype check
// case 1: weight is not packed, check weight.suggest_memory_format()
// case 2: weight is packed or use user's setting, weight_channels_last.
bool weight_use_channels_last =
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight_channels_last;
bool use_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight_use_channels_last;
auto mkldnn_memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::Contiguous;
auto input_ = input.contiguous(mkldnn_memory_format);
at::Tensor weight_ = weight;
// if weight is not packed, convert format, and weight will has same format
// with input.
if (!weight_packed) {
weight_ = weight_.contiguous(mkldnn_memory_format);
}
// get original weight dims.
std::vector<int64_t> origin_weight_dims;
origin_weight_dims.push_back(output_channel);
origin_weight_dims.push_back(input_.size(1) / groups);
for (auto& s : kernel_size) {
origin_weight_dims.push_back(s);
}
ideep::tensor mkldnn_weight = get_conv_packed_weight(
weight_,
stride,
padding,
dilation,
origin_weight_dims,
groups,
weight_channels_last,
weight_packed,
weight_channels_last,
{},
attr);

output = IS_CONTIGUOUS_ANY(output)
? output
: output.contiguous(output.suggest_memory_format());
output = output.to(input_.suggest_memory_format());
convolution_kernel_output(
input_,
mkldnn_weight,
bias_opt,
output,
stride,
padding,
dilation,
groups,
attr);
}

at::Tensor convolution_backward_input(
at::IntArrayRef input_size,
const at::Tensor& grad_output,
Expand All @@ -371,9 +216,14 @@ at::Tensor convolution_backward_input(
bool bias_defined,
bool weight_use_channels_last,
bool weight_packed) {
TORCH_CHECK(
input_size.size() == 4 || input_size.size() == 5,
"Only support 2d or 3d convolution for convolution_backward_input");

const ideep::tensor mkldnn_grad_output = itensor_view_from_dense(grad_output);
bool is_channels_last =
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;

std::vector<int64_t> origin_weight_dims;
origin_weight_dims.push_back(grad_output.size(1));
Expand Down Expand Up @@ -435,10 +285,14 @@ std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
bool bias_defined,
bool weight_use_channels_last,
bool weight_packed) {
TORCH_CHECK(
input.dim() == 4 || input.dim() == 5,
"Only support 2d or 3d convolution for convolution_backward_weights");
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense(grad_output);
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
bool is_channels_last =
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;

auto grad_weight = at::empty(weight_size, grad_output.options());
at::Tensor grad_bias;
Expand Down Expand Up @@ -496,13 +350,15 @@ std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
return std::make_tuple(grad_weight, grad_bias);
} else {
if (is_channels_last) {
auto memory_format = input.dim() == 4 ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::ChannelsLast3d;
return std::make_tuple(
mkldnn_to_dense(
new_with_itensor_mkldnn(
std::move(mkldnn_grad_weight),
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
grad_output.options().device_opt()))
.to(at::MemoryFormat::ChannelsLast),
.to(memory_format),
grad_bias);
} else {
return std::make_tuple(
Expand Down Expand Up @@ -538,24 +394,36 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> convolution_backward(
weight.scalar_type() == input.scalar_type() &&
weight.scalar_type() == grad_output_t.scalar_type(),
"the inputs need have same data type");
TORCH_CHECK(
input.dim() == 4 || input.dim() == 5,
"Only support 2d or 3d convolution for convolution_backward");

bool weight_use_channels_last =
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d ||
weight_channels_last;
bool use_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d ||
weight_use_channels_last;

auto mkldnn_memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast
: at::MemoryFormat::Contiguous;
auto grad_output_ = grad_output_t.contiguous(mkldnn_memory_format);
auto memory_format = at::MemoryFormat::Contiguous;
if (use_channels_last) {
if (input.dim() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
} else {
memory_format = at::MemoryFormat::ChannelsLast3d;
}
}
auto grad_output_ = grad_output_t.contiguous(memory_format);

at::Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
at::Tensor weight_ = weight;
// if weight is not packed, convert format, and weight will has same format
// with input.
if (!weight_packed) {
weight_ = weight_.contiguous(mkldnn_memory_format);
weight_ = weight_.contiguous(memory_format);
}
grad_input = convolution_backward_input(
input.sizes(),
Expand All @@ -571,7 +439,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> convolution_backward(
weight_packed);
}
if (output_mask[1] || output_mask[2]) {
auto input_ = input.contiguous(mkldnn_memory_format);
auto input_ = input.contiguous(memory_format);
std::tie(grad_weight, grad_bias) = convolution_backward_weights(
weight.sizes(),
grad_output_,
Expand Down
50 changes: 0 additions & 50 deletions intel_extension_for_pytorch/csrc/aten/cpu/Conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,56 +36,6 @@ std::vector<int64_t> calc_conv_output_size(
at::IntArrayRef stride,
at::IntArrayRef dilation);

at::Tensor convolution_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const ideep::attr_t& attr);

void convolution_inplace_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::Tensor& output,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const ideep::attr_t& attr);

at::Tensor convolution_forward_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
at::IntArrayRef kernel_size,
int64_t groups,
int64_t output_channel,
bool weight_channels_last,
bool weight_prepacked,
const ideep::attr_t& attr);

void convolution_forward_inplace_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::Tensor& output,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
at::IntArrayRef kernel_size,
int64_t groups,
int64_t output_channel,
bool weight_channels_last,
bool weight_prepacked,
const ideep::attr_t& attr);

// IPEX customized convolution OP with n-D packed weight
class IPEXConvolutionOp : public torch::autograd::Function<IPEXConvolutionOp> {
public:
Expand Down
Loading

0 comments on commit ae33faf

Please sign in to comment.