pytorch/test/cpp/jit/test_misc.cpp

1317 lines
41 KiB
C++
Raw Normal View History

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/ThreadLocalDebugInfo.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include <torch/csrc/jit/ir/type_hashing.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/codegen/fuser/interface.h"
#include "torch/csrc/jit/frontend/code_template.h"
#include "torch/csrc/jit/frontend/tracer.h"
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/ir/attributes.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/scope.h"
#include "torch/csrc/jit/passes/bailout_graph.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/inline_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/insert_guards.h"
#include "torch/csrc/jit/passes/liveness.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/pass_manager.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
#include "torch/csrc/jit/runtime/argument_spec.h"
#include "torch/csrc/jit/runtime/autodiff.h"
#include "torch/csrc/jit/runtime/custom_operator.h"
#include "torch/csrc/jit/runtime/interpreter.h"
#include "torch/csrc/jit/runtime/symbolic_script.h"
#include "torch/csrc/jit/serialization/import.h"
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/variable.h"
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/script.h>
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/frontend/ir_emitter.h"
#include "torch/csrc/jit/runtime/profiling_record.h"
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
#include "torch/jit.h"
#include "onnx/onnx_pb.h"
#include <c10/util/Exception.h>
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_set>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
size_t i = 0;
out << "{";
for (auto&& e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "}";
return out;
}
void testInternedStrings() {
ASSERT_EQ(prim::Param, Symbol::prim("Param"));
ASSERT_EQ(prim::Return, Symbol::prim("Return"));
ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
Symbol newsym = Symbol::aten("__NEW_SYMBOL");
size_t symstart = newsym;
ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
// TODO: This test is a bit too close to the implementation details.
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
}
void testFromQualString() {
ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
ASSERT_EQ(
Symbol::fromQualString("::").ns().toQualString(),
std::string("namespaces::"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").toUnqualString(),
std::string("param"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
std::string("new_ns"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").ns(),
Symbol::fromQualString("namespaces::new_ns"));
auto bad_inputs = {"scope", ":", ""};
for (auto input : bad_inputs) {
try {
Symbol::fromQualString(input);
ASSERT_TRUE(0);
} catch (const std::exception& c) {
}
}
}
void testTHNNConv() {
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
std::vector<int64_t> kernel_size = {3, 5};
std::vector<int64_t> stride = {1, 2};
std::vector<int64_t> padding = {2, 1};
constexpr int out_channels = 5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn(
{out_channels, input_size[1], kernel_size[0], kernel_size[1]});
at::Tensor bias = torch::randn({out_channels});
// run forward eagerly
at::Tensor output, finput, fgradinput;
std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(
input, weight, kernel_size, bias, stride, padding);
// make grad_outputs
at::Tensor grad_output =
torch::randn_like(output, at::MemoryFormat::Preserve);
at::Tensor grad_finput =
torch::zeros_like(finput, at::MemoryFormat::Preserve);
at::Tensor grad_fgradinput =
torch::zeros_like(fgradinput, at::MemoryFormat::Preserve);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(
grad_output,
input,
weight,
kernel_size,
stride,
padding,
finput,
fgradinput,
{true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
remove list specialization from ivalue (#30734) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30734 What are specialized lists? The IValues that hold List[int], List[Tensor], and List[AnythingElse] are different C++ types. e.g. List[int] has a std::vector<int> while List[AnythingElse] holds a std::vector<IValue>. Why do we have specialized lists? When we first created the JIT we needed to bind the ATen C++ API which has std::vector<int>, std::vector<Tensor> as inputs. The easiest way to match this API was to make our IValues contain these same types. Conversion was just unwrapping the IValue, very easy and cheap. What is the problem with specialized lists? We end up with significant special cases through the compiler. Other types like Dict are not specialized. So in the Pickler, for instance, there is a single piece of logic to handle their serialization. For Lists, we end up with multiple cases. Furthermore, it doesn't match Python, leading to problems along translation boundaries. Our pickle serialization is slightly different than python, so it is harder to load objects from our IValue serialization as Python values. They also make it harder to provide an easy-to-use user API. We'd like to match pybind11 for C++ bindings to TorchScript. This would entail having a single torch::List class (untemplated) that can be used to construct inputs. This is made much harder if the underlying ivalue needs to be different depending on the type inside the list. The ideal case would be to have a constructor like ``` template<typename T> List(std::vector<T> foo); ``` It would then set up the type tags correctly based on type T, without the need for passing tags. Do specialized lists improve perf? Not in a way we have been able to measure. Our major concern initially was having to translate a std::vector<IValue> to std::vector<int> to call ATen functions. This was especially a concern for aten::_convolution which takes a number of mostly-constant lists of integers. However, when we measure the effect of actually having to do this conversion for an aten::_convolution, it does not take measurable time (benchmark results below). This is true even if you use a trivial convolution (e.g. 1x1x1), and comment out the actual convolution code. What are the issues removing them? This PR removes list specialization but keeps the serialization format, and IValue APIs almost exactly the same. The only visible change is that toTensorListRef and family have turned into toTensorVector because they now return by value a copy of the list as a vector. Further PRs can then clean up the complexity issues that arose from speclization. This will likely involve removing the isTensorList/isIntList functions, and refactoring the code that used them to work generically. At some point we will also change serialization to no longer write specialized lists in the pickle binary. This is forward incompatible, so will go in its own PR. Benchmark: ``` import torch import torch.nn as nn import torch.nn.functional as F import time class MnistNet(nn.Module): def __init__(self): super(MnistNet, self).__init__() self.conv1 = nn.Conv2d(1, 1, kernel_size=1) self.conv2 = nn.Conv2d(1, 1, kernel_size=1) def forward(self, x): for i in range(10): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) return x model = MnistNet() x = torch.rand(1, 1, 1, 1) r = torch.jit.trace(model, x ) r(x) r(x) r(x) r(x) print(torch.jit.last_executed_optimized_graph()) while True: b = time.time() for i in range(100): r(x) e = time.time() print(e - b) ``` Results (no observable difference): ``` Before (actual conv) 0.13251137733459473 0.13260436058044434 0.13276338577270508 0.1327497959136963 0.13250041007995605 0.13270330429077148 0.13290190696716309 0.13265132904052734 0.13274288177490234 0.1326758861541748 0.13253355026245117 0.13254785537719727 0.13260746002197266 0.13285017013549805 0.13264012336730957 0.132490873336792 0.13280034065246582 0.13243484497070312 0.1325232982635498 0.1326127052307129 0.13264131546020508 0.13274383544921875 0.13298296928405762 0.1326909065246582 ------------------- After (actual conv) 0.13127517700195312 0.13150334358215332 0.13092470169067383 0.13102364540100098 0.13134360313415527 0.13155555725097656 0.13314104080200195 0.13151955604553223 0.13160037994384766 0.1315293312072754 0.13137340545654297 0.13148093223571777 0.131455659866333 0.1327371597290039 0.13134026527404785 0.13152337074279785 0.13151192665100098 0.13165974617004395 0.13403725624084473 0.13251852989196777 0.13135504722595215 0.1315624713897705 0.1317615509033203 0.1314380168914795 0.13157200813293457 -------------------- The following replace the convolution operator with a no-op, to show that even if the conv op was made faster, then we still would not see a difference: Before (fake conv) 0.0069539546966552734 0.0069522857666015625 0.007120847702026367 0.007344722747802734 0.007689952850341797 0.007932662963867188 0.00761723518371582 0.007501363754272461 0.007532835006713867 0.007141828536987305 0.007174253463745117 0.007114410400390625 0.007071495056152344 ------------------ After (fake conv) 0.007458209991455078 0.007337093353271484 0.007268190383911133 0.007313251495361328 0.007306575775146484 0.007468700408935547 0.0073091983795166016 0.007308483123779297 0.007538318634033203 0.007356882095336914 0.007464170455932617 0.007372140884399414 ``` Test Plan: Imported from OSS Differential Revision: D18814702 Pulled By: zdevito fbshipit-source-id: 0371c73b63068fdc12f24b801371ea90f23531a6
2020-01-13 02:26:36 +00:00
auto ksz_val = graph->insertConstant(kernel_size);
auto kst_val = graph->insertConstant(stride);
auto pad_val = graph->insertConstant(padding);
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
Value* conv = graph->insert(
aten::thnn_conv2d_forward,
{inputg, weightg, ksz_val, biasg, kst_val, pad_val});
auto outputs = conv->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
tensor_grads_in.push_back(grad_finput);
tensor_grads_in.push_back(grad_fgradinput);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensors_out.push_back(finput);
expected_tensors_out.push_back(fgradinput);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);
// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
void testATenNativeBatchNorm() {
// aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
// running_mean, Tensor running_var, bool training, float momentum, float eps)
// -> (Tensor, Tensor, Tensor)
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
bool training = true;
float momentum = 0.9;
float eps = 1e-5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn({input_size[1]});
at::Tensor bias = torch::randn({input_size[1]});
at::Tensor running_mean = torch::randn({input_size[1]});
at::Tensor running_var = torch::randn({input_size[1]});
// running_mean and running_var are changed in-place, so clone and send them
at::Tensor running_mean_eager = running_mean.clone();
at::Tensor running_var_eager = running_var.clone();
at::Tensor running_mean_jit = running_mean.clone();
at::Tensor running_var_jit = running_var.clone();
// run forward eagerly
at::Tensor output, savemean, saveinvstd;
std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
input,
weight,
bias,
running_mean_eager,
running_var_eager,
training,
momentum,
eps);
// make grad_outputs
at::Tensor grad_output =
torch::randn_like(output, at::MemoryFormat::Preserve);
at::Tensor grad_savemean =
torch::zeros_like(savemean, at::MemoryFormat::Preserve);
at::Tensor grad_saveinvstd =
torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
// aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
// weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
// save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
// Tensor, Tensor)
std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
grad_output,
input,
weight,
running_mean_eager,
running_var_eager,
savemean,
saveinvstd,
training,
eps,
{true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto training_val = graph->insertConstant(IValue(training));
auto momentum_val = graph->insertConstant(IValue(momentum));
auto eps_val = graph->insertConstant(IValue(eps));
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
auto running_meang = graph->addInput("running_mean");
auto running_varg = graph->addInput("running_var");
Value* bn = graph->insert(
aten::native_batch_norm,
{inputg,
weightg,
biasg,
running_meang,
running_varg,
training_val,
momentum_val,
eps_val});
auto outputs = bn->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensors_in.push_back(running_mean_jit);
tensors_in.push_back(running_var_jit);
tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
tensor_grads_in.push_back(grad_savemean);
tensor_grads_in.push_back(grad_saveinvstd);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensors_out.push_back(savemean);
expected_tensors_out.push_back(saveinvstd);
expected_tensors_out.push_back(running_mean_eager);
expected_tensors_out.push_back(running_var_eager);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);
tensors_out.push_back(running_mean_jit);
tensors_out.push_back(running_var_jit);
// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
void testCustomFusion() {
auto graph_string = R"IR(
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4)):
%2 : Tensor = aten::mul(%0, %1)
%3 : Tensor = aten::mul(%2, %0)
return (%3))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
torch::jit::overrideCanFuseOnCPU(true);
CustomFuseGraph(
g,
[](Node* n) { return n->kind() != prim::Param; },
Symbol::fromQualString("prim::FusionGroup"));
torch::jit::overrideCanFuseOnCPU(false);
const auto& nodes = g->nodes();
auto fusion_group =
std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
return node->kind() == Symbol::fromQualString("prim::FusionGroup");
});
AT_ASSERT(fusion_group != nodes.end());
auto subgraph = fusion_group->g(attr::Subgraph);
auto hits = 0;
// two multiplications
for (const auto& n : subgraph->nodes()) {
(void)n;
hits++;
}
AT_ASSERT(hits == 2);
}
void testCustomFusionNestedBlocks() {
auto graph_string = R"IR(
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4),
%2 : Float(2, 3, 4)):
%3 : int = prim::Constant[value=1]()
%4 : Tensor = prim::If(%2)
block0():
%5 : Tensor = aten::mul(%0, %2)
%6 : Tensor = aten::mul(%5, %1)
-> (%6)
block1():
%7 : Tensor = aten::add(%0, %2, %3)
%8 : Tensor = aten::add(%7, %1, %3)
-> (%8)
%9 : Tensor = aten::add(%4, %2, %3)
return (%4))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
CustomFuseGraph(
g,
[](Node* n) { return n->kind() == aten::mul; },
Symbol::fromQualString("prim::FusionGroup"));
// Could be done in more efficient ways, but this is only a test.
std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
Symbol s) {
for (auto node : b->nodes()) {
if (node->kind() == s)
return true;
for (auto nested_b : node->blocks())
if (dfs(nested_b, s))
return true;
}
return false;
};
AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
}
static const auto cf_examples = R"JIT(
def if_test(a, b):
# FIXME: use 0 instead of a.
# c = 0
c = a
if bool(a < b):
c = b
else:
c = a
return c
def if_one(a, b):
c = b
if bool(a < b):
c = a
return c
def while_test(a, i):
while bool(i < 3):
a *= a
i += 1
return a
)JIT";
void testControlFlow() {
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
auto cu = compile(cf_examples);
auto run = [&](const std::string& name, std::vector<IValue> stack) {
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
auto graph = cu->get_function(name).graph();
improved TorchScript traceback (#33834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33834 This changes how we report Tracebacks to make them more clear when there are both serialized and non-serialized ranges. It now looks like: ``` Traceback (most recent call last): File "foo.py", line 25, in <module> s2(a, b) File "/scratch/zdevito/pytorch/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/__torch__.py", line 7, in forward x: Tensor, y: Tensor) -> Tensor: return (self).bar(x, y, ) ~~~~~~~~~ <--- HERE def bar(self: __torch__.Moo, x: Tensor, File "code/__torch__.py", line 11, in bar x: Tensor, y: Tensor) -> Tensor: _0 = (self).baz(x, y, ) ~~~~~~~~~ <--- HERE _1 = torch.ones([3], dtype=None, layout=None, device=None, pin_memory=None) return torch.add(_0, _1, alpha=1) File "code/__torch__.py", line 17, in baz x: Tensor, y: Tensor) -> Tensor: return torch.add(x, y, alpha=1) ~~~~~~~~~ <--- HERE Traceback of TorchScript, original code (most recent call last): File "foo.py", line 11, in forward def forward(self, x, y): return self.bar(x, y) ~~~~~~~~ <--- HERE File "foo.py", line 9, in bar def bar(self, x, y): return self.baz(x, y) + torch.ones(3) ~~~~~~~~ <--- HERE File "foo.py", line 7, in baz def baz(self, x, y): return x + y ~~~~~ <--- HERE RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1 ``` It follows Python convension of having the most important information last and reading from the bottom up. Changes: * Moved the error message to the end, to copy Python * Report original traceback separate from serialized traceback * Make sure root functions have names in the interpreter trace. Test Plan: Imported from OSS Differential Revision: D20126136 Pulled By: zdevito fbshipit-source-id: fd01f9985e5d74e04c4d064c02e8bc320f4fac13
2020-03-03 20:24:28 +00:00
Code code(graph, "");
InterpreterState interp(code);
interp.run(stack);
return stack;
};
auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
return V(run(name, {L(a), L(b)})[0]);
};
ASSERT_EQ(2, run_binary("if_test", 1, 2));
ASSERT_EQ(3, run_binary("if_test", 3, 2));
ASSERT_EQ(2, run_binary("if_one", 2, 3));
ASSERT_EQ(2, run_binary("if_one", 3, 2));
ASSERT_EQ(256, run_binary("while_test", 2, 0));
}
void testProto() {
::ONNX_NAMESPACE::ModelProto proto;
proto.set_producer_name("foo");
}
void testEvalModeForLoadedModule() {
if (isSandcastle())
return; // The module file to load is not generated in Sandcastle
std::string module_path = "dropout_model.pt";
torch::jit::Module module = torch::jit::load(module_path);
AT_ASSERT(module.attr("dropout").toModule().is_training());
module.eval();
AT_ASSERT(!module.attr("dropout").toModule().is_training());
module.train();
AT_ASSERT(module.attr("dropout").toModule().is_training());
}
void testSerializationInterop() {
if (isSandcastle()) {
// The module file to load is not generated in Sandcastle
return;
}
// This should be generated by `test/cpp/jit/tests_setup.py`
std::ifstream input_stream("ivalue.pt");
std::vector<char> input;
input.insert(
input.begin(),
std::istream_iterator<char>(input_stream),
std::istream_iterator<char>());
IValue ivalue = pickle_load(input);
auto elements = ivalue.toTuple()->elements();
auto ones = torch::ones({2, 2});
ASSERT_TRUE(ones.equal(elements.at(0).toTensor()));
auto twos = torch::ones({3, 5}) * 2;
ASSERT_TRUE(twos.equal(elements.at(1).toTensor()));
}
void testTorchSaveError() {
if (isSandcastle()) {
// The file to load is not generated in Sandcastle
return;
}
// This should be generated by `test/cpp/jit/tests_setup.py`
bool passed = true;
try {
torch::jit::load("eager_value.pt");
passed = false;
} catch (const std::exception& c) {
}
// Ensure torch::jit::load did not run
ASSERT_TRUE(passed);
}
// test a few features that are not directly used in schemas yet
void testSchemaParser() {
// nested arrays
auto s = parseSchema("at::what(int[][4] foo) -> ()");
ASSERT_TRUE(s.arguments().at(0).N() == 4);
ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
.at(0)
.type()
->expect<ListType>()
->getElementType()
->expect<ListType>()
->getElementType()));
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
.at(0)
.type()
->expect<ListType>()
->getElementType()
->expect<ListType>()
->getElementType()));
// named returns
parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
auto s3 =
parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
// futures
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(
s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
// test tensor with annotated alias sets
parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
{
const auto s = parseSchema(
"at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
" -> (Tensor(b|c)[](a!))");
// The list itself is annotated with `a`
const auto& aliasInfo = *s.arguments().at(0).alias_info();
ASSERT_TRUE(
aliasInfo.beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo.isWrite());
// Check the contained types
ASSERT_TRUE(!aliasInfo.containedTypes().empty());
const auto& containedAliasInfo = aliasInfo.containedTypes()[0];
const auto expected = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"),
Symbol::fromQualString("alias::c"),
};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
{
const auto s = parseSchema(
"at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
" -> (Tensor(b|c)[](a!))");
// The list itself is annotated with `a`
const auto& aliasInfo = *s.arguments().at(0).alias_info();
ASSERT_EQ(
aliasInfo.beforeSets(),
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_EQ(
aliasInfo.afterSets(),
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo.isWrite());
ASSERT_EQ(aliasInfo.containedTypes().size(), 1);
// Check the contained types
ASSERT_TRUE(!aliasInfo.containedTypes().empty());
const auto& containedAliasInfo = aliasInfo.containedTypes()[0];
const auto expectedBefore = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"),
};
const auto expectedAfter = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
}
void testTopologicalIndex() {
{
Graph graph;
auto node1 = graph.create(prim::AutogradZero);
auto node2 = graph.create(prim::AutogradZero);
auto node3 = graph.create(prim::AutogradZero);
auto node4 = graph.create(prim::AutogradZero);
graph.appendNode(node4);
graph.prependNode(node1);
node2->insertAfter(node1);
node3->insertBefore(node4);
// nodes should be in numerical order
ASSERT_TRUE(node1->isBefore(node2));
ASSERT_TRUE(node1->isBefore(node3));
ASSERT_TRUE(node1->isBefore(node4));
ASSERT_TRUE(node2->isAfter(node1));
ASSERT_TRUE(node2->isBefore(node3));
ASSERT_TRUE(node2->isBefore(node4));
ASSERT_FALSE(node3->isBefore(node1));
ASSERT_FALSE(node3->isBefore(node2));
ASSERT_FALSE(node3->isAfter(node4));
// Built up a block structure
// node3
// /\ ...
// A B block1
// \ ...
// C block2
auto block1 = node3->addBlock();
auto A = graph.create(prim::AutogradZero);
block1->appendNode(A);
auto B = graph.create(prim::AutogradZero);
block1->appendNode(B);
auto block2 = B->addBlock();
auto C = graph.create(prim::AutogradZero);
block2->appendNode(C);
// Check isAfter on different block levels
ASSERT_TRUE(node1->isBefore(A));
ASSERT_TRUE(A->isBefore(B));
ASSERT_TRUE(A->isBefore(C));
// make sure things don't blow up on deletions
node2->destroy();
auto node2p = graph.create(prim::AutogradZero);
node2p->insertAfter(node1);
ASSERT_TRUE(node1->isBefore(node2p));
ASSERT_TRUE(node2p->isBefore(node3));
}
{
// Induce reindexing to test that path
Graph graph;
std::map<size_t, Node*> nodes;
auto anchor = graph.create(prim::AutogradZero);
graph.appendNode(anchor);
// Inserting to the same place a lot will trigger reindexing
for (auto i = 0; i < 100; ++i) {
auto n = graph.create(prim::AutogradZero);
n->insertAfter(anchor);
nodes[i] = n;
}
// Nodes should be in reverse order
for (auto i = 0; i < 100; ++i) {
for (auto j = i + 1; j < 100; ++j) {
ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
}
}
}
}
at::Tensor invokeTestRecordFunction(at::Tensor& t) {
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
auto t2 = t.pow(2);
return t2;
}
static const auto invokeTestRecordFunction_JIT = R"JIT(
def forward(t):
t2 = t.pow(2)
return t2
)JIT";
at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
auto cu = compile(invokeTestRecordFunction_JIT);
return cu->get_function("forward")({t}).toTensor();
}
using TracedTestInputs =
std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
void checkTracedInputs(const TracedTestInputs& inputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
for (const auto& input : inputs) {
const auto& fn = std::get<0>(input);
const auto& sizes = std::get<1>(input);
if (fn == "test") {
found_test = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
} else if (fn == "pow") {
found_pow = true;
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
TORCH_CHECK(sizes[1].empty());
} else if (fn == "mul") {
found_mul = true;
TORCH_CHECK(sizes.size() > 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
}
}
TORCH_CHECK(found_test);
TORCH_CHECK(found_pow);
TORCH_CHECK(found_mul);
}
void testRecordFunction() {
// [(fn, [[sizes], [sizes], ...]), ...]
TracedTestInputs traced_inputs;
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
autograd::profiler::pushCallback(
[&traced_inputs](const autograd::profiler::RecordFunction& fn) {
auto inputs = fn.inputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& input : inputs) {
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
if (input.isTensor()) {
sizes.push_back(input.toTensor().sizes().vec());
} else if (input.isScalar()) {
sizes.push_back(std::vector<int64_t>());
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
}
}
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
},
[](const autograd::profiler::RecordFunction&) {},
/* needs_inputs */ true);
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = invokeTestRecordFunction(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
auto eager_inputs = traced_inputs;
traced_inputs.clear();
t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
t2 = invokeTestRecordFunctionJIT(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
auto jit_inputs = traced_inputs;
traced_inputs.clear();
autograd::profiler::popCallback();
checkTracedInputs(eager_inputs);
checkTracedInputs(jit_inputs);
// test sampled callbacks
int sampled_cb_ctr = 0;
autograd::profiler::pushCallback(
[&sampled_cb_ctr](const autograd::profiler::RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++sampled_cb_ctr;
}
},
[](const autograd::profiler::RecordFunction&) {},
/* needs_inputs */ false,
/* sampled */ true);
int non_sampled_cb_ctr = 0;
autograd::profiler::pushCallback(
[&non_sampled_cb_ctr](const autograd::profiler::RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++non_sampled_cb_ctr;
}
},
[](const autograd::profiler::RecordFunction&) {},
/* needs_inputs */ false,
/* sampled */ false);
auto run_test_function = []() {
auto t = torch::randn({1, 2, 3}, at::kCPU);
for (auto k = 0; k < 1000; k++) {
invokeTestRecordFunction(t);
}
};
autograd::profiler::setSamplingProbability(0.5);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 1000);
TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
sampled_cb_ctr = 0;
autograd::profiler::setSamplingProbability(0.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 2000);
TORCH_CHECK(sampled_cb_ctr == 0);
sampled_cb_ctr = 0;
autograd::profiler::setSamplingProbability(1.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 3000);
TORCH_CHECK(sampled_cb_ctr == 1000);
autograd::profiler::popCallback();
autograd::profiler::popCallback();
}
class TestThreadLocalDebugInfo : public at::ThreadLocalDebugInfoBase {
public:
int getModelId() const {
return model_id_;
}
void setModelId(int model_id) {
model_id_ = model_id;
}
virtual ~TestThreadLocalDebugInfo() {}
private:
int model_id_ = 0;
};
void testThreadLocalDebugInfo() {
auto checkDebugInfo = []() {
auto debug_info = at::getThreadLocalDebugInfo();
TORCH_CHECK(debug_info != nullptr);
auto* test_debug_info =
dynamic_cast<TestThreadLocalDebugInfo*>(debug_info.get());
TORCH_CHECK(test_debug_info != nullptr);
TORCH_CHECK(test_debug_info->getModelId() == 42);
};
TORCH_CHECK(at::getThreadLocalDebugInfo() == nullptr);
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
debug_info->setModelId(42);
at::setThreadLocalDebugInfo(debug_info);
checkDebugInfo();
// check that thread local debug info is propagated through fork calls
std::atomic<bool> done{false};
at::launch([checkDebugInfo, &done]() {
checkDebugInfo();
done = true;
});
while (!done) {
}
checkDebugInfo();
// check that thread local debug info is propagated through backward pass
autograd::profiler::pushCallback(
[&checkDebugInfo](const autograd::profiler::RecordFunction& fn) {
checkDebugInfo();
},
[](const autograd::profiler::RecordFunction&) {});
{
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = t.pow(2);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
}
autograd::profiler::popCallback();
checkDebugInfo();
at::setThreadLocalDebugInfo(nullptr);
TORCH_CHECK(at::getThreadLocalDebugInfo() == nullptr);
}
void testAutogradProfiler() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
constexpr int seq_len = 32;
int hidden_size = 2 * input_size;
auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
std::stringstream ss;
{
autograd::profiler::RecordProfile guard(ss);
for (size_t i = 0; i < 100; ++i) {
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
}
}
std::string result = ss.str();
size_t count = 0;
for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
count++, pos++) {
}
TORCH_CHECK(count == 200);
}
void testNoneSchemaMatch() {
RegisterOperators reg({
Operator(
"prim::test_none() -> int?",
remove unnecessary Node* ops (#32760) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32760 Minor changes to the way ops are implemented to remove incidental use of Node* in the operator implementation. Current state for operators that previously took Node: ``` TBD: USES NODE: prim::DifferentiableGraph(...) -> (...) USES NODE: prim::profile(...) -> (...) USES NODE: prim::FusionGroup(...) -> (...) USES NODE: prim::PythonOp(...) -> (...) USES NODE: prim::ImplicitTensorToNum(Tensor a) -> Scalar # next PR Should be made interpreter primitives: USES NODE: prim::TupleUnpack(...) -> (...) USES NODE: prim::TupleSlice(...) -> (...) USES NODE: prim::TupleConstruct(...) -> (...) USES NODE: prim::ListUnpack(...) -> (...) USES NODE: prim::ListConstruct(...) -> (...) USES NODE: prim::DictConstruct(...) -> (...) USES NODE: prim::Constant() -> (...) USES NODE: prim::isinstance(...) -> (...) USES NODE: prim::CreateObject(...) -> (...) USES NODE: prim::fork(...) -> (...) USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack Should be made into vararg operators, i.e. the operators last argument should be an IValue that contains the number of arguments. USES NODE: prim::FusedConcat(...) -> (...) USES NODE: prim::MMTreeReduce(...) -> (...) USES NODE: prim::MMBatchSide(...) -> (...) USES NODE: prim::ConstantChunk(...) -> (...) USES NODE: prim::AutogradAnyNonZero(...) -> bool USES NODE: prim::BroadcastSizes(...) -> (...) USES NODE: prim::ChunkSizes(...) -> (...) USES NODE: aten::format(str self, ...) -> str USES NODE: prim::Print(...) -> (...) fixed: USES NODE: aten::extend(Tensor[](a!) self, Tensor [] other) -> () USES NODE: aten::copy(Tensor[](a) self) -> Tensor[] USES NODE: aten::extend(int[](a!) self, int [] other) -> () USES NODE: aten::copy(int[](a) self) -> int[] USES NODE: aten::extend(float[](a!) self, float [] other) -> () USES NODE: aten::copy(float[](a) self) -> float[] USES NODE: aten::extend(bool[](a!) self, bool [] other) -> () USES NODE: aten::copy(bool[](a) self) -> bool[] USES NODE: aten::extend(t[](a!) self, t [] other) -> () USES NODE: aten::copy(t[](a) self) -> t[] USES NODE: aten::keys(Dict(str, t) self) -> str[](*) USES NODE: aten::values(Dict(str, t) self) -> t[](*) USES NODE: aten::dict((str, tVal)[] inputs) -> Dict(str, tVal) USES NODE: aten::keys(Dict(int, t) self) -> int[](*) USES NODE: aten::values(Dict(int, t) self) -> t[](*) USES NODE: aten::dict((int, tVal)[] inputs) -> Dict(int, tVal) USES NODE: aten::keys(Dict(float, t) self) -> float[](*) USES NODE: aten::values(Dict(float, t) self) -> t[](*) USES NODE: aten::dict((float, tVal)[] inputs) -> Dict(float, tVal) USES NODE: aten::keys(Dict(Tensor, t) self) -> Tensor[](*) USES NODE: aten::values(Dict(Tensor, t) self) -> t[](*) USES NODE: aten::dict((Tensor, tVal)[] inputs) -> Dict(Tensor, tVal) USES NODE: aten::test_vartype2(t a, t[] b) -> (t[]) USES NODE: aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor USES NODE: aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor USES NODE: prim::is_none(int? a) -> bool USES NODE: aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::sorted(t[](a) self) -> (t[]) USES NODE: aten::sort(t[](a!) self, bool reverse=False) -> () USES NODE: aten::test_vartype(t[] a, t b) -> (t) USES NODE: prim::unchecked_unwrap_optional(t(a)? optional) -> t(a) USES NODE: prim::unchecked_cast(...) -> (...) USES NODE: aten::dict() -> Dict(str, Tensor) USES NODE: prim::Load(...) -> (...) USES NODE: prim::Store(...) -> (...) USES NODE: prim::Drop(...) -> (...) USES NODE: aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor USES NODE: aten::as_tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor ``` Test Plan: Imported from OSS Differential Revision: D19615387 Pulled By: zdevito fbshipit-source-id: 95298c3c4249b9f812c332d13f0fb79daeecb662
2020-02-12 22:45:44 +00:00
[](Stack& stack) {
push(stack, IValue());
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_none(int? a) -> bool",
remove unnecessary Node* ops (#32760) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32760 Minor changes to the way ops are implemented to remove incidental use of Node* in the operator implementation. Current state for operators that previously took Node: ``` TBD: USES NODE: prim::DifferentiableGraph(...) -> (...) USES NODE: prim::profile(...) -> (...) USES NODE: prim::FusionGroup(...) -> (...) USES NODE: prim::PythonOp(...) -> (...) USES NODE: prim::ImplicitTensorToNum(Tensor a) -> Scalar # next PR Should be made interpreter primitives: USES NODE: prim::TupleUnpack(...) -> (...) USES NODE: prim::TupleSlice(...) -> (...) USES NODE: prim::TupleConstruct(...) -> (...) USES NODE: prim::ListUnpack(...) -> (...) USES NODE: prim::ListConstruct(...) -> (...) USES NODE: prim::DictConstruct(...) -> (...) USES NODE: prim::Constant() -> (...) USES NODE: prim::isinstance(...) -> (...) USES NODE: prim::CreateObject(...) -> (...) USES NODE: prim::fork(...) -> (...) USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack Should be made into vararg operators, i.e. the operators last argument should be an IValue that contains the number of arguments. USES NODE: prim::FusedConcat(...) -> (...) USES NODE: prim::MMTreeReduce(...) -> (...) USES NODE: prim::MMBatchSide(...) -> (...) USES NODE: prim::ConstantChunk(...) -> (...) USES NODE: prim::AutogradAnyNonZero(...) -> bool USES NODE: prim::BroadcastSizes(...) -> (...) USES NODE: prim::ChunkSizes(...) -> (...) USES NODE: aten::format(str self, ...) -> str USES NODE: prim::Print(...) -> (...) fixed: USES NODE: aten::extend(Tensor[](a!) self, Tensor [] other) -> () USES NODE: aten::copy(Tensor[](a) self) -> Tensor[] USES NODE: aten::extend(int[](a!) self, int [] other) -> () USES NODE: aten::copy(int[](a) self) -> int[] USES NODE: aten::extend(float[](a!) self, float [] other) -> () USES NODE: aten::copy(float[](a) self) -> float[] USES NODE: aten::extend(bool[](a!) self, bool [] other) -> () USES NODE: aten::copy(bool[](a) self) -> bool[] USES NODE: aten::extend(t[](a!) self, t [] other) -> () USES NODE: aten::copy(t[](a) self) -> t[] USES NODE: aten::keys(Dict(str, t) self) -> str[](*) USES NODE: aten::values(Dict(str, t) self) -> t[](*) USES NODE: aten::dict((str, tVal)[] inputs) -> Dict(str, tVal) USES NODE: aten::keys(Dict(int, t) self) -> int[](*) USES NODE: aten::values(Dict(int, t) self) -> t[](*) USES NODE: aten::dict((int, tVal)[] inputs) -> Dict(int, tVal) USES NODE: aten::keys(Dict(float, t) self) -> float[](*) USES NODE: aten::values(Dict(float, t) self) -> t[](*) USES NODE: aten::dict((float, tVal)[] inputs) -> Dict(float, tVal) USES NODE: aten::keys(Dict(Tensor, t) self) -> Tensor[](*) USES NODE: aten::values(Dict(Tensor, t) self) -> t[](*) USES NODE: aten::dict((Tensor, tVal)[] inputs) -> Dict(Tensor, tVal) USES NODE: aten::test_vartype2(t a, t[] b) -> (t[]) USES NODE: aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor USES NODE: aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor USES NODE: prim::is_none(int? a) -> bool USES NODE: aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::sorted(t[](a) self) -> (t[]) USES NODE: aten::sort(t[](a!) self, bool reverse=False) -> () USES NODE: aten::test_vartype(t[] a, t b) -> (t) USES NODE: prim::unchecked_unwrap_optional(t(a)? optional) -> t(a) USES NODE: prim::unchecked_cast(...) -> (...) USES NODE: aten::dict() -> Dict(str, Tensor) USES NODE: prim::Load(...) -> (...) USES NODE: prim::Store(...) -> (...) USES NODE: prim::Drop(...) -> (...) USES NODE: aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor USES NODE: aten::as_tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor ``` Test Plan: Imported from OSS Differential Revision: D19615387 Pulled By: zdevito fbshipit-source-id: 95298c3c4249b9f812c332d13f0fb79daeecb662
2020-02-12 22:45:44 +00:00
[](Stack& stack) {
IValue a = pop(stack);
if (a.isNone()) {
push(stack, true);
} else {
push(stack, false);
}
return 0;
},
aliasAnalysisFromSchema()),
});
// Constant propagation will run test_none and produce a None,
// testing that its type is set appropriately and schema matching doesn't
// fail when running is_none
auto r = std::make_shared<Graph>();
auto& g = *r;
auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
g.registerOutput(out_bool);
ConstantPropagation(r);
auto nodes = r->block()->nodes();
// checking that constant propagation ran wo/failure
AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
}
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
void testModuleDefine() {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
def add_it(self, x, b : int = 4):
return self.foo + x + b
)");
auto result = m.run_method("add_it", torch::ones({}));
AT_ASSERT(result.toTensor().item<float>() == 6);
First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea
2019-04-11 20:30:42 +00:00
}
void testModuleConversion() {
Module m("test");
{
// test cuda to cpu for params and buffers
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
m.register_buffer("bar", torch::ones({}, at::kCUDA));
m.to(at::kCUDA);
m.to(at::kCPU);
AT_ASSERT(m.attr("foo").toTensor().device().is_cpu());
AT_ASSERT(m.attr("bar").toTensor().device().is_cpu());
}
{
// test cpu to cuda for params and buffers
m.register_parameter("foo", torch::ones({}), false);
m.register_buffer("bar", torch::ones({}));
m.to(at::kCUDA);
AT_ASSERT(m.attr("foo").toTensor().device().is_cuda());
AT_ASSERT(m.attr("bar").toTensor().device().is_cuda());
}
}
static int testPassValue = 0;
void fakePass(std::shared_ptr<Graph>& g) {
testPassValue++;
return;
}
RegisterPass p(fakePass);
void testPassManagement() {
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%a):
return (%a))IR",
&*graph);
std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
improved TorchScript traceback (#33834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33834 This changes how we report Tracebacks to make them more clear when there are both serialized and non-serialized ranges. It now looks like: ``` Traceback (most recent call last): File "foo.py", line 25, in <module> s2(a, b) File "/scratch/zdevito/pytorch/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/__torch__.py", line 7, in forward x: Tensor, y: Tensor) -> Tensor: return (self).bar(x, y, ) ~~~~~~~~~ <--- HERE def bar(self: __torch__.Moo, x: Tensor, File "code/__torch__.py", line 11, in bar x: Tensor, y: Tensor) -> Tensor: _0 = (self).baz(x, y, ) ~~~~~~~~~ <--- HERE _1 = torch.ones([3], dtype=None, layout=None, device=None, pin_memory=None) return torch.add(_0, _1, alpha=1) File "code/__torch__.py", line 17, in baz x: Tensor, y: Tensor) -> Tensor: return torch.add(x, y, alpha=1) ~~~~~~~~~ <--- HERE Traceback of TorchScript, original code (most recent call last): File "foo.py", line 11, in forward def forward(self, x, y): return self.bar(x, y) ~~~~~~~~ <--- HERE File "foo.py", line 9, in bar def bar(self, x, y): return self.baz(x, y) + torch.ones(3) ~~~~~~~~ <--- HERE File "foo.py", line 7, in baz def baz(self, x, y): return x + y ~~~~~ <--- HERE RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1 ``` It follows Python convension of having the most important information last and reading from the bottom up. Changes: * Moved the error message to the end, to copy Python * Report original traceback separate from serialized traceback * Make sure root functions have names in the interpreter trace. Test Plan: Imported from OSS Differential Revision: D20126136 Pulled By: zdevito fbshipit-source-id: fd01f9985e5d74e04c4d064c02e8bc320f4fac13
2020-03-03 20:24:28 +00:00
GraphExecutor executor(graph, "");
executor.run(stack);
return stack;
};
run(graph, stack);
// we will not run fusion in simple mode
if (!getExecutorMode()) {
AT_ASSERT(testPassValue);
}
}
static void checkShape(
Node* n,
std::vector<int64_t> expected,
bool prev = true) {
auto profile = (prev) ? n->inputs().at(0)->node() : n;
auto tp = profile->output()->type();
auto ptp = tp->expect<TensorType>();
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
}
void testInsertAndEliminateRedundantGuards() {
static const auto basic_example = R"JIT(
def basic(x, y):
a = x + y
b = x * y
c = x + 1
d = a - c
e = b - c
return d + e
)JIT";
auto cu = compile(basic_example);
auto& fun = cu->get_function("basic");
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
auto stack = createStack({x, y});
// introduce some profiling information
improved TorchScript traceback (#33834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33834 This changes how we report Tracebacks to make them more clear when there are both serialized and non-serialized ranges. It now looks like: ``` Traceback (most recent call last): File "foo.py", line 25, in <module> s2(a, b) File "/scratch/zdevito/pytorch/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/__torch__.py", line 7, in forward x: Tensor, y: Tensor) -> Tensor: return (self).bar(x, y, ) ~~~~~~~~~ <--- HERE def bar(self: __torch__.Moo, x: Tensor, File "code/__torch__.py", line 11, in bar x: Tensor, y: Tensor) -> Tensor: _0 = (self).baz(x, y, ) ~~~~~~~~~ <--- HERE _1 = torch.ones([3], dtype=None, layout=None, device=None, pin_memory=None) return torch.add(_0, _1, alpha=1) File "code/__torch__.py", line 17, in baz x: Tensor, y: Tensor) -> Tensor: return torch.add(x, y, alpha=1) ~~~~~~~~~ <--- HERE Traceback of TorchScript, original code (most recent call last): File "foo.py", line 11, in forward def forward(self, x, y): return self.bar(x, y) ~~~~~~~~ <--- HERE File "foo.py", line 9, in bar def bar(self, x, y): return self.baz(x, y) + torch.ones(3) ~~~~~~~~ <--- HERE File "foo.py", line 7, in baz def baz(self, x, y): return x + y ~~~~~ <--- HERE RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1 ``` It follows Python convension of having the most important information last and reading from the bottom up. Changes: * Moved the error message to the end, to copy Python * Report original traceback separate from serialized traceback * Make sure root functions have names in the interpreter trace. Test Plan: Imported from OSS Differential Revision: D20126136 Pulled By: zdevito fbshipit-source-id: fd01f9985e5d74e04c4d064c02e8bc320f4fac13
2020-03-03 20:24:28 +00:00
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto copy = pr->profiled_graph_->copy();
InsertGuards(copy);
auto nodes = copy->block()->nodes();
auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
return n->kind() == prim::Guard;
});
ASSERT_NE(guard, nodes.end());
ASSERT_EQ(
guard->input()->type()->expect<TensorType>()->sizes().size(),
c10::nullopt);
checkShape(*guard, {2, 3}, false);
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 12);
// now eliminate as many guards as possible
// we should be left with two guards on x and y's defs
EliminateRedundantGuards(copy);
num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 2);
}
void testInsertBailOuts() {
static const auto basic_example = R"JIT(
def basic_loop(x, y):
a = x + 1
b = y + 2
c = x + y + 3
for i in range(10):
a = a + b
# invariant
d = b * c
#
a = a - d
e = a + 4
return e
)JIT";
auto cu = compile(basic_example);
auto& fun = cu->get_function("basic_loop");
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
auto stack = createStack({x, y});
// introduce some profiling information
improved TorchScript traceback (#33834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33834 This changes how we report Tracebacks to make them more clear when there are both serialized and non-serialized ranges. It now looks like: ``` Traceback (most recent call last): File "foo.py", line 25, in <module> s2(a, b) File "/scratch/zdevito/pytorch/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/__torch__.py", line 7, in forward x: Tensor, y: Tensor) -> Tensor: return (self).bar(x, y, ) ~~~~~~~~~ <--- HERE def bar(self: __torch__.Moo, x: Tensor, File "code/__torch__.py", line 11, in bar x: Tensor, y: Tensor) -> Tensor: _0 = (self).baz(x, y, ) ~~~~~~~~~ <--- HERE _1 = torch.ones([3], dtype=None, layout=None, device=None, pin_memory=None) return torch.add(_0, _1, alpha=1) File "code/__torch__.py", line 17, in baz x: Tensor, y: Tensor) -> Tensor: return torch.add(x, y, alpha=1) ~~~~~~~~~ <--- HERE Traceback of TorchScript, original code (most recent call last): File "foo.py", line 11, in forward def forward(self, x, y): return self.bar(x, y) ~~~~~~~~ <--- HERE File "foo.py", line 9, in bar def bar(self, x, y): return self.baz(x, y) + torch.ones(3) ~~~~~~~~ <--- HERE File "foo.py", line 7, in baz def baz(self, x, y): return x + y ~~~~~ <--- HERE RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1 ``` It follows Python convension of having the most important information last and reading from the bottom up. Changes: * Moved the error message to the end, to copy Python * Report original traceback separate from serialized traceback * Make sure root functions have names in the interpreter trace. Test Plan: Imported from OSS Differential Revision: D20126136 Pulled By: zdevito fbshipit-source-id: fd01f9985e5d74e04c4d064c02e8bc320f4fac13
2020-03-03 20:24:28 +00:00
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto copy = pr->profiled_graph_->copy();
InsertGuards(copy);
EliminateRedundantGuards(copy);
auto nodes = copy->block()->nodes();
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 3);
InsertBailOuts(copy);
auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
ASSERT_EQ(num_guards, num_bailouts);
std::vector<Node*> bailouts(num_bailouts);
std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
for (auto blo : bailouts) {
ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
}
}
void testProfiler() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCPU);
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
auto g = build_lstm();
auto stack = createStack({input, hx, cx, w_ih, w_hh});
auto& opt_graph = *g.get();
ArgumentSpecCreator arg_spec_creator(opt_graph);
ArgumentSpec spec =
arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
Specialize Optional[T] to T (or subtype for Tensor) or None when executing graph (#18407) Summary: This patch specializes `Optional[Tensor]` graph inputs to either a `DimensionedTensorType` (if a Tensor is passed) or `NoneType`. Other `Optional[T]` are specialized to `T` or `None`. - For unwrapping (checked and unchecked) we need to keep the output type, as IR code that follows unwrapping may not work with NoneType (just as it doesn't deal with Optional). While it would not be hit during execution, it will run against the (legitimate) assumptions of the analysis passes. - Function lookup currently will not match NoneType when it expects optional (I'm not entirely sure why this doesn't lead to unhappyness currently, but hey), I amend this at the level of the function matching code (`operator.cpp`), but see Adam's comments. We would run into trouble if we needed to select between functions whose signature only differs in Optional types with different subtypes, but we would have the same problem when calling them directly, so I would think this is OK. - It would enable throwing away branches we can't hit. This also reduces the "blockyness" of the graph, so it may be easier to apply optimizations (e.g. fuse things in `if t is None: ...` and outside the `if`. - Arguments passed into `Optional[Tensor]` arguments will get shape information, which is very handy. - It get's rid of the problem that tensors passed into Optional arguments get requires_grad set erroneously #18270 (though that also affects lists, which aren't fixed here). - `Optional[List[int]]` is needed for #18697. - We're changing typing in a more subtle way than the `TensorType`->`DimensionedTensorType`. - In particular, specializing to NoneType loses the Type information captured in the `OptionalType` element type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18407 Reviewed By: zdevito Differential Revision: D15216808 Pulled By: eellison fbshipit-source-id: 01f1a7643deaf4962c3f55eff2070d54b0e54b69
2019-05-06 21:54:10 +00:00
arg_spec_creator.specializeTypes(opt_graph, spec);
auto pr = ProfilingRecord::instrumentGraph(g);
improved TorchScript traceback (#33834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33834 This changes how we report Tracebacks to make them more clear when there are both serialized and non-serialized ranges. It now looks like: ``` Traceback (most recent call last): File "foo.py", line 25, in <module> s2(a, b) File "/scratch/zdevito/pytorch/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/__torch__.py", line 7, in forward x: Tensor, y: Tensor) -> Tensor: return (self).bar(x, y, ) ~~~~~~~~~ <--- HERE def bar(self: __torch__.Moo, x: Tensor, File "code/__torch__.py", line 11, in bar x: Tensor, y: Tensor) -> Tensor: _0 = (self).baz(x, y, ) ~~~~~~~~~ <--- HERE _1 = torch.ones([3], dtype=None, layout=None, device=None, pin_memory=None) return torch.add(_0, _1, alpha=1) File "code/__torch__.py", line 17, in baz x: Tensor, y: Tensor) -> Tensor: return torch.add(x, y, alpha=1) ~~~~~~~~~ <--- HERE Traceback of TorchScript, original code (most recent call last): File "foo.py", line 11, in forward def forward(self, x, y): return self.bar(x, y) ~~~~~~~~ <--- HERE File "foo.py", line 9, in bar def bar(self, x, y): return self.baz(x, y) + torch.ones(3) ~~~~~~~~ <--- HERE File "foo.py", line 7, in baz def baz(self, x, y): return x + y ~~~~~ <--- HERE RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1 ``` It follows Python convension of having the most important information last and reading from the bottom up. Changes: * Moved the error message to the end, to copy Python * Report original traceback separate from serialized traceback * Make sure root functions have names in the interpreter trace. Test Plan: Imported from OSS Differential Revision: D20126136 Pulled By: zdevito fbshipit-source-id: fd01f9985e5d74e04c4d064c02e8bc320f4fac13
2020-03-03 20:24:28 +00:00
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto begin = pr->profiled_graph_->block()->nodes().begin();
auto end = pr->profiled_graph_->block()->nodes().end();
auto mm =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mm; });
ASSERT_NE(mm, end);
std::vector<int64_t> mm_expected{4, 256};
std::vector<int64_t> eltwise{4, 512};
checkShape(*mm, mm_expected);
auto sigmoid_n = std::find_if(
begin, end, [](Node* n) { return n->kind() == aten::sigmoid; });
ASSERT_NE(sigmoid_n, end);
checkShape(*sigmoid_n, eltwise);
auto tanh_n =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
checkShape(*tanh_n, eltwise);
}
void testCallStack() {
const auto text = R"(
def ham(x):
return x/7
def bar(x):
return x*3
def baz(x):
return ham(x)*x
def foo(x):
return bar(x)*baz(x)*11
)";
auto cu = compile(text);
const Function& foo = cu->get_function("foo");
for (Node* n : foo.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::i) {
continue;
}
int v = n->i(attr::value);
switch (v) {
case 3: {
// Const 3 comes from function 'bar', which gets inlined to 'foo'.
// The callstack for the corresponding node should contain only the
// function 'bar'.
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 1);
ASSERT_EQ(callstack_vector[0].first, &cu->get_function("bar"));
break;
}
case 7: {
// Const 7 comes from function 'ham', which gets inlined to 'baz',
// which is then inlined to 'foo'. The callstack for the corresponding
// node should contain these two functions.
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 2);
ASSERT_EQ(callstack_vector[0].first, &cu->get_function("baz"));
ASSERT_EQ(callstack_vector[1].first, &cu->get_function("ham"));
break;
}
case 11: {
// Const 11 comes from function 'foo', which is not inlined anywhere
// and thus it should not have a callstack.
ASSERT_FALSE(n->callstack());
break;
}
}
}
}
// Check that inlining doesn't corrupt callstack of the callee's nodes.
const Function& baz = cu->get_function("baz");
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::i) {
continue;
}
int v = n->i(attr::value);
ASSERT_TRUE(v == 7);
// Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
// was also inlined into 'foo', but when looking at the graph of 'baz' we
// should only see a callstack of depth 1 (containing only 'ham').
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 1);
ASSERT_EQ(callstack_vector[0].first, &cu->get_function("ham"));
}
}
}
void testCallStackCaching() {
const auto text = R"(
def a(x):
print("a1")
print("a2")
return x
def b(x):
print("b1")
print("b2")
a(x)
return x
def c(x):
print("c1")
print("c2")
b(x)
return x
)";
auto cu = compile(text);
const Function& baz = cu->get_function("c");
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::s) {
continue;
}
std::string v = n->s(attr::value);
if (n->callstack()) {
callstack_objects[v] = n->callstack()->get();
}
}
}
// We expect to see nodes prim::Constant[value="a1"] and
// prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
// the same (a->b->c), so we want to make sure we're not creating different
// callstack entries for them.
ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
}
void testAutogradSymbols() {
Symbol sym = Symbol::fromQualString("aten::test_symbol");
Graph graph;
auto node = graph.create(sym);
TORCH_CHECK(canRunWithAutograd(node));
sym = Symbol::fromQualString("prim::test_symbol");
node = graph.create(sym);
TORCH_CHECK(canRunWithAutograd(node));
sym = Symbol::fromQualString("prim::FusionGroup");
node = graph.create(sym);
TORCH_CHECK(!canRunWithAutograd(node));
sym = Symbol::fromQualString("custom::test_symbol");
node = graph.create(sym);
TORCH_CHECK(!canRunWithAutograd(node));
}
} // namespace jit
} // namespace torch