pytorch/torch/csrc/jit/passes/frozen_conv_folding.cpp
Kazuaki Ishizaki 62ecfa8b79 Fix typo under torch/csrc/jit/passes directory (#97222)
This PR fixes typo in comments under `torch/csrc/jit/passes` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97222
Approved by: https://github.com/davidberard98, https://github.com/kit1980
2023-03-23 04:08:42 +00:00

412 lines
14 KiB
C++

#include <ATen/Utils.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/utils/optimization_utils.h>
#include <torch/csrc/jit/tensorexpr/types.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#endif
namespace torch {
namespace jit {
namespace {
using Tensor = at::Tensor;
bool supportedConvNode(Node* n) {
switch (n->kind()) {
case aten::conv1d:
case aten::conv2d:
case aten::conv3d:
return true;
case aten::_convolution: {
auto transposed_conv =
constant_as<bool>(n->namedInput("transposed")).value_or(true);
// dont handle transposed conv yet or not-constant transpose parameter
return !transposed_conv;
}
default:
return false;
}
}
bool FoldFrozenConvBatchnorm(Block* b) {
bool graph_modified = false;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
graph_modified |= FoldFrozenConvBatchnorm(block);
}
if (n->kind() == aten::batch_norm &&
supportedConvNode(n->inputs().at(0)->node())) {
auto conv = n->inputs().at(0)->node();
auto bn = n;
if (nonConstantParameters(conv) || nonConstantParameters(bn)) {
continue;
}
if (conv->output()->uses().size() > 1) {
continue;
}
auto bn_rm_ivalue = bn->namedInput("running_mean");
auto bn_rv_ivalue = bn->namedInput("running_var");
// check running_mean and running_var has value, if they are
// None(track_running_stats=False), skipping the folding path.
if (bn_rm_ivalue->type() == NoneType::get() &&
bn_rv_ivalue->type() == NoneType::get()) {
continue;
}
auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
auto conv_w = constant_as<Tensor>(conv->namedInput("weight")).value();
// implementation taken from torch/nn/utils/fusion.py
Tensor conv_b;
if (conv->namedInput("bias")->type() == NoneType::get()) {
// If this is on GPU and bias is none and weight was half/bfloat, but
// bn_rm was float, then probably this was a case where autocasting
// casted inputs to conv. And since CUDA conv implementation requires
// all the inputs to have the same scalar dtype, we need to make this
// placeholder have the same type as conv_w.
at::ScalarType bias_dtype = bn_rm.scalar_type();
at::ScalarType weight_dtype = conv_w.scalar_type();
if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
bias_dtype == at::kFloat) {
bias_dtype = weight_dtype;
}
conv_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
} else {
conv_b = constant_as<Tensor>(conv->namedInput("bias")).value();
}
Tensor bn_w;
if (bn->namedInput("weight")->type() == NoneType::get()) {
bn_w = at::ones_like(bn_rm);
} else {
bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
}
Tensor bn_b;
if (n->namedInput("bias")->type() == NoneType::get()) {
bn_b = at::zeros_like(bn_rm);
} else {
bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
}
ConvBNParameters params;
params.conv_w = conv_w;
params.conv_b = conv_b;
params.bn_rm = bn_rm;
params.bn_rv = bn_rv;
params.bn_eps = bn_eps;
params.bn_w = bn_w;
params.bn_b = bn_b;
std::tuple<Tensor, Tensor> out = computeUpdatedConvWeightAndBias(params);
WithInsertPoint guard(conv);
auto fused_conv_w = b->owningGraph()->insertConstant(std::get<0>(out));
auto fused_conv_b = b->owningGraph()->insertConstant(std::get<1>(out));
auto conv_w_value = conv->namedInput("weight");
auto conv_b_value = conv->namedInput("bias");
fused_conv_w->setDebugName(conv_w_value->debugName() + "_fused_bn");
fused_conv_b->setDebugName(conv_b_value->debugName() + "_fused_bn");
conv->replaceInputWith(conv_w_value, fused_conv_w);
conv->replaceInputWith(conv_b_value, fused_conv_b);
bn->output()->replaceAllUsesWith(conv->output());
graph_modified = true;
}
}
return graph_modified;
}
bool supportedAddOrSub(Node* n) {
static const OperatorSet add_set{
"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
// sub is equivalent to add
"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
};
return n->isMemberOf(add_set);
}
// In order to fuse add/sub/mul/div with conv, the dimensions of its
// constant tensor must satisfy the following:
// - with resizing, broadcast to w/ weight/bias tensor shape
// - broadcast to the conv output shape
// It needs to have a shape that can resize to weight/bias
// tensor shape because we need to run the op with the conv
// weights/bias without changing their sizes.
// It needs to broadcast to the conv output shape so that we do
// accidentally change the shape of op output by pre-fusing it
// compared to eager.
// The only dimension value shared by weight/bias/conv output
// is they all contain a dim with value = channels-out. In the
// conv output tensor, this is in the second dimension,
// so the pointwise op tensor may have a second dimension of
// value == channels-out, but all the other dimensions have to be 1
bool opDoesNotBroadCastWithConv(Tensor& op_tensor, Tensor& weight_tensor) {
if (op_tensor.ndimension() > weight_tensor.ndimension()) {
return false;
}
for (int64_t i = op_tensor.ndimension() - 1; i >= 0; i--) {
// channels-out dimension == weight_tensor.size(0)
if (i == 1 && op_tensor.size(i) == weight_tensor.size(0)) {
continue;
}
if (op_tensor.size(i) != 1) {
return false;
}
}
return true;
}
bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) {
if (nonConstantParameters(conv) || nonConstantParameters(op)) {
return false;
}
if (conv->output()->uses().size() > 1) {
return false;
}
Tensor weight_tensor =
constant_as<Tensor>(conv->namedInput("weight")).value();
// avoid fusing op that causes type promotion
// restricting to float avoids int/float difficulties with scalar overload
if (!weight_tensor.is_floating_point()) {
return false;
}
if (op->inputs().at(1)->type()->cast<TensorType>()) {
auto op_tensor = constant_as<Tensor>(op->inputs().at(1)).value();
if (!opDoesNotBroadCastWithConv(op_tensor, weight_tensor)) {
return false;
}
if (!op_tensor.is_floating_point() &&
c10::promoteTypes(
op_tensor.scalar_type(), weight_tensor.scalar_type()) !=
weight_tensor.scalar_type()) {
return false;
}
}
return true;
}
Tensor resizeConstantScalarOrTensorToShape(
Value* v,
const std::vector<int64_t>& shape,
at::TensorOptions options) {
Tensor ret_tensor;
if (v->type()->cast<TensorType>()) {
ret_tensor = constant_as<Tensor>(v).value();
} else {
ret_tensor = at::zeros(shape, options);
if (v->type()->cast<IntType>()) {
ret_tensor.fill_(constant_as<int64_t>(v).value());
} else {
ret_tensor.fill_(constant_as<double>(v).value());
}
}
if (ret_tensor.numel() == 1) {
// expand errors if the shape input has less # dims than the tensor input
ret_tensor = ret_tensor.reshape({1});
ret_tensor = ret_tensor.expand(shape);
} else {
TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
ret_tensor = ret_tensor.view(shape);
}
return ret_tensor;
}
bool FoldFrozenConvAddOrSub(Block* b) {
bool graph_modified = false;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
graph_modified |= FoldFrozenConvAddOrSub(block);
}
if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) {
auto conv = n->inputs().at(0)->node();
auto add_or_sub = n;
if (!checkConvAndBroadcastingOpPreConditions(conv, add_or_sub)) {
continue;
}
Tensor weight_tensor =
constant_as<Tensor>(conv->namedInput("weight")).value();
Tensor add_or_sub_tensor = resizeConstantScalarOrTensorToShape(
add_or_sub->inputs().at(1),
{weight_tensor.size(0)},
weight_tensor.options());
Tensor bias;
if (conv->namedInput("bias")->type() == NoneType::get()) {
bias = at::zeros_like(add_or_sub_tensor, weight_tensor.dtype());
} else {
bias = constant_as<Tensor>(conv->namedInput("bias")).value();
}
WithInsertPoint guard(conv);
add_or_sub->replaceInputWith(
conv->output(), b->owningGraph()->insertConstant(bias));
add_or_sub->replaceInput(
1, b->owningGraph()->insertConstant(add_or_sub_tensor));
auto stack_out = runNodeIfInputsAreConstant(add_or_sub);
TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
auto fused_conv_b = b->owningGraph()->insertConstant(fuse_bias);
auto conv_b_value = conv->namedInput("bias");
fused_conv_b->setDebugName(
conv_b_value->debugName() + "_fused_" +
add_or_sub->kind().toUnqualString());
conv->replaceInputWith(conv_b_value, fused_conv_b);
add_or_sub->output()->replaceAllUsesWith(conv->output());
graph_modified = true;
// DCE run after cleans up nodes
}
}
return graph_modified;
}
bool supportedMulOrDiv(Node* n) {
static const OperatorSet add_set{
"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor",
// div is equivalent to mul
"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
"aten::div.Scalar(Tensor self, Scalar other) -> Tensor",
};
return n->isMemberOf(add_set);
}
bool FoldFrozenConvMulOrDiv(Block* b) {
bool graph_modified = false;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
graph_modified |= FoldFrozenConvMulOrDiv(block);
}
if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) {
auto conv = n->inputs().at(0)->node();
auto mul_or_div = n;
if (!checkConvAndBroadcastingOpPreConditions(conv, mul_or_div)) {
continue;
}
Tensor weight_tensor =
constant_as<Tensor>(conv->namedInput("weight")).value();
int64_t out_channels = weight_tensor.size(0);
// We've already verified that the second input has numel == 1 or
// channels-out resize it to the shape that will broadcast to
// weight_tensor when the op is run so we dont change weight size
std::vector<int64_t> weight_compatible_size = {out_channels};
for (const auto i : c10::irange(1, weight_tensor.ndimension())) {
(void)i; // Suppress unused variable warning
weight_compatible_size.push_back(1);
}
WithInsertPoint guard(conv);
Tensor mul_tensor = resizeConstantScalarOrTensorToShape(
mul_or_div->inputs().at(1),
weight_compatible_size,
weight_tensor.options());
// First fold with weight tensor
mul_or_div->replaceInputWith(
conv->output(), b->owningGraph()->insertConstant(weight_tensor));
mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor));
auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
Tensor fuse_weight = (*stack_out)[0].toTensor().to(weight_tensor.dtype());
auto fused_conv_weight = b->owningGraph()->insertConstant(fuse_weight);
auto conv_weight_value = conv->namedInput("weight");
fused_conv_weight->setDebugName(
conv_weight_value->debugName() + "_fused_" +
mul_or_div->kind().toUnqualString());
conv->replaceInputWith(conv_weight_value, fused_conv_weight);
mul_or_div->output()->replaceAllUsesWith(conv->output());
// now fold with bias tensor
if (conv->namedInput("bias")->type() != NoneType::get()) {
Tensor bias = constant_as<Tensor>(conv->namedInput("bias")).value();
// bias is of shape {channels_out}
auto mul_tensor = resizeConstantScalarOrTensorToShape(
mul_or_div->inputs().at(1), {out_channels}, bias.options());
mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias));
mul_or_div->replaceInput(
1, b->owningGraph()->insertConstant(mul_tensor));
auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
auto fused_conv_bias = b->owningGraph()->insertConstant(fuse_bias);
auto conv_b_value = conv->namedInput("bias");
fused_conv_weight->setDebugName(
conv_b_value->debugName() + "_fused_" +
mul_or_div->kind().toUnqualString());
conv->replaceInputWith(conv_b_value, fused_conv_bias);
}
graph_modified = true;
// DCE run after cleans up nodes
}
}
return graph_modified;
}
} // namespace
bool FoldFrozenConvBatchnorm(std::shared_ptr<Graph>& graph) {
bool graph_modified = FoldFrozenConvBatchnorm(graph->block());
EliminateDeadCode(graph);
return graph_modified;
}
bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) {
bool graph_modified = FoldFrozenConvAddOrSub(graph->block());
EliminateDeadCode(graph);
return graph_modified;
}
bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) {
bool graph_modified = FoldFrozenConvMulOrDiv(graph->block());
EliminateDeadCode(graph);
return graph_modified;
}
} // namespace jit
} // namespace torch