mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
trace s_copy_ (#15690)
Summary: s_copy_ was previously special-cased for out of place tracing. This adds support for inplace tracing, which fixes tracing of inception_v3 Fixes #15216 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15690 Differential Revision: D13572011 Pulled By: zdevito fbshipit-source-id: 1d565dec039a4b8c59179254285e61d2517ef9a9
This commit is contained in:
parent
78442f04fc
commit
d42e90991b
6 changed files with 33 additions and 17 deletions
|
|
@ -12,6 +12,6 @@ graph(%0 : Double(4, 4)) {
|
|||
%11 : int = prim::Constant[value=0]()
|
||||
%12 : Device = prim::Constant[value="cpu"]()
|
||||
%13 : Double(4, 4) = aten::zeros(%9, %10, %11, %12)
|
||||
%14 : Double(4, 4) = aten::expand_as(%0, %13)
|
||||
%14 : Double(4, 4) = aten::copy_(%13, %0)
|
||||
return (%14);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1160,6 +1160,14 @@ class TestJit(JitTestCase):
|
|||
if RUN_CUDA_MULTI_GPU:
|
||||
run(device="cuda:1")
|
||||
|
||||
def test_trace_indexed_assignment(self):
|
||||
def stuff(x, y):
|
||||
x = x.clone()
|
||||
x[0] = y
|
||||
return x
|
||||
example = torch.rand(3, 4)
|
||||
self.checkTrace(stuff, (example, example[0] + 1))
|
||||
|
||||
# TODO: implement
|
||||
@unittest.expectedFailure
|
||||
def test_output_unflatten(self):
|
||||
|
|
@ -8098,9 +8106,7 @@ a")
|
|||
x[i, :] = torch.zeros(4)
|
||||
return x
|
||||
|
||||
self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
|
||||
'Output nr 1. of the traced function does not match the '
|
||||
'corresponding output of the Python function')
|
||||
self.checkTrace(foo, (torch.rand(3, 4),))
|
||||
|
||||
def test_trace_checker_inplace_on_view(self):
|
||||
def foo(x):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
echo "Running pre-commit flake8"
|
||||
FLAKE8_OUT=$(python tools/flake8_hook.py)
|
||||
if [[ ${FLAKE8_OUT} ]]
|
||||
|
|
@ -8,7 +7,6 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ $(which clang-tidy) ]
|
||||
then
|
||||
echo "Running pre-commit clang-tidy"
|
||||
|
|
|
|||
|
|
@ -239,16 +239,24 @@ void VariableType::set_data(Tensor & self, Tensor new_data) const {
|
|||
as_variable_ref(self).set_data(new_data);
|
||||
}
|
||||
Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
|
||||
jit::Node* node = nullptr;
|
||||
jit::Value* output = nullptr;
|
||||
if(torch::jit::tracer::isTracing()) {
|
||||
auto& graph = jit::tracer::getTracingState()->graph;
|
||||
// if you have no views of self, then an in place copy is equivalent to
|
||||
// making sure we expand src to the same size as self
|
||||
node = graph->create(jit::aten::expand_as, /*num_outputs=*/0);
|
||||
jit::tracer::addInputs(node, "src", src);
|
||||
jit::tracer::addInputs(node, "self", self);
|
||||
graph->appendNode(node);
|
||||
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
|
||||
const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
|
||||
auto& graph = state.graph;
|
||||
if (state.force_outplace) {
|
||||
// if you have no views of self, then an in place copy is equivalent to
|
||||
// making sure we expand src to the same size as self
|
||||
jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
|
||||
jit::tracer::addInputs(node, "src", src);
|
||||
jit::tracer::addInputs(node, "self", self);
|
||||
graph->appendNode(node);
|
||||
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
|
||||
output = node->output();
|
||||
} else {
|
||||
output = graph->insert(
|
||||
jit::aten::copy_,
|
||||
{jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
|
||||
}
|
||||
}
|
||||
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
|
||||
// it automatically
|
||||
|
|
@ -270,7 +278,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
|
|||
increment_version(self);
|
||||
rebase_history(as_variable_ref( self ), std::move(grad_fn));
|
||||
if(torch::jit::tracer::isTracing()) {
|
||||
jit::tracer::addOutput(node, self);
|
||||
jit::tracer::setOutput(output, self);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -192,7 +192,10 @@ void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
|
|||
}
|
||||
|
||||
void addOutput(Node* node, const at::Tensor& output) {
|
||||
Value* value = node->addOutput();
|
||||
setOutput(node->addOutput(), output);
|
||||
}
|
||||
|
||||
void setOutput(Value* value, const at::Tensor& output) {
|
||||
if (output.defined()) {
|
||||
value->inferTypeFrom(output);
|
||||
setValueTrace(autograd::as_variable_ref(output), value);
|
||||
|
|
|
|||
|
|
@ -287,6 +287,7 @@ void addOutput(Node* node, T&&) {
|
|||
" in the JIT tracer. File a bug report.");
|
||||
}
|
||||
TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
|
||||
TORCH_API void setOutput(Value* value, const at::Tensor& output);
|
||||
TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
|
||||
|
||||
TORCH_API autograd::Variable getSizeOf(
|
||||
|
|
|
|||
Loading…
Reference in a new issue