trace s_copy_ (#15690)

Summary:
s_copy_ was previously special-cased for out of place tracing.
This adds support for inplace tracing, which fixes tracing of
inception_v3

Fixes #15216
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15690

Differential Revision: D13572011

Pulled By: zdevito

fbshipit-source-id: 1d565dec039a4b8c59179254285e61d2517ef9a9
This commit is contained in:
Zachary DeVito 2019-01-03 12:14:17 -08:00 committed by Facebook Github Bot
parent 78442f04fc
commit d42e90991b
6 changed files with 33 additions and 17 deletions

View file

@ -12,6 +12,6 @@ graph(%0 : Double(4, 4)) {
%11 : int = prim::Constant[value=0]()
%12 : Device = prim::Constant[value="cpu"]()
%13 : Double(4, 4) = aten::zeros(%9, %10, %11, %12)
%14 : Double(4, 4) = aten::expand_as(%0, %13)
%14 : Double(4, 4) = aten::copy_(%13, %0)
return (%14);
}

View file

@ -1160,6 +1160,14 @@ class TestJit(JitTestCase):
if RUN_CUDA_MULTI_GPU:
run(device="cuda:1")
def test_trace_indexed_assignment(self):
def stuff(x, y):
x = x.clone()
x[0] = y
return x
example = torch.rand(3, 4)
self.checkTrace(stuff, (example, example[0] + 1))
# TODO: implement
@unittest.expectedFailure
def test_output_unflatten(self):
@ -8098,9 +8106,7 @@ a")
x[i, :] = torch.zeros(4)
return x
self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
'Output nr 1. of the traced function does not match the '
'corresponding output of the Python function')
self.checkTrace(foo, (torch.rand(3, 4),))
def test_trace_checker_inplace_on_view(self):
def foo(x):

View file

@ -1,5 +1,4 @@
#!/bin/bash
set -e
echo "Running pre-commit flake8"
FLAKE8_OUT=$(python tools/flake8_hook.py)
if [[ ${FLAKE8_OUT} ]]
@ -8,7 +7,6 @@ then
exit 1
fi
if [ $(which clang-tidy) ]
then
echo "Running pre-commit clang-tidy"

View file

@ -239,16 +239,24 @@ void VariableType::set_data(Tensor & self, Tensor new_data) const {
as_variable_ref(self).set_data(new_data);
}
Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
jit::Node* node = nullptr;
jit::Value* output = nullptr;
if(torch::jit::tracer::isTracing()) {
auto& graph = jit::tracer::getTracingState()->graph;
// if you have no views of self, then an in place copy is equivalent to
// making sure we expand src to the same size as self
node = graph->create(jit::aten::expand_as, /*num_outputs=*/0);
jit::tracer::addInputs(node, "src", src);
jit::tracer::addInputs(node, "self", self);
graph->appendNode(node);
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
auto& graph = state.graph;
if (state.force_outplace) {
// if you have no views of self, then an in place copy is equivalent to
// making sure we expand src to the same size as self
jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
jit::tracer::addInputs(node, "src", src);
jit::tracer::addInputs(node, "self", self);
graph->appendNode(node);
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
output = node->output();
} else {
output = graph->insert(
jit::aten::copy_,
{jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
}
}
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
// it automatically
@ -270,7 +278,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
increment_version(self);
rebase_history(as_variable_ref( self ), std::move(grad_fn));
if(torch::jit::tracer::isTracing()) {
jit::tracer::addOutput(node, self);
jit::tracer::setOutput(output, self);
}
return self;
}

View file

@ -192,7 +192,10 @@ void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
}
void addOutput(Node* node, const at::Tensor& output) {
Value* value = node->addOutput();
setOutput(node->addOutput(), output);
}
void setOutput(Value* value, const at::Tensor& output) {
if (output.defined()) {
value->inferTypeFrom(output);
setValueTrace(autograd::as_variable_ref(output), value);

View file

@ -287,6 +287,7 @@ void addOutput(Node* node, T&&) {
" in the JIT tracer. File a bug report.");
}
TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
TORCH_API void setOutput(Value* value, const at::Tensor& output);
TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
TORCH_API autograd::Variable getSizeOf(