Skip to content

Commit

Permalink
fold binary operators with linear (#438)
Browse files Browse the repository at this point in the history
* fold binary operators with linear

* put common functions in a independent file
  • Loading branch information
jiayisunx authored Jan 21, 2022
1 parent 275ff50 commit b4e7dac
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <ATen/ATen.h>

namespace torch {
namespace jit {

inline bool nonConstantParameters(Node* n) {
// Checks if the parameters, not including the
// first param are all constants.
for (size_t i = 1; i < n->inputs().size(); i++) {
if (n->inputs().at(i)->node()->kind() != prim::Constant) {
return true;
}
}
return false;
}

inline bool supportedAddOrSub(Node* n) {
if (n->kind() == aten::add || n->kind() == aten::sub) {
return true;
} else {
return false;
}
}

inline bool supportedMulOrDiv(Node* n) {
if (n->kind() == aten::mul || n->kind() == aten::div) {
return true;
} else {
return false;
}
}

inline at::Tensor resizeConstantScalarOrTensorToShape(
Value* v,
const std::vector<int64_t>& shape,
at::TensorOptions options) {
at::Tensor ret_tensor;
if (v->type()->cast<TensorType>()) {
ret_tensor = constant_as<at::Tensor>(v).value();
} else {
ret_tensor = at::zeros(shape, options);
if (v->type()->cast<IntType>()) {
ret_tensor.fill_(constant_as<int64_t>(v).value());
} else {
ret_tensor.fill_(constant_as<double>(v).value());
}
}

if (ret_tensor.numel() == 1) {
// expand errors if the shape input has less # dims than the tensor input
ret_tensor = ret_tensor.reshape({1});
ret_tensor = ret_tensor.expand(shape);
} else {
TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
ret_tensor = ret_tensor.view(shape);
}
return ret_tensor;
}

} // namespace jit
} // namespace torch
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <torch/csrc/jit/tensorexpr/types.h>

#include "csrc/aten/cpu/WeightPack.h"
#include "folding_common_utils.h"
#include "frozen_conv_folding.h"

namespace torch {
Expand All @@ -21,17 +22,6 @@ namespace {

using Tensor = at::Tensor;

bool nonConstantParameters(Node* n) {
// Checks if the parameters, not including the
// first param are all constants.
for (size_t i = 1; i < n->inputs().size(); i++) {
if (n->inputs().at(i)->node()->kind() != prim::Constant) {
return true;
}
}
return false;
}

bool supportedConvNode(Node* n) {
if (n->kind() == aten::conv2d || n->kind() == aten::conv3d ||
n->kind() == Symbol::fromQualString("torch_ipex::convolution_forward")) {
Expand All @@ -41,14 +31,6 @@ bool supportedConvNode(Node* n) {
}
}

bool supportedAddOrSub(Node* n) {
if (n->kind() == aten::add || n->kind() == aten::sub) {
return true;
} else {
return false;
}
}

// In order to fuse add/sub/mul/div with conv, the dimensions of its
// constant tensor must satisfy the following:
// - with resizing, broadcast to w/ weight/bias tensor shape
Expand Down Expand Up @@ -137,33 +119,6 @@ bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) {
return true;
}

Tensor resizeConstantScalarOrTensorToShape(
Value* v,
const std::vector<int64_t>& shape,
at::TensorOptions options) {
Tensor ret_tensor;
if (v->type()->cast<TensorType>()) {
ret_tensor = constant_as<Tensor>(v).value();
} else {
ret_tensor = at::zeros(shape, options);
if (v->type()->cast<IntType>()) {
ret_tensor.fill_(constant_as<int64_t>(v).value());
} else {
ret_tensor.fill_(constant_as<double>(v).value());
}
}

if (ret_tensor.numel() == 1) {
// expand errors if the shape input has less # dims than the tensor input
ret_tensor = ret_tensor.reshape({1});
ret_tensor = ret_tensor.expand(shape);
} else {
TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
ret_tensor = ret_tensor.view(shape);
}
return ret_tensor;
}

void FoldFrozenConvAddOrSub(Block* b) {
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
Expand Down Expand Up @@ -224,14 +179,6 @@ void FoldFrozenConvAddOrSub(Block* b) {
}
}

bool supportedMulOrDiv(Node* n) {
if (n->kind() == aten::mul || n->kind() == aten::div) {
return true;
} else {
return false;
}
}

void FoldFrozenConvMulOrDiv(Block* b) {
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
Expand Down
Loading

0 comments on commit b4e7dac

Please sign in to comment.