mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Add ability to specialize class types to ArgumentSpec (#18314)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18314 ghimport-source-id: 8cecb768d476ab19c9460f39c8f94a764e4cb052 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18314 Add ability to specialize class types to ArgumentSpec** * #18226 Add Slot type to abstract the raw pointers being used for slots. Differential Revision: D14574395 fbshipit-source-id: cc3af6e56e9ae52990f4a1ad56ecceaa2d493577
This commit is contained in:
parent
5f5a2aaab9
commit
2d07993bcb
14 changed files with 422 additions and 199 deletions
|
|
@ -704,7 +704,9 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
|||
Symbol name() const {
|
||||
return typename_;
|
||||
}
|
||||
|
||||
const std::vector<IValue>& slots() const {
|
||||
return slots_;
|
||||
}
|
||||
private:
|
||||
const Symbol typename_;
|
||||
std::vector<IValue> slots_;
|
||||
|
|
|
|||
|
|
@ -1203,10 +1203,21 @@ struct CAFFE2_API ClassType : public Type {
|
|||
attributeTypes_.push_back(type);
|
||||
}
|
||||
|
||||
at::ArrayRef<std::string> attributeNames() const {
|
||||
return attributeNames_;
|
||||
}
|
||||
|
||||
at::ArrayRef<TypePtr> containedTypes() const override {
|
||||
return attributeTypes_;
|
||||
}
|
||||
|
||||
// generate a refined version of this class.
|
||||
// It has the same name but the slot Types are subtypes of
|
||||
// the original slots. It is only valid to refine a class type in a context
|
||||
// where it is know that there are not assignments to the objects slots
|
||||
// that would invalidate the refinement.
|
||||
// These variants are not registered in the global class table.
|
||||
ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
|
||||
static const TypeKind Kind = TypeKind::ClassType;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -478,6 +478,16 @@ ClassTypePtr ClassType::create(
|
|||
return ptr;
|
||||
}
|
||||
|
||||
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
|
||||
auto ptr = ClassTypePtr(new ClassType(typename_, module_));
|
||||
AT_ASSERT(numAttributes() == refined_slots.size());
|
||||
for(size_t i = 0; i < attributeNames_.size(); ++i) {
|
||||
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
|
||||
ptr->addAttribute(attributeNames_[i], refined_slots[i]);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
ClassTypePtr ClassType::get(const std::string& name) {
|
||||
return getRegistry().getType(name);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -208,7 +208,10 @@ void testDifferentiateWithRequiresGrad() {
|
|||
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
|
||||
auto b_var = autograd::make_variable(
|
||||
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
|
||||
setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2));
|
||||
|
||||
ArgumentSpecCreator asc(*graph);
|
||||
asc.setInputTypes(*graph, asc.create(true, {a_var, b_var}));
|
||||
|
||||
PropagateInputShapes(graph);
|
||||
PropagateRequiresGrad(graph);
|
||||
|
||||
|
|
|
|||
|
|
@ -1893,17 +1893,18 @@ class TestJit(JitTestCase):
|
|||
|
||||
def test_tuple_specialization(self):
|
||||
@torch.jit.script
|
||||
def f(t):
|
||||
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
||||
x, y = t
|
||||
def f(t, s):
|
||||
# type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor
|
||||
x, t2 = t
|
||||
_, y = t2
|
||||
return x + y
|
||||
|
||||
t = torch.randn(2, 2), torch.randn(2, 2)
|
||||
f(t)
|
||||
graph = f.graph_for(t)
|
||||
t = torch.randn(2, 2), (1, torch.randn(2, 2)),
|
||||
f(t, "hi")
|
||||
graph = f.graph_for(t, "hi")
|
||||
input_types = list(next(graph.inputs()).type().elements())
|
||||
for t in input_types:
|
||||
self.assertEqual(t.kind(), 'DimensionedTensorType')
|
||||
self.assertEqual(input_types[0].kind(), 'DimensionedTensorType')
|
||||
self.assertEqual(input_types[1].elements()[1].kind(), 'DimensionedTensorType')
|
||||
|
||||
def test_constant_prop_simple(self):
|
||||
@torch.jit.script
|
||||
|
|
@ -3450,13 +3451,11 @@ a")
|
|||
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
|
||||
self.run_pass('constant_propagation', func.graph)
|
||||
self.run_pass('constant_propagation', func2.graph)
|
||||
torch._C._jit_pass_shape_analysis(
|
||||
func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
|
||||
torch._C._jit_pass_shape_analysis(
|
||||
func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
|
||||
self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
|
||||
g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
|
||||
g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
|
||||
self.assertTrue(g.findNode("aten::sum").output().type().kind()
|
||||
== "DimensionedTensorType")
|
||||
self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
|
||||
self.assertTrue(g2.findNode("aten::sum").output().type().kind()
|
||||
== "DimensionedTensorType")
|
||||
|
||||
def test_cat(self):
|
||||
|
|
@ -4154,9 +4153,9 @@ a")
|
|||
torch.mul(x, y, out=z)
|
||||
return z
|
||||
|
||||
torch._C._jit_pass_shape_analysis(
|
||||
test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
|
||||
self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
|
||||
graph = test._get_method('forward').propagate_shapes(
|
||||
(torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
|
||||
self.assertTrue(next(graph.outputs()).type() == TensorType.get())
|
||||
out_op_graph_input()
|
||||
|
||||
def test_resize():
|
||||
|
|
@ -4173,10 +4172,8 @@ a")
|
|||
after_resize_alias = b.add_(1)
|
||||
return after_resize_alias
|
||||
|
||||
g = test.graph
|
||||
self.run_pass('constant_propagation', g)
|
||||
torch._C._jit_pass_shape_analysis(
|
||||
g, (torch.zeros(1, 1),), False)
|
||||
self.run_pass('constant_propagation', test.graph)
|
||||
g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
|
||||
resize_node = g.findNode("aten::resize_")
|
||||
# first input and output of b.resize_ is b
|
||||
self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
|
||||
|
|
@ -4200,8 +4197,7 @@ a")
|
|||
|
||||
g = test.graph
|
||||
self.run_pass('constant_propagation', g)
|
||||
torch._C._jit_pass_shape_analysis(
|
||||
g, (torch.zeros(1, 1),), False)
|
||||
g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
|
||||
|
||||
# x doesn't alias a resized op so it shouldn't be set to base Tensor type
|
||||
self.assertTrue(next(g.inputs()).type() != TensorType.get())
|
||||
|
|
@ -4255,8 +4251,8 @@ a")
|
|||
return x.view(T, B, C)
|
||||
|
||||
x = torch.randn(3, 1, 5, requires_grad=True)
|
||||
graph = torch.jit.script(fn).graph
|
||||
torch._C._jit_pass_shape_analysis(graph, (x,), False)
|
||||
fn = torch.jit.script(fn)
|
||||
graph = fn._get_method('forward').propagate_shapes((x,), False)
|
||||
a = next(graph.outputs()).type().kind()
|
||||
self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
|
||||
|
||||
|
|
@ -6677,7 +6673,7 @@ a")
|
|||
return torch.cat(c)
|
||||
|
||||
b = torch.zeros(2, 4)
|
||||
test_list.graph.propagate_shapes((b,), False)
|
||||
test_list._get_method('forward').propagate_shapes((b,), False)
|
||||
|
||||
def test_if_supertype(self):
|
||||
@torch.jit.script
|
||||
|
|
@ -6694,8 +6690,8 @@ a")
|
|||
b = torch.zeros(2, 4, dtype=torch.long)
|
||||
c = torch.zeros(2, 4, dtype=torch.float)
|
||||
|
||||
tensor_unifying.graph.propagate_shapes((a, b, c), False)
|
||||
if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs())
|
||||
graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False)
|
||||
if_outputs = list(graph.findNode("prim::If").outputs())
|
||||
self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
|
||||
self.assertTrue(if_outputs[1].type().str() == "Tensor")
|
||||
self.assertTrue(if_outputs[2].type().str() == "Tensor")
|
||||
|
|
@ -13303,6 +13299,30 @@ class TestClassType(JitTestCase):
|
|||
self.assertEqual(x, f2.x)
|
||||
self.assertEqual(y, f2.y)
|
||||
|
||||
def test_class_specialization(self):
|
||||
@torch.jit.script # noqa: B903
|
||||
class Foo(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def use_foo(foo, foo2, tup):
|
||||
# type: (Foo, Foo, Tuple[Foo, Foo]) -> Tensor
|
||||
a, b = tup
|
||||
return foo.x + foo2.y + a.x + b.y
|
||||
|
||||
# create from python
|
||||
x = torch.ones(2, 3)
|
||||
y = torch.zeros(2, 3)
|
||||
f = Foo(x, y)
|
||||
f2 = Foo(x * 2, y * 3)
|
||||
f3 = Foo(x * 4, y * 4)
|
||||
|
||||
input = (f, f2, (f, f3))
|
||||
sfoo = self.checkScript(use_foo, input)
|
||||
graphstr = str(sfoo.graph_for(*input))
|
||||
FileCheck().check_count("Double(*, *) = prim::GetAttr", 4).run(graphstr)
|
||||
|
||||
|
||||
class TestLogging(JitTestCase):
|
||||
def test_bump_numeric_counter(self):
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ libtorch_sources = [
|
|||
"torch/csrc/Exceptions.cpp",
|
||||
"torch/csrc/jit/autodiff.cpp",
|
||||
"torch/csrc/jit/attributes.cpp",
|
||||
"torch/csrc/jit/argument_spec.cpp",
|
||||
"torch/csrc/jit/constants.cpp",
|
||||
"torch/csrc/jit/node_hashing.cpp",
|
||||
"torch/csrc/jit/export.cpp",
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ set(TORCH_SRCS
|
|||
${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/attributes.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/export.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/pickler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp
|
||||
|
|
|
|||
229
torch/csrc/jit/argument_spec.cpp
Normal file
229
torch/csrc/jit/argument_spec.cpp
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
|
||||
#include <torch/csrc/jit/argument_spec.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void ArgumentSpecCreator::scan(
|
||||
const TypePtr& typ,
|
||||
size_t depth,
|
||||
const WrittenSlots& written_slots) {
|
||||
auto finishAggregate = [&](size_t pos) {
|
||||
// it is possible after all the work we did to scan this aggregate,
|
||||
// we found no tensors to specialize. In this case, just generate
|
||||
// a skip for the whole aggregate.
|
||||
bool any_spec = std::any_of(
|
||||
instructions_.begin() + pos, instructions_.end(), [](Inst i) {
|
||||
return i == SPECIALIZE_TENSOR;
|
||||
});
|
||||
if (!any_spec) {
|
||||
instructions_[pos] = SKIP;
|
||||
instructions_.resize(pos + 1);
|
||||
} else {
|
||||
instructions_.emplace_back(LEAVE);
|
||||
}
|
||||
};
|
||||
// the simple vm that scans instructions_ has a limited stack depth,
|
||||
// this prevents going deeper than that.
|
||||
if (depth >= DEPTH_LIMIT) {
|
||||
instructions_.emplace_back(SKIP);
|
||||
}
|
||||
if (typ->isSubtypeOf(TensorType::get())) {
|
||||
num_tensors_++;
|
||||
instructions_.emplace_back(SPECIALIZE_TENSOR);
|
||||
} else if (auto tup = typ->cast<TupleType>()) {
|
||||
size_t pos = instructions_.size();
|
||||
instructions_.emplace_back(ENTER_TUPLE);
|
||||
for (const auto& elem : tup->containedTypes()) {
|
||||
scan(elem, depth + 1, written_slots);
|
||||
}
|
||||
finishAggregate(pos);
|
||||
} else if (auto cls = typ->cast<ClassType>()) {
|
||||
size_t pos = instructions_.size();
|
||||
instructions_.emplace_back(ENTER_OBJECT);
|
||||
for (size_t i = 0; i < cls->numAttributes(); ++i) {
|
||||
auto key = cls->name() + cls->attributeNames().at(i);
|
||||
// it is only safe to specialize because someone might have written to it
|
||||
if (!written_slots.count(key)) {
|
||||
scan(cls->containedTypes().at(i), depth + 1, written_slots);
|
||||
} else {
|
||||
instructions_.emplace_back(SKIP);
|
||||
}
|
||||
}
|
||||
finishAggregate(pos);
|
||||
} else {
|
||||
instructions_.emplace_back(SKIP);
|
||||
}
|
||||
};
|
||||
|
||||
// this is a coarse-grained guarentee that the slots of a class will not be
|
||||
// modified by the function. It works fine for things that used be read-only
|
||||
// modules, but will be overly conservative when some classes are written to.
|
||||
// Doing alias analysis and looking for writes to the class would be more
|
||||
// accurate.
|
||||
static void scanWrittenSlots(
|
||||
Block* block,
|
||||
ArgumentSpecCreator::WrittenSlots& written_slots) {
|
||||
for (Node* n : block->nodes()) {
|
||||
if (n->kind() == prim::SetAttr) {
|
||||
if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
|
||||
written_slots.insert(cls->name() + n->s(attr::name));
|
||||
}
|
||||
}
|
||||
for (Block* subblock : n->blocks()) {
|
||||
scanWrittenSlots(subblock, written_slots);
|
||||
}
|
||||
if (n->hasAttribute(attr::Subgraph)) {
|
||||
scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
|
||||
: num_inputs_(graph.inputs().size()) {
|
||||
WrittenSlots written_slots;
|
||||
scanWrittenSlots(graph.block(), written_slots);
|
||||
for (Value* input : graph.inputs()) {
|
||||
scan(input->type(), 0, written_slots);
|
||||
}
|
||||
}
|
||||
|
||||
void ArgumentSpecCreator::dump() const {
|
||||
for (Inst inst : instructions_) {
|
||||
switch (inst) {
|
||||
case LEAVE:
|
||||
std::cout << "] ";
|
||||
break;
|
||||
case ENTER_TUPLE:
|
||||
std::cout << "Tuple[";
|
||||
break;
|
||||
case ENTER_OBJECT:
|
||||
std::cout << "Object[";
|
||||
break;
|
||||
case SKIP:
|
||||
std::cout << "Skip ";
|
||||
break;
|
||||
case SPECIALIZE_TENSOR:
|
||||
std::cout << "SpecializeTensor ";
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
|
||||
const {
|
||||
ArgumentSpec spec(num_tensors_);
|
||||
const IValue* stack[DEPTH_LIMIT]; // The stack of IValue lists
|
||||
// The stack gets initialized with the input list
|
||||
stack[0] = last(input, num_inputs_).begin();
|
||||
size_t stack_top = 0; // offset to the top of the stack
|
||||
for (Inst inst : instructions_) {
|
||||
switch (inst) {
|
||||
case SPECIALIZE_TENSOR:
|
||||
// consume a tensor and add to the argspec
|
||||
spec.addTensor(*stack[stack_top]++, with_grad);
|
||||
break;
|
||||
case ENTER_TUPLE: {
|
||||
// consume tuple
|
||||
const IValue* iv = stack[stack_top]++;
|
||||
AT_ASSERT(iv->isTuple());
|
||||
// see [argspec refcounting]
|
||||
auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
|
||||
auto tup_ptr = &p->elements()[0];
|
||||
// push list of tuple elements to the stack
|
||||
stack[++stack_top] = tup_ptr;
|
||||
} break;
|
||||
case ENTER_OBJECT: {
|
||||
// consume object
|
||||
const IValue* iv = stack[stack_top]++;
|
||||
AT_ASSERT(iv->isObject());
|
||||
iv->toObject();
|
||||
// see [argspec refcounting]
|
||||
auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv);
|
||||
auto obj_ptr = &p->slots()[0];
|
||||
// push list of object elements to the stack
|
||||
stack[++stack_top] = obj_ptr;
|
||||
} break;
|
||||
case SKIP:
|
||||
// consume and skip an element
|
||||
stack[stack_top]++;
|
||||
break;
|
||||
case LEAVE:
|
||||
--stack_top;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
// For every input of a given graph, returns a most detailed type that can be
|
||||
// inferred for it based on this ArgumentSpec.
|
||||
std::vector<TypePtr> ArgumentSpecCreator::getSpecializedTypes(
|
||||
Graph& graph,
|
||||
const ArgumentSpec& spec) const {
|
||||
auto input_types =
|
||||
fmap(graph.inputs(), [](Value* input) { return input->type(); });
|
||||
std::vector<std::vector<TypePtr>> result_stack;
|
||||
result_stack.emplace_back();
|
||||
std::vector<const TypePtr*> input_stack = {input_types.data()};
|
||||
std::vector<std::function<TypePtr()>> aggregate_creators;
|
||||
|
||||
size_t arg_spec_offset = 0; // number of specialized tensors seen so far
|
||||
|
||||
for (Inst inst : instructions_) {
|
||||
switch (inst) {
|
||||
case SPECIALIZE_TENSOR: {
|
||||
input_stack.back()++;
|
||||
auto& arg = spec.at(arg_spec_offset++);
|
||||
if (!arg.defined()) {
|
||||
result_stack.back().emplace_back(AutogradZeroTensorType::get());
|
||||
} else {
|
||||
result_stack.back().emplace_back(DimensionedTensorType::create(
|
||||
arg.type(),
|
||||
ConvertIntToCPUOrCUDA(arg.device()),
|
||||
arg.dim(),
|
||||
arg.requires_grad()));
|
||||
}
|
||||
} break;
|
||||
case ENTER_TUPLE: {
|
||||
auto tup = (*input_stack.back()++)->expect<TupleType>();
|
||||
input_stack.emplace_back(tup->elements().data());
|
||||
result_stack.emplace_back();
|
||||
aggregate_creators.emplace_back(
|
||||
[&] { return TupleType::create(result_stack.back()); });
|
||||
} break;
|
||||
case ENTER_OBJECT: {
|
||||
auto cls = (*input_stack.back()++)->expect<ClassType>();
|
||||
input_stack.emplace_back(cls->containedTypes().data());
|
||||
result_stack.emplace_back();
|
||||
aggregate_creators.emplace_back(
|
||||
[&result_stack, cls] { return cls->refine(result_stack.back()); });
|
||||
} break;
|
||||
case SKIP:
|
||||
result_stack.back().emplace_back(*input_stack.back()++);
|
||||
break;
|
||||
case LEAVE:
|
||||
TypePtr result = aggregate_creators.back()();
|
||||
result_stack.pop_back();
|
||||
aggregate_creators.pop_back();
|
||||
input_stack.pop_back();
|
||||
result_stack.back().emplace_back(std::move(result));
|
||||
break;
|
||||
}
|
||||
}
|
||||
AT_ASSERT(result_stack.size() == 1);
|
||||
return result_stack.back();
|
||||
}
|
||||
|
||||
void ArgumentSpecCreator::setInputTypes(Graph& g, const ArgumentSpec& spec)
|
||||
const {
|
||||
auto input_types = getSpecializedTypes(g, spec);
|
||||
auto inputs = g.inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
inputs[i]->setType(input_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/variable_tensor_list.h>
|
||||
#include <torch/csrc/utils/hash.h>
|
||||
#include <iostream>
|
||||
|
|
@ -22,9 +22,6 @@ struct ArgumentInfo {
|
|||
friend struct ArgumentSpec;
|
||||
using plain_data_type = uint32_t;
|
||||
|
||||
bool isTensor() const {
|
||||
return is_tensor_;
|
||||
}
|
||||
bool defined() const {
|
||||
return defined_;
|
||||
}
|
||||
|
|
@ -45,11 +42,11 @@ struct ArgumentInfo {
|
|||
operator TypePtr() const {
|
||||
if (!defined())
|
||||
return TensorType::get();
|
||||
return DimensionedTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim());
|
||||
return DimensionedTensorType::create(
|
||||
type(), ConvertIntToCPUOrCUDA(device()), dim());
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned is_tensor_ : 1;
|
||||
unsigned defined_ : 1;
|
||||
unsigned requires_grad_ : 1;
|
||||
unsigned : 5;
|
||||
|
|
@ -67,48 +64,32 @@ static_assert(
|
|||
"ArgumentInfo is expected to be a 32-bit struct");
|
||||
|
||||
struct ArgumentSpec {
|
||||
ArgumentSpec(
|
||||
bool with_grad,
|
||||
at::ArrayRef<IValue> inputs,
|
||||
size_t num_flat_inputs) {
|
||||
ArgumentSpec(size_t num_flat_inputs) {
|
||||
hash_code = num_flat_inputs;
|
||||
args.resize(num_flat_inputs);
|
||||
size_t offset = 0;
|
||||
for (const auto& i : inputs) {
|
||||
addInput(i, offset, with_grad);
|
||||
}
|
||||
AT_ASSERT(offset <= num_flat_inputs);
|
||||
args.reserve(num_flat_inputs);
|
||||
}
|
||||
|
||||
void addInput(const IValue& input, size_t& offset, bool with_grad) {
|
||||
auto& arg = args.at(offset);
|
||||
void addTensor(const IValue& input, bool with_grad) {
|
||||
AT_ASSERT(input.isTensor());
|
||||
args.emplace_back();
|
||||
auto& arg = args.back();
|
||||
// Initialize all fields to 0. This is convenient, because e.g.
|
||||
// requires_grad() can be checked even on tensors AND will make
|
||||
// padding bits all 0s.
|
||||
std::memset(&arg, 0, sizeof(ArgumentInfo));
|
||||
|
||||
if (input.isTensor()) {
|
||||
at::Tensor t = input.toTensor();
|
||||
if ((arg.defined_ = t.defined())) {
|
||||
arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
|
||||
arg.dim_ = t.dim();
|
||||
arg.device_ = t.is_cuda() ? t.get_device() : -1;
|
||||
arg.type_ = static_cast<unsigned>(t.scalar_type());
|
||||
}
|
||||
|
||||
arg.is_tensor_ = true;
|
||||
combineHash(arg);
|
||||
offset++;
|
||||
} else if (input.isTuple()) {
|
||||
for (const IValue& elem : input.toTuple()->elements()) {
|
||||
addInput(elem, offset, with_grad);
|
||||
}
|
||||
} else {
|
||||
// NB: no need to set is_tensor to false, because we memset the struct to
|
||||
// 0 above
|
||||
combineHash(arg);
|
||||
offset++;
|
||||
// [argspec refcounting] reinterpret the IValue to avoid having to refcount
|
||||
// the Tensor microbenchmarks
|
||||
// https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
|
||||
// show overhead in extra refcounting along this path
|
||||
const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
|
||||
if ((arg.defined_ = t->defined())) {
|
||||
arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
|
||||
arg.dim_ = t->dim();
|
||||
arg.device_ = t->is_cuda() ? t->get_device() : -1;
|
||||
arg.type_ = static_cast<unsigned>(t->scalar_type());
|
||||
}
|
||||
combineHash(arg);
|
||||
}
|
||||
|
||||
void combineHash(const ArgumentInfo& arg) {
|
||||
|
|
@ -143,38 +124,49 @@ struct ArgumentSpec {
|
|||
size_t hashCode() const {
|
||||
return hash_code;
|
||||
}
|
||||
// For every input of a given graph, returns a most detailed type that can be
|
||||
// inferred for it based on this ArgumentSpec.
|
||||
std::vector<TypePtr> getTypes(Graph& graph) const {
|
||||
size_t offset = 0;
|
||||
return fmap(
|
||||
graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); });
|
||||
}
|
||||
|
||||
private:
|
||||
TypePtr fillType(TypePtr original, size_t& offset) const {
|
||||
if (original->isSubtypeOf(TensorType::get())) {
|
||||
auto& arg = args.at(offset++);
|
||||
if (!arg.defined())
|
||||
return AutogradZeroTensorType::get();
|
||||
return DimensionedTensorType::create(
|
||||
arg.type(),
|
||||
ConvertIntToCPUOrCUDA(arg.device()),
|
||||
arg.dim(),
|
||||
arg.requires_grad());
|
||||
} else if (auto tuple_type = original->cast<TupleType>()) {
|
||||
return TupleType::create(fmap(
|
||||
tuple_type->elements(),
|
||||
[&](const TypePtr& subtype) { return fillType(subtype, offset); }));
|
||||
} else {
|
||||
offset++;
|
||||
return original;
|
||||
}
|
||||
}
|
||||
size_t hash_code; // precomputed on construction
|
||||
std::vector<ArgumentInfo> args;
|
||||
};
|
||||
|
||||
// ArgumentSpecCreator takes an initial graph and comes up with a set
|
||||
// of simple instructions to compute the ArgumentSpec given a set of
|
||||
// input tensors.
|
||||
struct ArgumentSpecCreator {
|
||||
// instructs acts on a stack of a list of input IValues
|
||||
// at the beginning the stack contains a single list of the inputs to the
|
||||
// function the ENTER_ instructs descend into subobjects and push new lists
|
||||
// onto the stack
|
||||
enum Inst : char {
|
||||
ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the
|
||||
// list of its elements onto the stack as a new list
|
||||
ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class
|
||||
LEAVE, // pop the top-most list from the stack
|
||||
SKIP, // consume an element from the top-most list, and discard
|
||||
SPECIALIZE_TENSOR, // consume a tensor for the top-most list, and
|
||||
// add it to the ArgSpec key being created
|
||||
};
|
||||
ArgumentSpecCreator(Graph& graph);
|
||||
ArgumentSpec create(bool with_grad, const Stack& stack) const;
|
||||
void setInputTypes(Graph& g, const ArgumentSpec& spec) const;
|
||||
std::vector<TypePtr> getSpecializedTypes(
|
||||
Graph& graph,
|
||||
const ArgumentSpec& spec) const;
|
||||
void dump() const;
|
||||
using WrittenSlots = std::unordered_set<std::string>;
|
||||
|
||||
private:
|
||||
static constexpr size_t DEPTH_LIMIT = 128;
|
||||
void scan(
|
||||
const TypePtr& typ,
|
||||
size_t depth,
|
||||
const WrittenSlots& written_slots);
|
||||
size_t num_inputs_;
|
||||
size_t num_tensors_ = 0;
|
||||
std::vector<Inst> instructions_;
|
||||
};
|
||||
|
||||
// CompleteArgumentSpec represents one particular specialization.
|
||||
// It is designed so that it can be created, hashed, and compared quickly
|
||||
// since it is used along the hot-path of the JIT to check if the code
|
||||
|
|
@ -398,14 +390,6 @@ inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const {
|
|||
return CompleteArgumentInfo(*this, i);
|
||||
}
|
||||
|
||||
inline void setInputTypes(Graph& g, const ArgumentSpec& spec) {
|
||||
auto input_types = spec.getTypes(g);
|
||||
auto inputs = g.inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
inputs[i]->setType(input_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
|
|
|
|||
|
|
@ -322,28 +322,6 @@ struct GraphExecutorImpl {
|
|||
return copy;
|
||||
}
|
||||
|
||||
static size_t countFlatInputs(const TypePtr& ptr) {
|
||||
if (auto optional_type = ptr->cast<OptionalType>()) {
|
||||
return countFlatInputs(optional_type->getElementType());
|
||||
}
|
||||
if (auto tuple_type = ptr->cast<TupleType>()) {
|
||||
size_t total = 0;
|
||||
for (auto& elem : tuple_type->elements()) {
|
||||
total += countFlatInputs(elem);
|
||||
}
|
||||
return total;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static size_t countFlatInputs(const std::shared_ptr<Graph>& graph) {
|
||||
size_t total = 0;
|
||||
for (Value* input : graph->inputs()) {
|
||||
total += countFlatInputs(input->type());
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
inline bool hasMutableOperators(Block* block) {
|
||||
for (auto n : block->nodes()) {
|
||||
if (n->kind().is_aten() && n->schema().is_mutable())
|
||||
|
|
@ -362,11 +340,11 @@ struct GraphExecutorImpl {
|
|||
// disables all optimization
|
||||
optimize(optimize),
|
||||
num_inputs(this->graph->inputs().size()),
|
||||
num_flat_inputs(countFlatInputs(graph)),
|
||||
arg_spec_creator_(*graph),
|
||||
num_outputs(this->graph->outputs().size()) {
|
||||
logging::getLogger()->addStatValue(
|
||||
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
|
||||
}
|
||||
logging::getLogger()->addStatValue(
|
||||
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
|
||||
}
|
||||
|
||||
// entry point where execution begins
|
||||
void run(Stack& stack) {
|
||||
|
|
@ -391,9 +369,9 @@ struct GraphExecutorImpl {
|
|||
|
||||
std::shared_ptr<Graph> graphFor(const Stack& stack) const {
|
||||
AT_ASSERT(stack.size() >= num_inputs);
|
||||
auto inputs = last(stack, num_inputs);
|
||||
ArgumentSpec spec(
|
||||
autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
|
||||
|
||||
ArgumentSpec spec =
|
||||
arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
|
||||
|
||||
if (!optimize) {
|
||||
AT_CHECK(fallback, "No graph found for given inputs");
|
||||
|
|
@ -441,10 +419,8 @@ struct GraphExecutorImpl {
|
|||
const ExecutionPlan& getOrCompile(const Stack& stack) {
|
||||
// outside lock guard, to minimize the time holding the lock on the fast
|
||||
// path ArgumentSpec even computes its hashCode here.
|
||||
ArgumentSpec spec(
|
||||
autograd::GradMode::is_enabled(),
|
||||
last(stack, num_inputs),
|
||||
num_flat_inputs);
|
||||
ArgumentSpec spec =
|
||||
arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||
auto it = plan_cache.find(spec);
|
||||
|
|
@ -463,7 +439,7 @@ struct GraphExecutorImpl {
|
|||
|
||||
ExecutionPlan compileSpec(const ArgumentSpec& spec) {
|
||||
auto opt_graph = graph->copy();
|
||||
setInputTypes(*opt_graph, spec);
|
||||
arg_spec_creator_.setInputTypes(*opt_graph, spec);
|
||||
|
||||
// Phase 1. Specialize to input definedness (this is very important for
|
||||
// gradient graphs), and run required passes to bring the graph
|
||||
|
|
@ -562,8 +538,8 @@ struct GraphExecutorImpl {
|
|||
auto input_values = fmap(
|
||||
inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); });
|
||||
|
||||
ArgumentSpec spec(
|
||||
autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
|
||||
ArgumentSpec spec =
|
||||
arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
|
||||
// NB: we could just run the fallback in here and call it a day, but that
|
||||
// would loose all the control flow information we have in the graph. Thus,
|
||||
// we run the fallback to get the correct output values, but we will
|
||||
|
|
@ -580,7 +556,7 @@ struct GraphExecutorImpl {
|
|||
// tracing and so we only do the type propgation if no concrete types have
|
||||
// been set.
|
||||
auto local_graph = this->graph->copy();
|
||||
setInputTypes(*local_graph, spec);
|
||||
arg_spec_creator_.setInputTypes(*local_graph, spec);
|
||||
PropagateInputShapes(local_graph);
|
||||
auto output_values =
|
||||
inlineCallTo(*state->graph, *local_graph, input_values);
|
||||
|
|
@ -600,8 +576,7 @@ struct GraphExecutorImpl {
|
|||
// Useful for debugging.
|
||||
const bool optimize;
|
||||
const size_t num_inputs;
|
||||
const size_t num_flat_inputs; // Number of inputs, assuming all tuples would
|
||||
// be flattened.
|
||||
ArgumentSpecCreator arg_spec_creator_;
|
||||
const size_t num_outputs;
|
||||
|
||||
// Populated only when optimize is false (and in that case plan_cache will be
|
||||
|
|
|
|||
|
|
@ -155,16 +155,6 @@ void initJITBindings(PyObject* module) {
|
|||
"_jit_pass_canonicalize",
|
||||
[](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
|
||||
.def("_jit_pass_lint", LintGraph)
|
||||
.def(
|
||||
"_jit_pass_shape_analysis",
|
||||
[](std::shared_ptr<Graph> graph,
|
||||
std::vector<at::Tensor> inputs,
|
||||
bool with_grad) {
|
||||
setInputTypes(
|
||||
*graph,
|
||||
ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
|
||||
PropagateInputShapes(graph);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_complete_shape_analysis",
|
||||
[](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
|
||||
|
|
|
|||
|
|
@ -528,6 +528,12 @@ class ShapePropagator {
|
|||
setUnshapedType(node);
|
||||
return;
|
||||
}
|
||||
case prim::GetAttr: {
|
||||
auto cls = node->input()->type()->expect<ClassType>();
|
||||
// propagate any type specializations encoded in the type of the class
|
||||
node->output()->setType(cls->getAttribute(node->s(attr::name)));
|
||||
return;
|
||||
}
|
||||
case aten::_unwrap_optional: {
|
||||
auto input_ivalue = toIValue(node->input());
|
||||
if (input_ivalue && input_ivalue->isNone()) {
|
||||
|
|
@ -997,11 +1003,9 @@ class ShapePropagator {
|
|||
};
|
||||
|
||||
// Requirements:
|
||||
// dims : 0 if dim is None, otherwise preserved if keepdim == false or 1 smaller otherwise
|
||||
// scalar type : preserved
|
||||
// device : preserved
|
||||
// tensor inputs : 1
|
||||
// tensor outputs : 1
|
||||
// dims : 0 if dim is None, otherwise preserved if keepdim ==
|
||||
// false or 1 smaller otherwise scalar type : preserved device :
|
||||
// preserved tensor inputs : 1 tensor outputs : 1
|
||||
// Additionally:
|
||||
// - First input should be the only tensor input
|
||||
// - Has a bool keepdim argument
|
||||
|
|
@ -1094,7 +1098,9 @@ class ShapePropagator {
|
|||
[](Node* node) -> type_vec_t {
|
||||
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
|
||||
return multidim_reduce_with_postprocess(
|
||||
node, /*num_reduced_dim=*/dim->size(), /*upcast_integer=*/false);
|
||||
node,
|
||||
/*num_reduced_dim=*/dim->size(),
|
||||
/*upcast_integer=*/false);
|
||||
}
|
||||
return {};
|
||||
}};
|
||||
|
|
|
|||
|
|
@ -208,16 +208,6 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
AliasDb db(g);
|
||||
db.dump();
|
||||
})
|
||||
.def(
|
||||
"propagate_shapes",
|
||||
[](std::shared_ptr<Graph> g,
|
||||
std::vector<at::Tensor> inputs,
|
||||
bool with_grad) {
|
||||
setInputTypes(
|
||||
*g,
|
||||
ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
|
||||
PropagateInputShapes(g);
|
||||
})
|
||||
.def(
|
||||
"_export_onnx",
|
||||
[](const std::shared_ptr<Graph> g,
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/argument_spec.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/jit/argument_spec.h>
|
||||
#include <torch/csrc/jit/graph_executor.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/named_value.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
#include <torch/csrc/jit/script/slot.h>
|
||||
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/api/include/torch/ordered_dict.h>
|
||||
|
|
@ -51,8 +50,8 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
|||
|
||||
struct Module;
|
||||
|
||||
using ModuleLookup = std::function<std::shared_ptr<Module>(
|
||||
const std::vector<std::string>&)>;
|
||||
using ModuleLookup =
|
||||
std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
|
||||
|
||||
struct Method {
|
||||
Method(
|
||||
|
|
@ -137,6 +136,14 @@ struct Method {
|
|||
return graph()->addInput()->setType(type);
|
||||
}
|
||||
|
||||
static void setInputTensorTypes(Graph& g, const Stack& stack) {
|
||||
AT_ASSERT(stack.size() == g.inputs().size());
|
||||
for (size_t i = 0; i < stack.size(); ++i) {
|
||||
g.inputs().at(i)->setType(
|
||||
DimensionedTensorType::create(stack.at(i).toTensor()));
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> propagate_shapes(
|
||||
std::vector<at::Tensor> inputs,
|
||||
bool with_grad = false) {
|
||||
|
|
@ -149,8 +156,7 @@ struct Method {
|
|||
for (const Slot& inp : initial_ivalues_) {
|
||||
stack.push_back(*inp);
|
||||
}
|
||||
const auto size = stack.size();
|
||||
setInputTypes(*retval, ArgumentSpec(with_grad, stack, size));
|
||||
setInputTensorTypes(*retval, stack);
|
||||
PropagateInputShapes(retval);
|
||||
return retval;
|
||||
}
|
||||
|
|
@ -167,9 +173,7 @@ struct Method {
|
|||
}
|
||||
}
|
||||
if (propagate) {
|
||||
setInputTypes(
|
||||
*retval,
|
||||
ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
|
||||
setInputTensorTypes(*retval, fmap<IValue>(inputs));
|
||||
PropagateInputShapes(retval);
|
||||
}
|
||||
AT_ASSERT(retval->inputs().size() == inputs.size());
|
||||
|
|
@ -288,16 +292,16 @@ struct Method {
|
|||
if (pos < inputs.size()) {
|
||||
if (!isSubvalueOf(inputs[pos], argument.type())) {
|
||||
AT_ERROR(
|
||||
"Expected value of type ",
|
||||
*argument.type(),
|
||||
" for argument '",
|
||||
argument.name(),
|
||||
"' in position ",
|
||||
pos,
|
||||
", but instead got value of type ",
|
||||
attemptToRecoverType(inputs[pos])->str(),
|
||||
". Declaration: ",
|
||||
schema);
|
||||
"Expected value of type ",
|
||||
*argument.type(),
|
||||
" for argument '",
|
||||
argument.name(),
|
||||
"' in position ",
|
||||
pos,
|
||||
", but instead got value of type ",
|
||||
attemptToRecoverType(inputs[pos])->str(),
|
||||
". Declaration: ",
|
||||
schema);
|
||||
}
|
||||
} else if (argument.default_value()) {
|
||||
inputs.push_back(*argument.default_value());
|
||||
|
|
@ -375,7 +379,8 @@ struct NamedIValue {
|
|||
const TypePtr& type() const {
|
||||
return type_;
|
||||
}
|
||||
private:
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
const TypePtr type_;
|
||||
std::unique_ptr<IValue> ivalue_;
|
||||
|
|
@ -497,12 +502,10 @@ struct Module {
|
|||
const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
|
||||
return modules;
|
||||
}
|
||||
const torch::OrderedDict<std::string, NamedIValue>& get_parameters()
|
||||
const {
|
||||
const torch::OrderedDict<std::string, NamedIValue>& get_parameters() const {
|
||||
return parameters;
|
||||
}
|
||||
const torch::OrderedDict<std::string, NamedIValue>& get_attributes()
|
||||
const {
|
||||
const torch::OrderedDict<std::string, NamedIValue>& get_attributes() const {
|
||||
return attributes;
|
||||
}
|
||||
const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
|
||||
|
|
@ -630,9 +633,7 @@ struct Module {
|
|||
if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
|
||||
continue;
|
||||
}
|
||||
curr->register_buffer(
|
||||
kv.key(),
|
||||
kv.value().slot()->toTensor());
|
||||
curr->register_buffer(kv.key(), kv.value().slot()->toTensor());
|
||||
parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
|
||||
}
|
||||
for (auto& kv : modules) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue