pr feedback

This commit is contained in:
Zach DeVito 2017-08-23 11:14:48 -07:00 committed by Soumith Chintala
parent c9f7f2eff4
commit 35bddb6b7e
3 changed files with 26 additions and 13 deletions

View file

@ -99,6 +99,7 @@ class View(Function):
def primspec(g, i, sizes):
n = g.appendNode(g.create("Reshape", [i]).is_("shape", sizes))
real_out = g.appendNode(g.createSelect(n, 0))
# TODO: remove from toffee
g.appendNode(g.createSelect(n, 1))
return real_out
@ -327,6 +328,7 @@ class Concat(Function):
def primspec(g, dim, *inputs):
n = g.appendNode(g.create("Concat", inputs).i_("axis", dim))
real = g.appendNode(g.createSelect(n, 0))
# TODO: remove from toffee
g.appendNode(g.createSelect(n, 1))
return real

View file

@ -4,12 +4,13 @@
#include "torch/csrc/Exceptions.h"
#include <toffee/toffee.pb.h>
#include <toffee/schema.h>
#include <toffee/defs/schema.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "torch/csrc/autograd/functions/convolution.h"
#include "torch/csrc/jit/dead_code_elimination.h"
#include "torch/csrc/utils/functional.h"
#include <fstream>
#undef NDEBUG
@ -25,14 +26,6 @@ std::string node_name(Node* n) {
return n->uniqueName();
}
template<typename R, typename T>
static std::vector<R> mapv(const std::vector<T> & inputs, std::function<R(const T &)> fn) {
std::vector<R> r;
r.reserve(inputs.size());
for(auto & input : inputs)
r.push_back(fn(input));
return r;
}
// transform PythonOps and Cpp Ops into Node's that match ToffeeIR
// semantics.
// Eventually this should just be part of init_pass but we should avoid
@ -43,7 +36,7 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
ctx.graph = std::make_shared<Graph>();
for (auto input : g->inputs())
env[input] = ctx.graph->addInput()->setType(input->typeOption());
auto envFn = [&](Node * n) {
auto envFn = [&env](Node * n) {
return env.at(n);
};
// put the new outputs in our environment map, and
@ -76,7 +69,7 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
JIT_ASSERT(env.count(value) > 0);
IR_ELSEIFM(CppOp)
if (auto fn = std::dynamic_pointer_cast<autograd::HasPrimSpec>(value->fn)) {
auto outputs = fn->primspec(&ctx, mapv<Node*,Node*>(node->inputs(),envFn));
auto outputs = fn->primspec(&ctx, fmap<Node*,Node*>(node->inputs(),envFn));
setOutputs(node,outputs);
} else {
throw std::runtime_error("CppOp doesn't define primspec " + value->name());
@ -112,7 +105,7 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
}
py_primspec_args[input_nr++] = obj;
}
py::object raw_output = py::reinterpret_borrow<py::object>(PyObject_CallObject(primspec_fn.ptr(), py_primspec_args.ptr()));
py::object raw_output = py::reinterpret_steal<py::object>(PyObject_CallObject(primspec_fn.ptr(), py_primspec_args.ptr()));
if(!raw_output)
throw python_error();
if(raw_output.ptr() == Py_None)
@ -132,7 +125,9 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
if(node->hasMultipleOutputs()) {
int i = 0;
for(auto s : node->uses()) {
env[s.user] = ctx.graph->createSelect(n_,i++);
auto new_node = ctx.graph->createSelect(n_,i++);
new_node->setType(s.user->typeOption());
env[s.user] = new_node;
}
} else {
env[node] = n_;

View file

@ -0,0 +1,16 @@
#pragma once
#include <vector>
namespace torch {
template<typename R, typename T>
static std::vector<R> fmap(const std::vector<T> & inputs, std::function<R(const T &)> fn) {
std::vector<R> r;
r.reserve(inputs.size());
for(auto & input : inputs)
r.push_back(fn(input));
return r;
}
}