mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add support for permutting dynamic fusion group outputs to channels last format (#70656)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70656 Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D33458650 Pulled By: eellison fbshipit-source-id: f0c7d20743deac7a87f7c9176e60da8100aefe41
This commit is contained in:
parent
39be20f259
commit
5480deb183
9 changed files with 150 additions and 33 deletions
|
|
@ -4,12 +4,11 @@
|
|||
#include <ATen/Parallel.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <torch/csrc/autograd/engine.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
|
|||
subgraph->inputs().at(0)->setType(x_type);
|
||||
subgraph->inputs().at(1)->setType(y_type);
|
||||
subgraph->inputs().at(2)->setType(z_type);
|
||||
subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
|
||||
auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
|
||||
subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
|
||||
output->node()->addInput(x_inp);
|
||||
|
|
@ -84,7 +85,7 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
|
|||
->check("TensorExprGroup")
|
||||
->check_same("symbolic_shape_inputs")
|
||||
->check("block1")
|
||||
->check("FallbackGraph")
|
||||
->check("aten::cat")
|
||||
->run(*g);
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -106,6 +107,7 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
|
|||
%3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
|
||||
-> (%3)
|
||||
block1():
|
||||
// FallbackGraph is inlined
|
||||
%14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
|
||||
-> (%14)
|
||||
return ()
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ TEST(DynamicShapes, SimpleGraph) {
|
|||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
|
||||
x_sym_dims,
|
||||
[](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
|
||||
|
|
@ -142,6 +143,7 @@ TEST(DynamicShapes, GraphWith2InputsSameDims) {
|
|||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
|
@ -232,6 +234,7 @@ TEST(DynamicShapes, GraphWith2InputsAndBroadcast) {
|
|||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
|
@ -313,6 +316,8 @@ TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) {
|
|||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
|
@ -436,6 +441,8 @@ TEST(DynamicShapes, GraphWithCatAndBroadcast) {
|
|||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[z_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
|
|
|||
|
|
@ -1859,6 +1859,24 @@ class TestTEFuser(JitTestCase):
|
|||
script = self.checkScript(eager, (x, y))
|
||||
self.assertAllFused(script.graph_for(x, y))
|
||||
|
||||
def test_channels_last_dims_dynamic(self):
|
||||
def eager(x, y):
|
||||
return x / (y + 0.0001)
|
||||
|
||||
for i in range(4):
|
||||
size = [2, 3, 4, 5]
|
||||
size[i] = 1
|
||||
inp = torch.rand(size).to(memory_format=torch.channels_last)
|
||||
with texpr_dynamic_enabled():
|
||||
foo_s = torch.jit.trace(eager, (inp, inp))
|
||||
for _ in range(3):
|
||||
out = foo_s(inp, inp)
|
||||
out_eager = eager(inp, inp)
|
||||
self.assertEqual(out_eager, out)
|
||||
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
FileCheck().check("TensorExpr").run(g)
|
||||
|
||||
def test_unsqueeze_var_dim(self):
|
||||
def eager(x, y, z: int):
|
||||
return x * torch.unsqueeze(y, dim=z)
|
||||
|
|
|
|||
|
|
@ -454,7 +454,7 @@ void insertDynamicShapesGuard(
|
|||
|
||||
std::vector<std::string> output_striding =
|
||||
fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
|
||||
auto output_ival = IValue(input_striding);
|
||||
auto output_ival = IValue(output_striding);
|
||||
guarded_node->ival_(attr::striding_outputs_desc, output_ival);
|
||||
|
||||
if (add_composed_op) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
C10_DEFINE_bool(
|
||||
|
|
@ -1279,6 +1280,10 @@ Operation createTensorExprOp(const Node* node) {
|
|||
stride_map[v] = striding_inputs[index];
|
||||
index++;
|
||||
}
|
||||
std::vector<std::string> output_desc = node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
|
||||
for (size_t i = 0; i < subgraph->outputs().size(); ++i) {
|
||||
stride_map[subgraph->outputs().at(i)] = {strideInputFromString(output_desc.at(i))};
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
|
||||
std::make_shared<tensorexpr::TensorExprKernel>(
|
||||
|
|
|
|||
|
|
@ -1395,7 +1395,8 @@ ExprPtr PolynomialTransformer::mutate(CompareSelectPtr v) {
|
|||
ExprPtr false_branch = v->ret_val2()->accept_mutator(this);
|
||||
|
||||
// Constant Folding.
|
||||
if (lhs_new->isConstant() && rhs_new->isConstant()) {
|
||||
if (lhs_new->isConstant() && rhs_new->isConstant() &&
|
||||
true_branch->isConstant() && false_branch->isConstant()) {
|
||||
ExprPtr v_new = alloc<CompareSelect>(
|
||||
lhs_new,
|
||||
rhs_new,
|
||||
|
|
|
|||
|
|
@ -1111,7 +1111,85 @@ bool denseAndNonOverlapping(
|
|||
return (strides == at::infer_dense_strides(sizes, strides));
|
||||
}
|
||||
|
||||
Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
Tensor TensorExprKernel::convertOutputToCorrectStrides(
|
||||
const std::vector<ExprHandle>& sizes,
|
||||
const std::vector<size_t>& sorted_stride_indices_descending,
|
||||
const std::vector<ExprPtr>& strides,
|
||||
BufPtr& buf) {
|
||||
// We need to convert the output tensor so that its values are layed
|
||||
// so that when viewed from the output strides the values are correct.
|
||||
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
|
||||
// [0] [1] [2] [3] [4] [5]
|
||||
// The same valued tensor with strides (1, 2) would be layed out like
|
||||
// [0] [3] [1] [4] [2] [5]
|
||||
// When we are doing the re-ordering of values into the output tensor,
|
||||
// we are iterating per-element of the input, and we are fixed
|
||||
// in indexing in to the output tensor at [i, j] = val
|
||||
// `val` we want here is equal to the indices for the output
|
||||
// tensor that would have given the same position as the output
|
||||
// The position is equal to the sum of stride[i] * index[i],
|
||||
// and we can can calculate the equivalent indices in the
|
||||
// output tensor strides by iteratively computing the index of
|
||||
// the biggest stride:
|
||||
// absolute = ...
|
||||
// for stride in strides_from_largest_to_smallest:
|
||||
// cur_idx = absolute // stride
|
||||
// absolute = absolute % stride
|
||||
auto dims = c10::fmap<DimArg>(sizes);
|
||||
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
|
||||
auto zero = LongImm::make(0);
|
||||
return Compute(
|
||||
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
|
||||
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
|
||||
auto absolute_position = ExprHandle(immLike(axes[0], 0));
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
ExprHandle stride(default_strides[i]);
|
||||
ExprHandle axis = axes[i];
|
||||
absolute_position = absolute_position + (stride * axis);
|
||||
}
|
||||
std::vector<ExprHandle> new_axes(sorted_stride_indices_descending.size());
|
||||
for (size_t stride_index : sorted_stride_indices_descending) {
|
||||
auto size = sizes[stride_index];
|
||||
auto stride = strides[stride_index];
|
||||
auto index = absolute_position / ExprHandle(stride);
|
||||
auto one = Cast::make(size.dtype(), 1);
|
||||
// if the size is one, we don't advance the absolute position
|
||||
// which would give 0
|
||||
auto non_one_position = absolute_position % ExprHandle(stride);
|
||||
absolute_position = CompareSelect::make(size, one, absolute_position, non_one_position, kEQ);
|
||||
new_axes[stride_index] = index;
|
||||
}
|
||||
return BufHandle(buf).load(new_axes);
|
||||
});
|
||||
}
|
||||
|
||||
Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
const TensorTypePtr& tt = v->type()->expect<TensorType>();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
bufs_.count(v),
|
||||
buildErrorMessage(
|
||||
"Ouput tensor has no corresponding bufs in the fuser."));
|
||||
BufPtr buf = bufs_.at(v);
|
||||
// output is contiguous, no work to do
|
||||
if (tensorOutputStrideDesc_[v->offset()] == torch::jit::StrideInput::TENSOR_CONT) {
|
||||
return Tensor(buf, nullptr);;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(tensorOutputStrideDesc_[v->offset()] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
|
||||
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
auto dims = c10::fmap<DimArg>(sizes);
|
||||
auto strides = make_channels_last_strides(sizes);
|
||||
// For a tensor with dimensions N C H W, channels last
|
||||
// format will is in format N H W C,
|
||||
// so the order largest to smallest will be N, H, W, C
|
||||
std::vector<size_t> sorted_stride_indices = {0, 2, 3, 1};
|
||||
auto zero = LongImm::make(0);
|
||||
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
|
||||
// See explanation in convertOutputToCorrectStrides
|
||||
return convertOutputToCorrectStrides(sizes, sorted_stride_indices, strides, buf);
|
||||
}
|
||||
|
||||
|
||||
Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
const TensorTypePtr& tt = v->type()->expect<TensorType>();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
bufs_.count(v),
|
||||
|
|
@ -1150,27 +1228,12 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
|||
}
|
||||
|
||||
auto dims = c10::fmap<DimArg>(sizesForValue(v));
|
||||
// We need to convert the output tensor so that its values are layed
|
||||
// so that when viewed from the output strides the values are correct.
|
||||
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
|
||||
// [0] [1] [2] [3] [4] [5]
|
||||
// The same valued tensor with strides (2, 1) would be layed out like
|
||||
// [0] [3] [1] [4] [2] [5]
|
||||
// When we are doing the re-ordering of values into the output tensor,
|
||||
// we are iterating per-element of the input, and we are fixed
|
||||
// in indexing in to the output tensor at [i, j] = val
|
||||
// `val` we want here is equal to the indices for the output
|
||||
// tensor that would have given the same position as the output
|
||||
// The position is equal to the sum of stride[i] * index[i],
|
||||
// and we can can calculate the equivalent indices in the
|
||||
// output tensor strides by iteratively computing the index of
|
||||
// the biggest stride:
|
||||
// absolute = ...
|
||||
// for stride in strides_from_largest_to_smallest:
|
||||
// cur_idx = absolute // stride
|
||||
// absolute = absolute % stride
|
||||
|
||||
auto zero = LongImm::make(0);
|
||||
std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
|
||||
|
||||
// TODO: call into `convertOutputToCorrectStrides`. Currently this causes a bug
|
||||
// in IRSimplifier to occur.
|
||||
// See explanation in `convertOutputToCorrectStrides`
|
||||
return Compute(
|
||||
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
|
||||
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
|
||||
|
|
@ -1179,8 +1242,7 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
|||
absolute_position = absolute_position +
|
||||
(ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]);
|
||||
}
|
||||
std::vector<size_t> sorted_stride_indices =
|
||||
reverse_sort_indices(strides);
|
||||
|
||||
std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
|
||||
for (size_t stride_index : sorted_stride_indices) {
|
||||
auto size = sizes[stride_index];
|
||||
|
|
@ -1399,14 +1461,22 @@ void TensorExprKernel::compile() {
|
|||
}
|
||||
const auto& tt = output->type()->expect<TensorType>();
|
||||
if (has_symbolic_shapes_) {
|
||||
// We only support contiguous tensors with symbolic shapes at this time.
|
||||
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
tensorOutputSymbolicSizes_.push_back(sizes);
|
||||
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
|
||||
auto stride_desc = symbolic_strides_[output];
|
||||
TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
|
||||
tensorOutputStrideDesc_.push_back(stride_desc[0]);
|
||||
Tensor properly_strided_output = convertSymbolicOutputToCorrectStrides(output);
|
||||
if (properly_strided_output.stmt()) {
|
||||
block->append_stmt(properly_strided_output.stmt());
|
||||
}
|
||||
bufs_[output] = properly_strided_output.buf();
|
||||
} else {
|
||||
// The "strided" tensor will be incorrect if used in NNC,
|
||||
// since NNC views it as contiguous. Only convert it to the right
|
||||
// strides at the end of the kernel (if already contiguous it's a no-op)
|
||||
Tensor properly_strided_output = convertOutputToCorrectStrides(output);
|
||||
Tensor properly_strided_output = convertStaticShapeOutputToCorrectStrides(output);
|
||||
if (properly_strided_output.stmt()) {
|
||||
block->append_stmt(properly_strided_output.stmt());
|
||||
}
|
||||
|
|
@ -1529,8 +1599,16 @@ void TensorExprKernel::updateOutputSizesAndStrides(
|
|||
tensorOutputSizes_[i].emplace_back(inputs[input_pos].toInt());
|
||||
}
|
||||
}
|
||||
tensorOutputStrides_[i] =
|
||||
TensorType::contiguousStridesOf(tensorOutputSizes_[i]);
|
||||
|
||||
if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
|
||||
tensorOutputStrides_[i] = TensorType::contiguousStridesOf(tensorOutputSizes_[i]);
|
||||
} else if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
|
||||
tensorOutputStrides_[i] = at::get_channels_last_strides_2d(tensorOutputSizes_[i]);
|
||||
} else {
|
||||
std::string output_desc = toString(tensorOutputStrideDesc_[i]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Expected contiguous or channels last, got ", output_desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -228,7 +228,13 @@ class TORCH_API TensorExprKernel {
|
|||
Tensor bindInput(const torch::jit::Value* input);
|
||||
BlockPtr bindAllInputs();
|
||||
|
||||
Tensor convertOutputToCorrectStrides(torch::jit::Value* v);
|
||||
Tensor convertSymbolicOutputToCorrectStrides(torch::jit::Value* v);
|
||||
Tensor convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v);
|
||||
Tensor convertOutputToCorrectStrides(
|
||||
const std::vector<ExprHandle>& sizes,
|
||||
const std::vector<size_t>& sorted_stride_indices_descending,
|
||||
const std::vector<ExprPtr>& strides,
|
||||
BufPtr& buf);
|
||||
|
||||
NNCLoweringFunction getCustomLoweringFor(c10::Symbol op) const;
|
||||
std::unordered_map<c10::Symbol, NNCLoweringFunction> getCustomLowerings()
|
||||
|
|
@ -271,6 +277,7 @@ class TORCH_API TensorExprKernel {
|
|||
std::vector<CodeGen::BufferArg> bufferArgs_;
|
||||
std::vector<std::vector<int64_t>> tensorOutputSizes_;
|
||||
std::vector<std::vector<int64_t>> tensorOutputStrides_;
|
||||
std::vector<torch::jit::StrideInput> tensorOutputStrideDesc_;
|
||||
std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
|
||||
std::unordered_set<BufPtr> bufOutputs_;
|
||||
std::unordered_map<const torch::jit::Value*, BufPtr> bufs_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue