mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
pr feedback
This commit is contained in:
parent
c9f7f2eff4
commit
35bddb6b7e
3 changed files with 26 additions and 13 deletions
|
|
@ -99,6 +99,7 @@ class View(Function):
|
|||
def primspec(g, i, sizes):
|
||||
n = g.appendNode(g.create("Reshape", [i]).is_("shape", sizes))
|
||||
real_out = g.appendNode(g.createSelect(n, 0))
|
||||
# TODO: remove from toffee
|
||||
g.appendNode(g.createSelect(n, 1))
|
||||
return real_out
|
||||
|
||||
|
|
@ -327,6 +328,7 @@ class Concat(Function):
|
|||
def primspec(g, dim, *inputs):
|
||||
n = g.appendNode(g.create("Concat", inputs).i_("axis", dim))
|
||||
real = g.appendNode(g.createSelect(n, 0))
|
||||
# TODO: remove from toffee
|
||||
g.appendNode(g.createSelect(n, 1))
|
||||
return real
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@
|
|||
#include "torch/csrc/Exceptions.h"
|
||||
|
||||
#include <toffee/toffee.pb.h>
|
||||
#include <toffee/schema.h>
|
||||
#include <toffee/defs/schema.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
|
||||
#include "torch/csrc/autograd/functions/convolution.h"
|
||||
#include "torch/csrc/jit/dead_code_elimination.h"
|
||||
#include "torch/csrc/utils/functional.h"
|
||||
|
||||
#include <fstream>
|
||||
#undef NDEBUG
|
||||
|
|
@ -25,14 +26,6 @@ std::string node_name(Node* n) {
|
|||
return n->uniqueName();
|
||||
}
|
||||
|
||||
template<typename R, typename T>
|
||||
static std::vector<R> mapv(const std::vector<T> & inputs, std::function<R(const T &)> fn) {
|
||||
std::vector<R> r;
|
||||
r.reserve(inputs.size());
|
||||
for(auto & input : inputs)
|
||||
r.push_back(fn(input));
|
||||
return r;
|
||||
}
|
||||
// transform PythonOps and Cpp Ops into Node's that match ToffeeIR
|
||||
// semantics.
|
||||
// Eventually this should just be part of init_pass but we should avoid
|
||||
|
|
@ -43,7 +36,7 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
|
|||
ctx.graph = std::make_shared<Graph>();
|
||||
for (auto input : g->inputs())
|
||||
env[input] = ctx.graph->addInput()->setType(input->typeOption());
|
||||
auto envFn = [&](Node * n) {
|
||||
auto envFn = [&env](Node * n) {
|
||||
return env.at(n);
|
||||
};
|
||||
// put the new outputs in our environment map, and
|
||||
|
|
@ -76,7 +69,7 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
|
|||
JIT_ASSERT(env.count(value) > 0);
|
||||
IR_ELSEIFM(CppOp)
|
||||
if (auto fn = std::dynamic_pointer_cast<autograd::HasPrimSpec>(value->fn)) {
|
||||
auto outputs = fn->primspec(&ctx, mapv<Node*,Node*>(node->inputs(),envFn));
|
||||
auto outputs = fn->primspec(&ctx, fmap<Node*,Node*>(node->inputs(),envFn));
|
||||
setOutputs(node,outputs);
|
||||
} else {
|
||||
throw std::runtime_error("CppOp doesn't define primspec " + value->name());
|
||||
|
|
@ -112,7 +105,7 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
|
|||
}
|
||||
py_primspec_args[input_nr++] = obj;
|
||||
}
|
||||
py::object raw_output = py::reinterpret_borrow<py::object>(PyObject_CallObject(primspec_fn.ptr(), py_primspec_args.ptr()));
|
||||
py::object raw_output = py::reinterpret_steal<py::object>(PyObject_CallObject(primspec_fn.ptr(), py_primspec_args.ptr()));
|
||||
if(!raw_output)
|
||||
throw python_error();
|
||||
if(raw_output.ptr() == Py_None)
|
||||
|
|
@ -132,7 +125,9 @@ std::shared_ptr<Graph> ToToffeeIR(std::shared_ptr<Graph>& g) {
|
|||
if(node->hasMultipleOutputs()) {
|
||||
int i = 0;
|
||||
for(auto s : node->uses()) {
|
||||
env[s.user] = ctx.graph->createSelect(n_,i++);
|
||||
auto new_node = ctx.graph->createSelect(n_,i++);
|
||||
new_node->setType(s.user->typeOption());
|
||||
env[s.user] = new_node;
|
||||
}
|
||||
} else {
|
||||
env[node] = n_;
|
||||
|
|
|
|||
16
torch/csrc/utils/functional.h
Normal file
16
torch/csrc/utils/functional.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
||||
template<typename R, typename T>
|
||||
static std::vector<R> fmap(const std::vector<T> & inputs, std::function<R(const T &)> fn) {
|
||||
std::vector<R> r;
|
||||
r.reserve(inputs.size());
|
||||
for(auto & input : inputs)
|
||||
r.push_back(fn(input));
|
||||
return r;
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in a new issue