mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[StaticRuntime] Add out variant for reshape and flatten (#51249)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51249 - Add out variant for reshape and flatten. reshape and flatten only create tensor views when it can. In cases where it can't, it does a copy. The out variant reuses the TensorImpl for both cases. The difference is that the TensorImpl is a view in the first case, but a normal TensorImpl in the second case. - Create a separate registry for the view ops with out variants. Because Tensor views can't participate in memory reuse (memonger), we need to track these ops separately. - The MemoryPlanner does not track the StorageImpl of tensor views because they don't own the storage, however, in cases where reshape does not create a view, the MemoryPlanner does manage the output tensor. Reviewed By: ajyu Differential Revision: D25992202 fbshipit-source-id: dadd63b78088c129e491d78abaf8b33d8303ca0d
This commit is contained in:
parent
16132a4b1d
commit
d035d56bfb
6 changed files with 231 additions and 20 deletions
|
|
@ -1770,12 +1770,10 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
|
|||
start_dim = maybe_wrap_dim(start_dim, self.dim());
|
||||
end_dim = maybe_wrap_dim(end_dim, self.dim());
|
||||
TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
|
||||
std::vector<int64_t> shape;
|
||||
|
||||
if (self.dim() == 0) {
|
||||
return self.reshape({1});
|
||||
}
|
||||
|
||||
if (start_dim == end_dim) {
|
||||
return self;
|
||||
}
|
||||
|
|
@ -1785,13 +1783,14 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
|
|||
// It's clear we want result shape [0, 3, 0] but passing [0, -1, 0] to infer_size means the -1
|
||||
// can take on any value and satisfy the constraints.
|
||||
auto slice_numel = prod_intlist(self.sizes().slice(start_dim, end_dim - start_dim + 1));
|
||||
std::vector<int64_t> shape;
|
||||
shape.reserve(self.dim() - end_dim + start_dim);
|
||||
for (int64_t i = 0; i < start_dim; i++) {
|
||||
shape.push_back(self.size(i));
|
||||
shape.push_back(self.sizes()[i]);
|
||||
}
|
||||
shape.push_back(slice_numel);
|
||||
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
|
||||
shape.push_back(self.size(i));
|
||||
shape.push_back(self.sizes()[i]);
|
||||
}
|
||||
|
||||
return native::reshape(self, shape);
|
||||
|
|
|
|||
|
|
@ -24,3 +24,25 @@ const auto add_script = R"JIT(
|
|||
def forward(self, a, b):
|
||||
return a + b
|
||||
)JIT";
|
||||
|
||||
const auto reshape_script_1 = R"JIT(
|
||||
def forward(self, a: Tensor, shape: List[int]):
|
||||
return a.reshape(shape)
|
||||
)JIT";
|
||||
|
||||
const auto reshape_script_2 = R"JIT(
|
||||
def forward(self, a: Tensor, shape: List[int]):
|
||||
b = a.transpose(0, 1)
|
||||
return b.reshape(shape)
|
||||
)JIT";
|
||||
|
||||
const auto flatten_script_1 = R"JIT(
|
||||
def forward(self, a: Tensor, start_dim: int, end_dim: int):
|
||||
return torch.flatten(a, start_dim, end_dim)
|
||||
)JIT";
|
||||
|
||||
const auto flatten_script_2 = R"JIT(
|
||||
def forward(self, a: Tensor, start_dim: int, end_dim: int):
|
||||
b = a.transpose(0, 1)
|
||||
return torch.flatten(b, start_dim, end_dim)
|
||||
)JIT";
|
||||
|
|
|
|||
|
|
@ -68,8 +68,7 @@ void testStaticRuntime(
|
|||
compareTensorLists(
|
||||
expect.toTuple()->elements(), actual.toTuple()->elements());
|
||||
} else if (expect.isList()) {
|
||||
compareTensorLists(
|
||||
expect.toTensorVector(), actual.toTensorVector());
|
||||
compareTensorLists(expect.toTensorVector(), actual.toTensorVector());
|
||||
} else {
|
||||
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
|
||||
}
|
||||
|
|
@ -88,6 +87,33 @@ TEST(StaticRuntime, IndividualOps_Binary) {
|
|||
testStaticRuntime(tuple_construct_script, args);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IndividualOps_Reshape) {
|
||||
auto a = at::randn({2, 3});
|
||||
auto b = std::vector<int64_t>({3, 2});
|
||||
std::vector<IValue> args{a, b};
|
||||
|
||||
testStaticRuntime(reshape_script_1, args);
|
||||
testStaticRuntime(reshape_script_2, args);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IndividualOps_flatten) {
|
||||
auto test_flatten =
|
||||
[](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
|
||||
auto a = at::randn(shape);
|
||||
std::vector<IValue> args{a, start_dim, end_dim};
|
||||
testStaticRuntime(flatten_script_1, args);
|
||||
if (shape.size() > 2) {
|
||||
testStaticRuntime(flatten_script_2, args);
|
||||
}
|
||||
};
|
||||
|
||||
test_flatten({2, 3}, 0, 1);
|
||||
test_flatten({2, 1, 3}, 1, 2);
|
||||
test_flatten({0, 1, 3, 0}, 1, 2);
|
||||
test_flatten({2, 3}, 1, 1);
|
||||
test_flatten({}, 0, 0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, LongModel) {
|
||||
torch::jit::Module mod = getLongScriptModel();
|
||||
auto a = torch::randn({2, 2});
|
||||
|
|
@ -271,7 +297,7 @@ TEST(StaticRuntime, FusionPass) {
|
|||
bool hit = false;
|
||||
for (const auto& n : module.get_method("forward").graph()->nodes()) {
|
||||
if (n->kind() == torch::jit::prim::StaticSubgraph) {
|
||||
hit = true;
|
||||
hit = true;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(hit);
|
||||
|
|
@ -280,4 +306,3 @@ TEST(StaticRuntime, FusionPass) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -177,14 +177,16 @@ std::unordered_set<Value*> GetOptimizableValues(
|
|||
std::unordered_set<Value*> cannot_reuse;
|
||||
for (const auto& n : graph->nodes()) {
|
||||
for (const auto& v : n->inputs()) {
|
||||
if (canRunOutOfPlace(n) && canReuseInputs(n)) {
|
||||
if (canRunOutOfPlace(n) && canReuseInputsOutputs(n) &&
|
||||
canReuseInputs(n)) {
|
||||
can_reuse.insert(v);
|
||||
} else {
|
||||
cannot_reuse.insert(v);
|
||||
}
|
||||
}
|
||||
for (const auto& v : n->outputs()) {
|
||||
if (canRunOutOfPlace(n) && canReuseOutputs(n)) {
|
||||
if (canRunOutOfPlace(n) && canReuseInputsOutputs(n) &&
|
||||
canReuseOutputs(n)) {
|
||||
can_reuse.insert(v);
|
||||
} else {
|
||||
cannot_reuse.insert(v);
|
||||
|
|
@ -698,15 +700,23 @@ MemoryPlanner::MemoryPlanner(
|
|||
// some Values should share storage, this map will
|
||||
// keep track of the index into managed_storage_
|
||||
std::unordered_map<Value*, size_t> shared;
|
||||
// the StorageImpls of Tensor views should not be managed
|
||||
std::unordered_set<c10::StorageImpl*> managed_storage_impls;
|
||||
|
||||
// Snapshot of the current memory state
|
||||
for (const auto& pnode : runtime->get_nodes()) {
|
||||
for (auto i = 0; i < pnode.outputs().size(); ++i) {
|
||||
const auto& ival = pnode.outputs()[i];
|
||||
const auto& val = pnode.get_node()->outputs()[i];
|
||||
auto* val = pnode.get_node()->outputs()[i];
|
||||
if (managed_values.count(val)) {
|
||||
TORCH_CHECK(ival.isTensor());
|
||||
auto* impl = ival.toTensor().storage().unsafeGetStorageImpl();
|
||||
|
||||
auto didInsert = managed_storage_impls.insert(impl).second;
|
||||
if (!didInsert) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (shared.count(val)) {
|
||||
managed_storage_[shared.at(val)].second.emplace_back(impl);
|
||||
} else {
|
||||
|
|
@ -741,12 +751,10 @@ void MemoryPlanner::allocate() {
|
|||
if (managed_bytes_ == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
buffer_ = allocate_buffer(managed_bytes_);
|
||||
|
||||
size_t offset = 0;
|
||||
uint8_t* start = static_cast<uint8_t*>(buffer_.get());
|
||||
|
||||
for (const auto& ms : managed_storage_) {
|
||||
auto tensor_size = ms.first;
|
||||
if (tensor_size == 0) {
|
||||
|
|
|
|||
|
|
@ -1,31 +1,136 @@
|
|||
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||
|
||||
#include <ATen/CPUFunctions.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/native/quantized/cpu/qembeddingbag.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
// The out variants of view ops can't be moved to aten because they don't
|
||||
// exactly follow the semantics of the aten ops. aten::reshape/flatten create
|
||||
// views, t, that are tracked by autograd and t.is_view() returns true. Here
|
||||
// t.is_view() would return false instead.
|
||||
at::Tensor& reshape_out(
|
||||
at::Tensor& out,
|
||||
const at::Tensor& self,
|
||||
const std::vector<int64_t>& proposed_shape,
|
||||
bool infer_size = true) {
|
||||
auto shape = infer_size ? at::infer_size(proposed_shape, self.numel())
|
||||
: proposed_shape;
|
||||
auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape);
|
||||
|
||||
if (stride.has_value()) {
|
||||
// create view
|
||||
if (!out.defined() || !out.storage().is_alias_of(self.storage())) {
|
||||
auto impl = c10::make_intrusive<c10::TensorImpl>(
|
||||
c10::Storage(self.storage()), self.key_set(), self.dtype());
|
||||
out = at::Tensor(std::move(impl));
|
||||
}
|
||||
|
||||
c10::TensorImpl* impl = out.unsafeGetTensorImpl();
|
||||
impl->set_storage_offset(self.storage_offset());
|
||||
impl->set_sizes_and_strides(shape, *stride);
|
||||
} else {
|
||||
// copy over tensor
|
||||
if (!out.defined()) {
|
||||
out = at::native::empty_like(
|
||||
self, self.options(), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
// copy first and set shape/strides later. It doesn't work the other way
|
||||
// around.
|
||||
at::native::copy_(out, self);
|
||||
stride = at::detail::computeStride(out.sizes(), out.strides(), shape);
|
||||
c10::TensorImpl* impl = out.unsafeGetTensorImpl();
|
||||
impl->set_sizes_and_strides(shape, *stride);
|
||||
}
|
||||
// namedinference::propagate_names(output, self);
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor& flatten_out(
|
||||
at::Tensor& out,
|
||||
const at::Tensor& self,
|
||||
int64_t start_dim,
|
||||
int64_t end_dim) {
|
||||
start_dim =
|
||||
start_dim < 0 ? c10::maybe_wrap_dim(start_dim, self.dim()) : start_dim;
|
||||
end_dim = end_dim < 0 ? c10::maybe_wrap_dim(end_dim, self.dim()) : end_dim;
|
||||
TORCH_CHECK(
|
||||
start_dim <= end_dim,
|
||||
"flatten() has invalid args: start_dim cannot come after end_dim");
|
||||
|
||||
if (self.dim() == 0) {
|
||||
return reshape_out(out, self, {1}, false);
|
||||
}
|
||||
|
||||
if (start_dim == end_dim) {
|
||||
out = self;
|
||||
return out;
|
||||
}
|
||||
|
||||
// We don't want to infer_size on the entire shape, because that can give us
|
||||
// an extra degree of freedom we don't want; for example, consider shape [0,
|
||||
// 1, 3, 0], with start_dim=1, end_dim=2. It's clear we want result shape [0,
|
||||
// 3, 0] but passing [0, -1, 0] to infer_size means the -1 can take on any
|
||||
// value and satisfy the constraints.
|
||||
auto iter = self.sizes().data();
|
||||
auto slice_numel = std::accumulate(
|
||||
iter + start_dim,
|
||||
iter + end_dim + 1,
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
|
||||
std::vector<int64_t> shape;
|
||||
shape.reserve(self.dim() - end_dim + start_dim);
|
||||
for (int64_t i = 0; i < start_dim; i++) {
|
||||
shape.push_back(self.sizes()[i]);
|
||||
}
|
||||
shape.push_back(slice_numel);
|
||||
for (int64_t i = end_dim + 1; i < self.dim(); i++) {
|
||||
shape.push_back(self.sizes()[i]);
|
||||
}
|
||||
return reshape_out(out, self, shape, false);
|
||||
}
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
|
||||
// View ops with out variants are registered separately
|
||||
C10_DEFINE_REGISTRY(SRViewOperatorRegistry, SROperatorFunctor);
|
||||
|
||||
bool canRunOutOfPlace(Node* n) {
|
||||
auto op_name = std::string(n->kind().toQualString());
|
||||
return SROperatorRegistry()->Has(op_name);
|
||||
return SROperatorRegistry()->Has(op_name) ||
|
||||
SRViewOperatorRegistry()->Has(op_name);
|
||||
}
|
||||
|
||||
// The inputs/outputs of view ops do not participate in memory reuse
|
||||
bool canReuseInputsOutputs(Node* n) {
|
||||
auto op_name = std::string(n->kind().toQualString());
|
||||
return !SRViewOperatorRegistry()->Has(op_name);
|
||||
}
|
||||
|
||||
bool canReuseInputs(Node* n) {
|
||||
auto op_name = std::string(n->kind().toQualString());
|
||||
DCHECK(SROperatorRegistry()->Has(op_name));
|
||||
return SROperatorRegistry()->Create(op_name)->CanReuseInput();
|
||||
if (SROperatorRegistry()->Has(op_name)) {
|
||||
return SROperatorRegistry()->Create(op_name)->CanReuseInput();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool canReuseOutputs(Node* n) {
|
||||
auto op_name = std::string(n->kind().toQualString());
|
||||
DCHECK(SROperatorRegistry()->Has(op_name));
|
||||
return SROperatorRegistry()->Create(op_name)->CanReuseOutput();
|
||||
if (SROperatorRegistry()->Has(op_name)) {
|
||||
return SROperatorRegistry()->Create(op_name)->CanReuseOutput();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: expand to include all view producing ops, mostly in
|
||||
|
|
@ -60,7 +165,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator {
|
|||
p_node->Output(0) = create_empty_from(in0_t);
|
||||
}
|
||||
auto& out_t = p_node->Output(0).toTensor();
|
||||
out_t.resize_({0});
|
||||
fastResizeToZero(out_t);
|
||||
at::cpu::add_out(out_t, in0_t, in1_t, in2_s);
|
||||
};
|
||||
});
|
||||
|
|
@ -327,16 +432,55 @@ REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
|
|||
p_node->Output(0) = create_empty_from(self);
|
||||
}
|
||||
auto& output = p_node->Output(0).toTensor();
|
||||
output.resize_({0});
|
||||
fastResizeToZero(output);
|
||||
at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output);
|
||||
};
|
||||
});
|
||||
|
||||
// Out variants for view ops are registered to a separate registry because
|
||||
// their outputs (views) can't participate in memory reuse.
|
||||
REGISTER_VIEW_OPERATOR_FUNCTOR(
|
||||
aten::reshape,
|
||||
aten_reshape,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
auto& self = p_node->Input(0).toTensor(); // self
|
||||
auto proposed_shape = p_node->Input(1).toIntVector(); // shape
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::Tensor();
|
||||
}
|
||||
auto& out = p_node->Output(0).toTensor();
|
||||
at::native::reshape_out(out, self, proposed_shape, true);
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_VIEW_OPERATOR_FUNCTOR(
|
||||
aten::flatten,
|
||||
aten_flatten,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
DCHECK(p_node->inputs().size() == 3);
|
||||
auto& self = p_node->Input(0).toTensor();
|
||||
auto start_dim = p_node->Input(1).toInt();
|
||||
auto end_dim = p_node->Input(2).toInt();
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::Tensor();
|
||||
}
|
||||
auto& out = p_node->Output(0).toTensor();
|
||||
at::native::flatten_out(out, self, start_dim, end_dim);
|
||||
};
|
||||
});
|
||||
|
||||
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
|
||||
auto op_name = n->kind().toQualString();
|
||||
if (SROperatorRegistry()->Has(op_name)) {
|
||||
return SROperatorRegistry()->Create(op_name)->Generate(n);
|
||||
}
|
||||
if (SRViewOperatorRegistry()->Has(op_name)) {
|
||||
return SRViewOperatorRegistry()->Create(op_name)->Generate(n);
|
||||
}
|
||||
|
||||
return [](ProcessedNode*) { TORCH_CHECK(0); };
|
||||
}
|
||||
|
|
@ -351,6 +495,7 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
|
|||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
DCHECK(p_node->inputs().size() == 3);
|
||||
auto& in0_t = p_node->Input(0).toTensor();
|
||||
auto in1_i = p_node->Input(1).toInt();
|
||||
auto in2_i = p_node->Input(2).toInt();
|
||||
|
|
|
|||
|
|
@ -44,6 +44,17 @@ C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
|
|||
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
|
||||
REGISTER_OPERATOR_FUNCTOR_OPT(name, id, true, true, __VA_ARGS__)
|
||||
|
||||
#define REGISTER_VIEW_OPERATOR_FUNCTOR(name, id, ...) \
|
||||
struct SROperatorFunctor_##id : public SROperatorFunctor { \
|
||||
const SROpFunctor fn = __VA_ARGS__; \
|
||||
SROperator Generate(Node* n) override { \
|
||||
return fn(n); \
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_CLASS(SRViewOperatorRegistry, name, SROperatorFunctor_##id);
|
||||
|
||||
C10_DECLARE_REGISTRY(SRViewOperatorRegistry, SROperatorFunctor);
|
||||
|
||||
inline at::Tensor create_empty_from(const at::Tensor& t) {
|
||||
return at::empty({0}, t.options());
|
||||
}
|
||||
|
|
@ -60,6 +71,7 @@ inline void fastResizeToZero(at::Tensor& t) {
|
|||
}
|
||||
|
||||
bool canRunOutOfPlace(Node* n);
|
||||
bool canReuseInputsOutputs(Node* n);
|
||||
bool canReuseInputs(Node* n);
|
||||
bool canReuseOutputs(Node* n);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue