Add ability to specialize class types to ArgumentSpec (#18314)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18314
ghimport-source-id: 8cecb768d476ab19c9460f39c8f94a764e4cb052

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18314 Add ability to specialize class types to ArgumentSpec**
* #18226 Add Slot type to abstract the raw pointers being used for slots.

Differential Revision: D14574395

fbshipit-source-id: cc3af6e56e9ae52990f4a1ad56ecceaa2d493577
This commit is contained in:
Zachary DeVito 2019-04-02 17:33:06 -07:00 committed by Facebook Github Bot
parent 5f5a2aaab9
commit 2d07993bcb
14 changed files with 422 additions and 199 deletions

View file

@ -704,7 +704,9 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
Symbol name() const {
return typename_;
}
const std::vector<IValue>& slots() const {
return slots_;
}
private:
const Symbol typename_;
std::vector<IValue> slots_;

View file

@ -1203,10 +1203,21 @@ struct CAFFE2_API ClassType : public Type {
attributeTypes_.push_back(type);
}
at::ArrayRef<std::string> attributeNames() const {
return attributeNames_;
}
at::ArrayRef<TypePtr> containedTypes() const override {
return attributeTypes_;
}
// generate a refined version of this class.
// It has the same name but the slot Types are subtypes of
// the original slots. It is only valid to refine a class type in a context
// where it is know that there are not assignments to the objects slots
// that would invalidate the refinement.
// These variants are not registered in the global class table.
ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
static const TypeKind Kind = TypeKind::ClassType;
private:

View file

@ -478,6 +478,16 @@ ClassTypePtr ClassType::create(
return ptr;
}
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
auto ptr = ClassTypePtr(new ClassType(typename_, module_));
AT_ASSERT(numAttributes() == refined_slots.size());
for(size_t i = 0; i < attributeNames_.size(); ++i) {
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
ptr->addAttribute(attributeNames_[i], refined_slots[i]);
}
return ptr;
}
ClassTypePtr ClassType::get(const std::string& name) {
return getRegistry().getType(name);
}

View file

@ -208,7 +208,10 @@ void testDifferentiateWithRequiresGrad() {
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
auto b_var = autograd::make_variable(
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2));
ArgumentSpecCreator asc(*graph);
asc.setInputTypes(*graph, asc.create(true, {a_var, b_var}));
PropagateInputShapes(graph);
PropagateRequiresGrad(graph);

View file

@ -1893,17 +1893,18 @@ class TestJit(JitTestCase):
def test_tuple_specialization(self):
@torch.jit.script
def f(t):
# type: (Tuple[Tensor, Tensor]) -> Tensor
x, y = t
def f(t, s):
# type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor
x, t2 = t
_, y = t2
return x + y
t = torch.randn(2, 2), torch.randn(2, 2)
f(t)
graph = f.graph_for(t)
t = torch.randn(2, 2), (1, torch.randn(2, 2)),
f(t, "hi")
graph = f.graph_for(t, "hi")
input_types = list(next(graph.inputs()).type().elements())
for t in input_types:
self.assertEqual(t.kind(), 'DimensionedTensorType')
self.assertEqual(input_types[0].kind(), 'DimensionedTensorType')
self.assertEqual(input_types[1].elements()[1].kind(), 'DimensionedTensorType')
def test_constant_prop_simple(self):
@torch.jit.script
@ -3450,13 +3451,11 @@ a")
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
self.run_pass('constant_propagation', func.graph)
self.run_pass('constant_propagation', func2.graph)
torch._C._jit_pass_shape_analysis(
func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
torch._C._jit_pass_shape_analysis(
func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
self.assertTrue(g.findNode("aten::sum").output().type().kind()
== "DimensionedTensorType")
self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
self.assertTrue(g2.findNode("aten::sum").output().type().kind()
== "DimensionedTensorType")
def test_cat(self):
@ -4154,9 +4153,9 @@ a")
torch.mul(x, y, out=z)
return z
torch._C._jit_pass_shape_analysis(
test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
graph = test._get_method('forward').propagate_shapes(
(torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
self.assertTrue(next(graph.outputs()).type() == TensorType.get())
out_op_graph_input()
def test_resize():
@ -4173,10 +4172,8 @@ a")
after_resize_alias = b.add_(1)
return after_resize_alias
g = test.graph
self.run_pass('constant_propagation', g)
torch._C._jit_pass_shape_analysis(
g, (torch.zeros(1, 1),), False)
self.run_pass('constant_propagation', test.graph)
g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
resize_node = g.findNode("aten::resize_")
# first input and output of b.resize_ is b
self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
@ -4200,8 +4197,7 @@ a")
g = test.graph
self.run_pass('constant_propagation', g)
torch._C._jit_pass_shape_analysis(
g, (torch.zeros(1, 1),), False)
g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
# x doesn't alias a resized op so it shouldn't be set to base Tensor type
self.assertTrue(next(g.inputs()).type() != TensorType.get())
@ -4255,8 +4251,8 @@ a")
return x.view(T, B, C)
x = torch.randn(3, 1, 5, requires_grad=True)
graph = torch.jit.script(fn).graph
torch._C._jit_pass_shape_analysis(graph, (x,), False)
fn = torch.jit.script(fn)
graph = fn._get_method('forward').propagate_shapes((x,), False)
a = next(graph.outputs()).type().kind()
self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
@ -6677,7 +6673,7 @@ a")
return torch.cat(c)
b = torch.zeros(2, 4)
test_list.graph.propagate_shapes((b,), False)
test_list._get_method('forward').propagate_shapes((b,), False)
def test_if_supertype(self):
@torch.jit.script
@ -6694,8 +6690,8 @@ a")
b = torch.zeros(2, 4, dtype=torch.long)
c = torch.zeros(2, 4, dtype=torch.float)
tensor_unifying.graph.propagate_shapes((a, b, c), False)
if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs())
graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False)
if_outputs = list(graph.findNode("prim::If").outputs())
self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
self.assertTrue(if_outputs[1].type().str() == "Tensor")
self.assertTrue(if_outputs[2].type().str() == "Tensor")
@ -13303,6 +13299,30 @@ class TestClassType(JitTestCase):
self.assertEqual(x, f2.x)
self.assertEqual(y, f2.y)
def test_class_specialization(self):
@torch.jit.script # noqa: B903
class Foo(object):
def __init__(self, x, y):
self.x = x
self.y = y
def use_foo(foo, foo2, tup):
# type: (Foo, Foo, Tuple[Foo, Foo]) -> Tensor
a, b = tup
return foo.x + foo2.y + a.x + b.y
# create from python
x = torch.ones(2, 3)
y = torch.zeros(2, 3)
f = Foo(x, y)
f2 = Foo(x * 2, y * 3)
f3 = Foo(x * 4, y * 4)
input = (f, f2, (f, f3))
sfoo = self.checkScript(use_foo, input)
graphstr = str(sfoo.graph_for(*input))
FileCheck().check_count("Double(*, *) = prim::GetAttr", 4).run(graphstr)
class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):

View file

@ -51,6 +51,7 @@ libtorch_sources = [
"torch/csrc/Exceptions.cpp",
"torch/csrc/jit/autodiff.cpp",
"torch/csrc/jit/attributes.cpp",
"torch/csrc/jit/argument_spec.cpp",
"torch/csrc/jit/constants.cpp",
"torch/csrc/jit/node_hashing.cpp",
"torch/csrc/jit/export.cpp",

View file

@ -123,6 +123,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp
${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp
${TORCH_SRC_DIR}/csrc/jit/attributes.cpp
${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp
${TORCH_SRC_DIR}/csrc/jit/export.cpp
${TORCH_SRC_DIR}/csrc/jit/pickler.cpp
${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp

View file

@ -0,0 +1,229 @@
#include <torch/csrc/jit/argument_spec.h>
namespace torch {
namespace jit {
void ArgumentSpecCreator::scan(
const TypePtr& typ,
size_t depth,
const WrittenSlots& written_slots) {
auto finishAggregate = [&](size_t pos) {
// it is possible after all the work we did to scan this aggregate,
// we found no tensors to specialize. In this case, just generate
// a skip for the whole aggregate.
bool any_spec = std::any_of(
instructions_.begin() + pos, instructions_.end(), [](Inst i) {
return i == SPECIALIZE_TENSOR;
});
if (!any_spec) {
instructions_[pos] = SKIP;
instructions_.resize(pos + 1);
} else {
instructions_.emplace_back(LEAVE);
}
};
// the simple vm that scans instructions_ has a limited stack depth,
// this prevents going deeper than that.
if (depth >= DEPTH_LIMIT) {
instructions_.emplace_back(SKIP);
}
if (typ->isSubtypeOf(TensorType::get())) {
num_tensors_++;
instructions_.emplace_back(SPECIALIZE_TENSOR);
} else if (auto tup = typ->cast<TupleType>()) {
size_t pos = instructions_.size();
instructions_.emplace_back(ENTER_TUPLE);
for (const auto& elem : tup->containedTypes()) {
scan(elem, depth + 1, written_slots);
}
finishAggregate(pos);
} else if (auto cls = typ->cast<ClassType>()) {
size_t pos = instructions_.size();
instructions_.emplace_back(ENTER_OBJECT);
for (size_t i = 0; i < cls->numAttributes(); ++i) {
auto key = cls->name() + cls->attributeNames().at(i);
// it is only safe to specialize because someone might have written to it
if (!written_slots.count(key)) {
scan(cls->containedTypes().at(i), depth + 1, written_slots);
} else {
instructions_.emplace_back(SKIP);
}
}
finishAggregate(pos);
} else {
instructions_.emplace_back(SKIP);
}
};
// this is a coarse-grained guarentee that the slots of a class will not be
// modified by the function. It works fine for things that used be read-only
// modules, but will be overly conservative when some classes are written to.
// Doing alias analysis and looking for writes to the class would be more
// accurate.
static void scanWrittenSlots(
Block* block,
ArgumentSpecCreator::WrittenSlots& written_slots) {
for (Node* n : block->nodes()) {
if (n->kind() == prim::SetAttr) {
if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
written_slots.insert(cls->name() + n->s(attr::name));
}
}
for (Block* subblock : n->blocks()) {
scanWrittenSlots(subblock, written_slots);
}
if (n->hasAttribute(attr::Subgraph)) {
scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
}
}
}
ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
: num_inputs_(graph.inputs().size()) {
WrittenSlots written_slots;
scanWrittenSlots(graph.block(), written_slots);
for (Value* input : graph.inputs()) {
scan(input->type(), 0, written_slots);
}
}
void ArgumentSpecCreator::dump() const {
for (Inst inst : instructions_) {
switch (inst) {
case LEAVE:
std::cout << "] ";
break;
case ENTER_TUPLE:
std::cout << "Tuple[";
break;
case ENTER_OBJECT:
std::cout << "Object[";
break;
case SKIP:
std::cout << "Skip ";
break;
case SPECIALIZE_TENSOR:
std::cout << "SpecializeTensor ";
break;
}
}
std::cout << "\n";
}
ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
const {
ArgumentSpec spec(num_tensors_);
const IValue* stack[DEPTH_LIMIT]; // The stack of IValue lists
// The stack gets initialized with the input list
stack[0] = last(input, num_inputs_).begin();
size_t stack_top = 0; // offset to the top of the stack
for (Inst inst : instructions_) {
switch (inst) {
case SPECIALIZE_TENSOR:
// consume a tensor and add to the argspec
spec.addTensor(*stack[stack_top]++, with_grad);
break;
case ENTER_TUPLE: {
// consume tuple
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isTuple());
// see [argspec refcounting]
auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
auto tup_ptr = &p->elements()[0];
// push list of tuple elements to the stack
stack[++stack_top] = tup_ptr;
} break;
case ENTER_OBJECT: {
// consume object
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isObject());
iv->toObject();
// see [argspec refcounting]
auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv);
auto obj_ptr = &p->slots()[0];
// push list of object elements to the stack
stack[++stack_top] = obj_ptr;
} break;
case SKIP:
// consume and skip an element
stack[stack_top]++;
break;
case LEAVE:
--stack_top;
break;
}
}
return spec;
}
// For every input of a given graph, returns a most detailed type that can be
// inferred for it based on this ArgumentSpec.
std::vector<TypePtr> ArgumentSpecCreator::getSpecializedTypes(
Graph& graph,
const ArgumentSpec& spec) const {
auto input_types =
fmap(graph.inputs(), [](Value* input) { return input->type(); });
std::vector<std::vector<TypePtr>> result_stack;
result_stack.emplace_back();
std::vector<const TypePtr*> input_stack = {input_types.data()};
std::vector<std::function<TypePtr()>> aggregate_creators;
size_t arg_spec_offset = 0; // number of specialized tensors seen so far
for (Inst inst : instructions_) {
switch (inst) {
case SPECIALIZE_TENSOR: {
input_stack.back()++;
auto& arg = spec.at(arg_spec_offset++);
if (!arg.defined()) {
result_stack.back().emplace_back(AutogradZeroTensorType::get());
} else {
result_stack.back().emplace_back(DimensionedTensorType::create(
arg.type(),
ConvertIntToCPUOrCUDA(arg.device()),
arg.dim(),
arg.requires_grad()));
}
} break;
case ENTER_TUPLE: {
auto tup = (*input_stack.back()++)->expect<TupleType>();
input_stack.emplace_back(tup->elements().data());
result_stack.emplace_back();
aggregate_creators.emplace_back(
[&] { return TupleType::create(result_stack.back()); });
} break;
case ENTER_OBJECT: {
auto cls = (*input_stack.back()++)->expect<ClassType>();
input_stack.emplace_back(cls->containedTypes().data());
result_stack.emplace_back();
aggregate_creators.emplace_back(
[&result_stack, cls] { return cls->refine(result_stack.back()); });
} break;
case SKIP:
result_stack.back().emplace_back(*input_stack.back()++);
break;
case LEAVE:
TypePtr result = aggregate_creators.back()();
result_stack.pop_back();
aggregate_creators.pop_back();
input_stack.pop_back();
result_stack.back().emplace_back(std::move(result));
break;
}
}
AT_ASSERT(result_stack.size() == 1);
return result_stack.back();
}
void ArgumentSpecCreator::setInputTypes(Graph& g, const ArgumentSpec& spec)
const {
auto input_types = getSpecializedTypes(g, spec);
auto inputs = g.inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i]->setType(input_types[i]);
}
}
} // namespace jit
} // namespace torch

View file

@ -1,9 +1,9 @@
#pragma once
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/ir.h>
#include <ATen/core/stack.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/variable_tensor_list.h>
#include <torch/csrc/utils/hash.h>
#include <iostream>
@ -22,9 +22,6 @@ struct ArgumentInfo {
friend struct ArgumentSpec;
using plain_data_type = uint32_t;
bool isTensor() const {
return is_tensor_;
}
bool defined() const {
return defined_;
}
@ -45,11 +42,11 @@ struct ArgumentInfo {
operator TypePtr() const {
if (!defined())
return TensorType::get();
return DimensionedTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim());
return DimensionedTensorType::create(
type(), ConvertIntToCPUOrCUDA(device()), dim());
}
private:
unsigned is_tensor_ : 1;
unsigned defined_ : 1;
unsigned requires_grad_ : 1;
unsigned : 5;
@ -67,48 +64,32 @@ static_assert(
"ArgumentInfo is expected to be a 32-bit struct");
struct ArgumentSpec {
ArgumentSpec(
bool with_grad,
at::ArrayRef<IValue> inputs,
size_t num_flat_inputs) {
ArgumentSpec(size_t num_flat_inputs) {
hash_code = num_flat_inputs;
args.resize(num_flat_inputs);
size_t offset = 0;
for (const auto& i : inputs) {
addInput(i, offset, with_grad);
}
AT_ASSERT(offset <= num_flat_inputs);
args.reserve(num_flat_inputs);
}
void addInput(const IValue& input, size_t& offset, bool with_grad) {
auto& arg = args.at(offset);
void addTensor(const IValue& input, bool with_grad) {
AT_ASSERT(input.isTensor());
args.emplace_back();
auto& arg = args.back();
// Initialize all fields to 0. This is convenient, because e.g.
// requires_grad() can be checked even on tensors AND will make
// padding bits all 0s.
std::memset(&arg, 0, sizeof(ArgumentInfo));
if (input.isTensor()) {
at::Tensor t = input.toTensor();
if ((arg.defined_ = t.defined())) {
arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
arg.dim_ = t.dim();
arg.device_ = t.is_cuda() ? t.get_device() : -1;
arg.type_ = static_cast<unsigned>(t.scalar_type());
}
arg.is_tensor_ = true;
combineHash(arg);
offset++;
} else if (input.isTuple()) {
for (const IValue& elem : input.toTuple()->elements()) {
addInput(elem, offset, with_grad);
}
} else {
// NB: no need to set is_tensor to false, because we memset the struct to
// 0 above
combineHash(arg);
offset++;
// [argspec refcounting] reinterpret the IValue to avoid having to refcount
// the Tensor microbenchmarks
// https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
// show overhead in extra refcounting along this path
const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
if ((arg.defined_ = t->defined())) {
arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
arg.dim_ = t->dim();
arg.device_ = t->is_cuda() ? t->get_device() : -1;
arg.type_ = static_cast<unsigned>(t->scalar_type());
}
combineHash(arg);
}
void combineHash(const ArgumentInfo& arg) {
@ -143,38 +124,49 @@ struct ArgumentSpec {
size_t hashCode() const {
return hash_code;
}
// For every input of a given graph, returns a most detailed type that can be
// inferred for it based on this ArgumentSpec.
std::vector<TypePtr> getTypes(Graph& graph) const {
size_t offset = 0;
return fmap(
graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); });
}
private:
TypePtr fillType(TypePtr original, size_t& offset) const {
if (original->isSubtypeOf(TensorType::get())) {
auto& arg = args.at(offset++);
if (!arg.defined())
return AutogradZeroTensorType::get();
return DimensionedTensorType::create(
arg.type(),
ConvertIntToCPUOrCUDA(arg.device()),
arg.dim(),
arg.requires_grad());
} else if (auto tuple_type = original->cast<TupleType>()) {
return TupleType::create(fmap(
tuple_type->elements(),
[&](const TypePtr& subtype) { return fillType(subtype, offset); }));
} else {
offset++;
return original;
}
}
size_t hash_code; // precomputed on construction
std::vector<ArgumentInfo> args;
};
// ArgumentSpecCreator takes an initial graph and comes up with a set
// of simple instructions to compute the ArgumentSpec given a set of
// input tensors.
struct ArgumentSpecCreator {
// instructs acts on a stack of a list of input IValues
// at the beginning the stack contains a single list of the inputs to the
// function the ENTER_ instructs descend into subobjects and push new lists
// onto the stack
enum Inst : char {
ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the
// list of its elements onto the stack as a new list
ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class
LEAVE, // pop the top-most list from the stack
SKIP, // consume an element from the top-most list, and discard
SPECIALIZE_TENSOR, // consume a tensor for the top-most list, and
// add it to the ArgSpec key being created
};
ArgumentSpecCreator(Graph& graph);
ArgumentSpec create(bool with_grad, const Stack& stack) const;
void setInputTypes(Graph& g, const ArgumentSpec& spec) const;
std::vector<TypePtr> getSpecializedTypes(
Graph& graph,
const ArgumentSpec& spec) const;
void dump() const;
using WrittenSlots = std::unordered_set<std::string>;
private:
static constexpr size_t DEPTH_LIMIT = 128;
void scan(
const TypePtr& typ,
size_t depth,
const WrittenSlots& written_slots);
size_t num_inputs_;
size_t num_tensors_ = 0;
std::vector<Inst> instructions_;
};
// CompleteArgumentSpec represents one particular specialization.
// It is designed so that it can be created, hashed, and compared quickly
// since it is used along the hot-path of the JIT to check if the code
@ -398,14 +390,6 @@ inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const {
return CompleteArgumentInfo(*this, i);
}
inline void setInputTypes(Graph& g, const ArgumentSpec& spec) {
auto input_types = spec.getTypes(g);
auto inputs = g.inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i]->setType(input_types[i]);
}
}
} // namespace jit
} // namespace torch

View file

@ -322,28 +322,6 @@ struct GraphExecutorImpl {
return copy;
}
static size_t countFlatInputs(const TypePtr& ptr) {
if (auto optional_type = ptr->cast<OptionalType>()) {
return countFlatInputs(optional_type->getElementType());
}
if (auto tuple_type = ptr->cast<TupleType>()) {
size_t total = 0;
for (auto& elem : tuple_type->elements()) {
total += countFlatInputs(elem);
}
return total;
}
return 1;
}
static size_t countFlatInputs(const std::shared_ptr<Graph>& graph) {
size_t total = 0;
for (Value* input : graph->inputs()) {
total += countFlatInputs(input->type());
}
return total;
}
inline bool hasMutableOperators(Block* block) {
for (auto n : block->nodes()) {
if (n->kind().is_aten() && n->schema().is_mutable())
@ -362,11 +340,11 @@ struct GraphExecutorImpl {
// disables all optimization
optimize(optimize),
num_inputs(this->graph->inputs().size()),
num_flat_inputs(countFlatInputs(graph)),
arg_spec_creator_(*graph),
num_outputs(this->graph->outputs().size()) {
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
}
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
}
// entry point where execution begins
void run(Stack& stack) {
@ -391,9 +369,9 @@ struct GraphExecutorImpl {
std::shared_ptr<Graph> graphFor(const Stack& stack) const {
AT_ASSERT(stack.size() >= num_inputs);
auto inputs = last(stack, num_inputs);
ArgumentSpec spec(
autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
ArgumentSpec spec =
arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
if (!optimize) {
AT_CHECK(fallback, "No graph found for given inputs");
@ -441,10 +419,8 @@ struct GraphExecutorImpl {
const ExecutionPlan& getOrCompile(const Stack& stack) {
// outside lock guard, to minimize the time holding the lock on the fast
// path ArgumentSpec even computes its hashCode here.
ArgumentSpec spec(
autograd::GradMode::is_enabled(),
last(stack, num_inputs),
num_flat_inputs);
ArgumentSpec spec =
arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
@ -463,7 +439,7 @@ struct GraphExecutorImpl {
ExecutionPlan compileSpec(const ArgumentSpec& spec) {
auto opt_graph = graph->copy();
setInputTypes(*opt_graph, spec);
arg_spec_creator_.setInputTypes(*opt_graph, spec);
// Phase 1. Specialize to input definedness (this is very important for
// gradient graphs), and run required passes to bring the graph
@ -562,8 +538,8 @@ struct GraphExecutorImpl {
auto input_values = fmap(
inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); });
ArgumentSpec spec(
autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
ArgumentSpec spec =
arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
// NB: we could just run the fallback in here and call it a day, but that
// would loose all the control flow information we have in the graph. Thus,
// we run the fallback to get the correct output values, but we will
@ -580,7 +556,7 @@ struct GraphExecutorImpl {
// tracing and so we only do the type propgation if no concrete types have
// been set.
auto local_graph = this->graph->copy();
setInputTypes(*local_graph, spec);
arg_spec_creator_.setInputTypes(*local_graph, spec);
PropagateInputShapes(local_graph);
auto output_values =
inlineCallTo(*state->graph, *local_graph, input_values);
@ -600,8 +576,7 @@ struct GraphExecutorImpl {
// Useful for debugging.
const bool optimize;
const size_t num_inputs;
const size_t num_flat_inputs; // Number of inputs, assuming all tuples would
// be flattened.
ArgumentSpecCreator arg_spec_creator_;
const size_t num_outputs;
// Populated only when optimize is false (and in that case plan_cache will be

View file

@ -155,16 +155,6 @@ void initJITBindings(PyObject* module) {
"_jit_pass_canonicalize",
[](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
.def("_jit_pass_lint", LintGraph)
.def(
"_jit_pass_shape_analysis",
[](std::shared_ptr<Graph> graph,
std::vector<at::Tensor> inputs,
bool with_grad) {
setInputTypes(
*graph,
ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
PropagateInputShapes(graph);
})
.def(
"_jit_pass_complete_shape_analysis",
[](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {

View file

@ -528,6 +528,12 @@ class ShapePropagator {
setUnshapedType(node);
return;
}
case prim::GetAttr: {
auto cls = node->input()->type()->expect<ClassType>();
// propagate any type specializations encoded in the type of the class
node->output()->setType(cls->getAttribute(node->s(attr::name)));
return;
}
case aten::_unwrap_optional: {
auto input_ivalue = toIValue(node->input());
if (input_ivalue && input_ivalue->isNone()) {
@ -997,11 +1003,9 @@ class ShapePropagator {
};
// Requirements:
// dims : 0 if dim is None, otherwise preserved if keepdim == false or 1 smaller otherwise
// scalar type : preserved
// device : preserved
// tensor inputs : 1
// tensor outputs : 1
// dims : 0 if dim is None, otherwise preserved if keepdim ==
// false or 1 smaller otherwise scalar type : preserved device :
// preserved tensor inputs : 1 tensor outputs : 1
// Additionally:
// - First input should be the only tensor input
// - Has a bool keepdim argument
@ -1094,7 +1098,9 @@ class ShapePropagator {
[](Node* node) -> type_vec_t {
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
return multidim_reduce_with_postprocess(
node, /*num_reduced_dim=*/dim->size(), /*upcast_integer=*/false);
node,
/*num_reduced_dim=*/dim->size(),
/*upcast_integer=*/false);
}
return {};
}};

View file

@ -208,16 +208,6 @@ void initPythonIRBindings(PyObject* module_) {
AliasDb db(g);
db.dump();
})
.def(
"propagate_shapes",
[](std::shared_ptr<Graph> g,
std::vector<at::Tensor> inputs,
bool with_grad) {
setInputTypes(
*g,
ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
PropagateInputShapes(g);
})
.def(
"_export_onnx",
[](const std::shared_ptr<Graph> g,

View file

@ -1,15 +1,14 @@
#pragma once
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/argument_spec.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/named_value.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/source_range.h>
#include <torch/csrc/jit/script/slot.h>
#include <torch/csrc/jit/source_range.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
@ -51,8 +50,8 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
struct Module;
using ModuleLookup = std::function<std::shared_ptr<Module>(
const std::vector<std::string>&)>;
using ModuleLookup =
std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
struct Method {
Method(
@ -137,6 +136,14 @@ struct Method {
return graph()->addInput()->setType(type);
}
static void setInputTensorTypes(Graph& g, const Stack& stack) {
AT_ASSERT(stack.size() == g.inputs().size());
for (size_t i = 0; i < stack.size(); ++i) {
g.inputs().at(i)->setType(
DimensionedTensorType::create(stack.at(i).toTensor()));
}
}
std::shared_ptr<Graph> propagate_shapes(
std::vector<at::Tensor> inputs,
bool with_grad = false) {
@ -149,8 +156,7 @@ struct Method {
for (const Slot& inp : initial_ivalues_) {
stack.push_back(*inp);
}
const auto size = stack.size();
setInputTypes(*retval, ArgumentSpec(with_grad, stack, size));
setInputTensorTypes(*retval, stack);
PropagateInputShapes(retval);
return retval;
}
@ -167,9 +173,7 @@ struct Method {
}
}
if (propagate) {
setInputTypes(
*retval,
ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
setInputTensorTypes(*retval, fmap<IValue>(inputs));
PropagateInputShapes(retval);
}
AT_ASSERT(retval->inputs().size() == inputs.size());
@ -288,16 +292,16 @@ struct Method {
if (pos < inputs.size()) {
if (!isSubvalueOf(inputs[pos], argument.type())) {
AT_ERROR(
"Expected value of type ",
*argument.type(),
" for argument '",
argument.name(),
"' in position ",
pos,
", but instead got value of type ",
attemptToRecoverType(inputs[pos])->str(),
". Declaration: ",
schema);
"Expected value of type ",
*argument.type(),
" for argument '",
argument.name(),
"' in position ",
pos,
", but instead got value of type ",
attemptToRecoverType(inputs[pos])->str(),
". Declaration: ",
schema);
}
} else if (argument.default_value()) {
inputs.push_back(*argument.default_value());
@ -375,7 +379,8 @@ struct NamedIValue {
const TypePtr& type() const {
return type_;
}
private:
private:
const std::string name_;
const TypePtr type_;
std::unique_ptr<IValue> ivalue_;
@ -497,12 +502,10 @@ struct Module {
const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
return modules;
}
const torch::OrderedDict<std::string, NamedIValue>& get_parameters()
const {
const torch::OrderedDict<std::string, NamedIValue>& get_parameters() const {
return parameters;
}
const torch::OrderedDict<std::string, NamedIValue>& get_attributes()
const {
const torch::OrderedDict<std::string, NamedIValue>& get_attributes() const {
return attributes;
}
const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
@ -630,9 +633,7 @@ struct Module {
if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
continue;
}
curr->register_buffer(
kv.key(),
kv.value().slot()->toTensor());
curr->register_buffer(kv.key(), kv.value().slot()->toTensor());
parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
}
for (auto& kv : modules) {