Add support for permutting dynamic fusion group outputs to channels last format (#70656)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70656

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D33458650

Pulled By: eellison

fbshipit-source-id: f0c7d20743deac7a87f7c9176e60da8100aefe41
This commit is contained in:
Elias Ellison 2022-01-12 09:08:04 -08:00 committed by Facebook GitHub Bot
parent 39be20f259
commit 5480deb183
9 changed files with 150 additions and 33 deletions

View file

@ -4,12 +4,11 @@
#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/profiler.h>

View file

@ -68,6 +68,7 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
subgraph->inputs().at(0)->setType(x_type);
subgraph->inputs().at(1)->setType(y_type);
subgraph->inputs().at(2)->setType(z_type);
subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
output->node()->addInput(x_inp);
@ -84,7 +85,7 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
->check("TensorExprGroup")
->check_same("symbolic_shape_inputs")
->check("block1")
->check("FallbackGraph")
->check("aten::cat")
->run(*g);
// clang-format off
@ -106,6 +107,7 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
%3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
-> (%3)
block1():
// FallbackGraph is inlined
%14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
-> (%14)
return ()

View file

@ -56,6 +56,7 @@ TEST(DynamicShapes, SimpleGraph) {
std::vector<torch::jit::StrideInput>>
symbolic_strides;
symbolic_strides[x_inp] = input_desc;
symbolic_strides[graph->outputs().at(0)] = input_desc;
std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
x_sym_dims,
[](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
@ -142,6 +143,7 @@ TEST(DynamicShapes, GraphWith2InputsSameDims) {
symbolic_strides;
symbolic_strides[x_inp] = input_desc;
symbolic_strides[y_inp] = input_desc;
symbolic_strides[graph->outputs().at(0)] = input_desc;
TensorExprKernel kernel(
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
@ -232,6 +234,7 @@ TEST(DynamicShapes, GraphWith2InputsAndBroadcast) {
symbolic_strides;
symbolic_strides[x_inp] = input_desc;
symbolic_strides[y_inp] = input_desc;
symbolic_strides[graph->outputs().at(0)] = input_desc;
TensorExprKernel kernel(
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
@ -313,6 +316,8 @@ TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) {
symbolic_strides;
symbolic_strides[x_inp] = input_desc;
symbolic_strides[y_inp] = input_desc;
symbolic_strides[graph->outputs().at(0)] = input_desc;
TensorExprKernel kernel(
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
@ -436,6 +441,8 @@ TEST(DynamicShapes, GraphWithCatAndBroadcast) {
symbolic_strides[x_inp] = input_desc;
symbolic_strides[y_inp] = input_desc;
symbolic_strides[z_inp] = input_desc;
symbolic_strides[graph->outputs().at(0)] = input_desc;
TensorExprKernel kernel(
graph, {}, symbolic_shape_inputs, false, symbolic_strides);

View file

@ -1859,6 +1859,24 @@ class TestTEFuser(JitTestCase):
script = self.checkScript(eager, (x, y))
self.assertAllFused(script.graph_for(x, y))
def test_channels_last_dims_dynamic(self):
def eager(x, y):
return x / (y + 0.0001)
for i in range(4):
size = [2, 3, 4, 5]
size[i] = 1
inp = torch.rand(size).to(memory_format=torch.channels_last)
with texpr_dynamic_enabled():
foo_s = torch.jit.trace(eager, (inp, inp))
for _ in range(3):
out = foo_s(inp, inp)
out_eager = eager(inp, inp)
self.assertEqual(out_eager, out)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
g = torch.jit.last_executed_optimized_graph()
FileCheck().check("TensorExpr").run(g)
def test_unsqueeze_var_dim(self):
def eager(x, y, z: int):
return x * torch.unsqueeze(y, dim=z)

View file

@ -454,7 +454,7 @@ void insertDynamicShapesGuard(
std::vector<std::string> output_striding =
fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
auto output_ival = IValue(input_striding);
auto output_ival = IValue(output_striding);
guarded_node->ival_(attr::striding_outputs_desc, output_ival);
if (add_composed_op) {

View file

@ -22,6 +22,7 @@
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/core/interned_strings.h>
// NOLINTNEXTLINE
C10_DEFINE_bool(
@ -1279,6 +1280,10 @@ Operation createTensorExprOp(const Node* node) {
stride_map[v] = striding_inputs[index];
index++;
}
std::vector<std::string> output_desc = node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
for (size_t i = 0; i < subgraph->outputs().size(); ++i) {
stride_map[subgraph->outputs().at(i)] = {strideInputFromString(output_desc.at(i))};
}
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
std::make_shared<tensorexpr::TensorExprKernel>(

View file

@ -1395,7 +1395,8 @@ ExprPtr PolynomialTransformer::mutate(CompareSelectPtr v) {
ExprPtr false_branch = v->ret_val2()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
if (lhs_new->isConstant() && rhs_new->isConstant() &&
true_branch->isConstant() && false_branch->isConstant()) {
ExprPtr v_new = alloc<CompareSelect>(
lhs_new,
rhs_new,

View file

@ -1111,7 +1111,85 @@ bool denseAndNonOverlapping(
return (strides == at::infer_dense_strides(sizes, strides));
}
Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
Tensor TensorExprKernel::convertOutputToCorrectStrides(
const std::vector<ExprHandle>& sizes,
const std::vector<size_t>& sorted_stride_indices_descending,
const std::vector<ExprPtr>& strides,
BufPtr& buf) {
// We need to convert the output tensor so that its values are layed
// so that when viewed from the output strides the values are correct.
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
// [0] [1] [2] [3] [4] [5]
// The same valued tensor with strides (1, 2) would be layed out like
// [0] [3] [1] [4] [2] [5]
// When we are doing the re-ordering of values into the output tensor,
// we are iterating per-element of the input, and we are fixed
// in indexing in to the output tensor at [i, j] = val
// `val` we want here is equal to the indices for the output
// tensor that would have given the same position as the output
// The position is equal to the sum of stride[i] * index[i],
// and we can can calculate the equivalent indices in the
// output tensor strides by iteratively computing the index of
// the biggest stride:
// absolute = ...
// for stride in strides_from_largest_to_smallest:
// cur_idx = absolute // stride
// absolute = absolute % stride
auto dims = c10::fmap<DimArg>(sizes);
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
auto zero = LongImm::make(0);
return Compute(
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
auto absolute_position = ExprHandle(immLike(axes[0], 0));
for (size_t i = 0; i < axes.size(); ++i) {
ExprHandle stride(default_strides[i]);
ExprHandle axis = axes[i];
absolute_position = absolute_position + (stride * axis);
}
std::vector<ExprHandle> new_axes(sorted_stride_indices_descending.size());
for (size_t stride_index : sorted_stride_indices_descending) {
auto size = sizes[stride_index];
auto stride = strides[stride_index];
auto index = absolute_position / ExprHandle(stride);
auto one = Cast::make(size.dtype(), 1);
// if the size is one, we don't advance the absolute position
// which would give 0
auto non_one_position = absolute_position % ExprHandle(stride);
absolute_position = CompareSelect::make(size, one, absolute_position, non_one_position, kEQ);
new_axes[stride_index] = index;
}
return BufHandle(buf).load(new_axes);
});
}
Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(torch::jit::Value* v) {
const TensorTypePtr& tt = v->type()->expect<TensorType>();
TORCH_INTERNAL_ASSERT(
bufs_.count(v),
buildErrorMessage(
"Ouput tensor has no corresponding bufs in the fuser."));
BufPtr buf = bufs_.at(v);
// output is contiguous, no work to do
if (tensorOutputStrideDesc_[v->offset()] == torch::jit::StrideInput::TENSOR_CONT) {
return Tensor(buf, nullptr);;
}
TORCH_INTERNAL_ASSERT(tensorOutputStrideDesc_[v->offset()] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
auto dims = c10::fmap<DimArg>(sizes);
auto strides = make_channels_last_strides(sizes);
// For a tensor with dimensions N C H W, channels last
// format will is in format N H W C,
// so the order largest to smallest will be N, H, W, C
std::vector<size_t> sorted_stride_indices = {0, 2, 3, 1};
auto zero = LongImm::make(0);
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
// See explanation in convertOutputToCorrectStrides
return convertOutputToCorrectStrides(sizes, sorted_stride_indices, strides, buf);
}
Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v) {
const TensorTypePtr& tt = v->type()->expect<TensorType>();
TORCH_INTERNAL_ASSERT(
bufs_.count(v),
@ -1150,27 +1228,12 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
}
auto dims = c10::fmap<DimArg>(sizesForValue(v));
// We need to convert the output tensor so that its values are layed
// so that when viewed from the output strides the values are correct.
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
// [0] [1] [2] [3] [4] [5]
// The same valued tensor with strides (2, 1) would be layed out like
// [0] [3] [1] [4] [2] [5]
// When we are doing the re-ordering of values into the output tensor,
// we are iterating per-element of the input, and we are fixed
// in indexing in to the output tensor at [i, j] = val
// `val` we want here is equal to the indices for the output
// tensor that would have given the same position as the output
// The position is equal to the sum of stride[i] * index[i],
// and we can can calculate the equivalent indices in the
// output tensor strides by iteratively computing the index of
// the biggest stride:
// absolute = ...
// for stride in strides_from_largest_to_smallest:
// cur_idx = absolute // stride
// absolute = absolute % stride
auto zero = LongImm::make(0);
std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
// TODO: call into `convertOutputToCorrectStrides`. Currently this causes a bug
// in IRSimplifier to occur.
// See explanation in `convertOutputToCorrectStrides`
return Compute(
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
@ -1179,8 +1242,7 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
absolute_position = absolute_position +
(ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]);
}
std::vector<size_t> sorted_stride_indices =
reverse_sort_indices(strides);
std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
for (size_t stride_index : sorted_stride_indices) {
auto size = sizes[stride_index];
@ -1399,14 +1461,22 @@ void TensorExprKernel::compile() {
}
const auto& tt = output->type()->expect<TensorType>();
if (has_symbolic_shapes_) {
// We only support contiguous tensors with symbolic shapes at this time.
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
tensorOutputSymbolicSizes_.push_back(sizes);
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
auto stride_desc = symbolic_strides_[output];
TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
tensorOutputStrideDesc_.push_back(stride_desc[0]);
Tensor properly_strided_output = convertSymbolicOutputToCorrectStrides(output);
if (properly_strided_output.stmt()) {
block->append_stmt(properly_strided_output.stmt());
}
bufs_[output] = properly_strided_output.buf();
} else {
// The "strided" tensor will be incorrect if used in NNC,
// since NNC views it as contiguous. Only convert it to the right
// strides at the end of the kernel (if already contiguous it's a no-op)
Tensor properly_strided_output = convertOutputToCorrectStrides(output);
Tensor properly_strided_output = convertStaticShapeOutputToCorrectStrides(output);
if (properly_strided_output.stmt()) {
block->append_stmt(properly_strided_output.stmt());
}
@ -1529,8 +1599,16 @@ void TensorExprKernel::updateOutputSizesAndStrides(
tensorOutputSizes_[i].emplace_back(inputs[input_pos].toInt());
}
}
tensorOutputStrides_[i] =
TensorType::contiguousStridesOf(tensorOutputSizes_[i]);
if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
tensorOutputStrides_[i] = TensorType::contiguousStridesOf(tensorOutputSizes_[i]);
} else if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
tensorOutputStrides_[i] = at::get_channels_last_strides_2d(tensorOutputSizes_[i]);
} else {
std::string output_desc = toString(tensorOutputStrideDesc_[i]);
TORCH_INTERNAL_ASSERT(
false, "Expected contiguous or channels last, got ", output_desc);
}
}
}

View file

@ -228,7 +228,13 @@ class TORCH_API TensorExprKernel {
Tensor bindInput(const torch::jit::Value* input);
BlockPtr bindAllInputs();
Tensor convertOutputToCorrectStrides(torch::jit::Value* v);
Tensor convertSymbolicOutputToCorrectStrides(torch::jit::Value* v);
Tensor convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v);
Tensor convertOutputToCorrectStrides(
const std::vector<ExprHandle>& sizes,
const std::vector<size_t>& sorted_stride_indices_descending,
const std::vector<ExprPtr>& strides,
BufPtr& buf);
NNCLoweringFunction getCustomLoweringFor(c10::Symbol op) const;
std::unordered_map<c10::Symbol, NNCLoweringFunction> getCustomLowerings()
@ -271,6 +277,7 @@ class TORCH_API TensorExprKernel {
std::vector<CodeGen::BufferArg> bufferArgs_;
std::vector<std::vector<int64_t>> tensorOutputSizes_;
std::vector<std::vector<int64_t>> tensorOutputStrides_;
std::vector<torch::jit::StrideInput> tensorOutputStrideDesc_;
std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
std::unordered_set<BufPtr> bufOutputs_;
std::unordered_map<const torch::jit::Value*, BufPtr> bufs_;