Make operator names consistent between export_opnames and the lite interpreter (#34674)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34674

Two changes to make sure the op_names dumped in export_opnames() are consistent to what are actually used in bytecode.
* Inline graph before dumping the operator names.
* Use code of the graph (which is used in bytecode) instead of the nodes of graph.

Test Plan: Imported from OSS

Differential Revision: D20610715

Pulled By: iseeyuan

fbshipit-source-id: 53fa9c3b36f4f242b7f2b99b421f4adf20d4b1f6
This commit is contained in:
Martin Yuan 2020-03-26 22:43:27 -07:00 committed by Facebook GitHub Bot
parent 8c90ae11b3
commit da4e68faed
7 changed files with 66 additions and 30 deletions

View file

@ -4627,6 +4627,7 @@ graph(%Ra, %Rb):
super(Bar, self).__init__()
self.sub = Foo()
@torch.jit.script_method
def forward(self, x):
# type: (Tensor) -> Tensor
return self.sub.forward(x)
@ -4634,7 +4635,7 @@ graph(%Ra, %Rb):
bar = Bar()
ops = torch.jit.export_opnames(bar)
expected = ['aten::add.Tensor', 'aten::mul.Scalar']
self.assertEqual(ops, expected)
self.assertTrue(set(expected).issubset(set(ops)))
def test_pytorch_jit_env_off(self):
import subprocess

View file

@ -454,12 +454,16 @@ class TestScriptPy3(JitTestCase):
return mod_list[0].forward(x) + mod_list[1].forward(x)
scripted_M_mod = torch.jit.script(M())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'])
# Temporarily test empty output because lite interpreter does not support interface call
# Replace it with the issubset call when interface call is supported.
self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
# self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
# set(torch.jit.export_opnames(scripted_M_mod))))
scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::add.Tensor', 'aten::mul.Scalar'])
self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
# self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
# set(torch.jit.export_opnames(scripted_M_mod))))
if __name__ == '__main__':

View file

@ -4,6 +4,7 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/custom_class.h>
@ -45,7 +46,6 @@ using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
OpCode parseOpCode(const char* str);
namespace {
IValue expect_field(
IValue tup,
@ -61,6 +61,7 @@ IValue expect_field(
return row->elements().at(1);
}
namespace {
void print_unsupported_ops_and_throw(
const std::unordered_set<std::string>& unsupported_ops) {
std::string error_message("{");
@ -83,13 +84,19 @@ void parseMethods(
new mobile::Function(c10::QualifiedName(function_name)));
const auto& ins_list =
expect_field(table, "instructions", 0).toTuple()->elements();
expect_field(table, "instructions", BYTECODE_INDEX_INSTRUCTION)
.toTuple()
->elements();
const auto& ops_list =
expect_field(table, "operators", 1).toTuple()->elements();
expect_field(table, "operators", BYTECODE_INDEX_OPERATOR)
.toTuple()
->elements();
const auto& consts_list =
expect_field(table, "constants", 2).toTuple()->elements();
expect_field(table, "constants", BYTECODE_INDEX_CONSTANT)
.toTuple()
->elements();
const auto& types_list =
expect_field(table, "types", 3).toTuple()->elements();
expect_field(table, "types", BYTECODE_INDEX_TYPE).toTuple()->elements();
const auto& register_size = expect_field(table, "register_size", 4).toInt();
for (const auto& ins : ins_list) {

View file

@ -1,12 +1,15 @@
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/autograd/symbolic.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <torch/csrc/onnx/onnx.h>
#include <ATen/core/functional.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <onnx/checker.h>
#include <onnx/onnx_pb.h>
@ -969,28 +972,29 @@ void check_onnx_proto(const std::string& proto_string) {
}
namespace {
void export_opnames(const Module& m, std::set<std::string>& opnames) {
for (const auto& method : m.get_methods()) {
const auto& func = method.function();
for (const auto& node : func.graph()->nodes()) {
auto schema = node->maybeSchema();
if (schema) {
auto opname = schema->operator_name();
std::string namestr = opname.name;
if (!opname.overload_name.empty()) {
namestr += "." + opname.overload_name;
}
opnames.emplace(namestr);
}
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
std::vector<c10::IValue> elements;
moduleMethodsTuple(m, elements);
for (const auto& element : elements) {
auto table = element.toTuple()->elements()[1];
const auto& ops_list =
expect_field(table, "operators", BYTECODE_INDEX_OPERATOR)
.toTuple()
->elements();
for (const auto& op : ops_list) {
auto op_item = op.toTuple()->elements();
TORCH_CHECK(
op_item.size() == 2,
"There should be two parts in an operator name.");
opnames.emplace(
op_item[0].toString()->string() + "." +
op_item[1].toString()->string());
}
}
for (const auto& sub_m : m.children()) {
export_opnames(sub_m, opnames);
}
}
} // namespace
std::vector<std::string> export_opnames(const Module& m) {
std::vector<std::string> export_opnames(const script::Module& m) {
std::set<std::string> names;
export_opnames(m, names);
return std::vector<std::string>(names.begin(), names.end());

View file

@ -137,6 +137,7 @@ void setstateTuple(const IValue& ivalue, std::vector<c10::IValue>& elements) {
}
}
}
} // namespace
void moduleMethodsTuple(
const Module& module,
@ -151,8 +152,6 @@ void moduleMethodsTuple(
setstateTuple(module._ivalue(), elements);
}
} // namespace
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
GetExtraFilesHook() = hook;
}

View file

@ -0,0 +1,10 @@
#pragma once
namespace torch {
namespace jit {
constexpr size_t BYTECODE_INDEX_INSTRUCTION = 0;
constexpr size_t BYTECODE_INDEX_OPERATOR = 1;
constexpr size_t BYTECODE_INDEX_CONSTANT = 2;
constexpr size_t BYTECODE_INDEX_TYPE = 3;
} // namespace jit
} // namespace torch

View file

@ -0,0 +1,11 @@
#pragma once
// Functions that are used in both import and export processes
namespace torch {
namespace jit {
void moduleMethodsTuple(
const Module& module,
std::vector<c10::IValue>& elements);
IValue expect_field(IValue tup, const std::string& expected_name, size_t entry);
} // namespace jit
} // namespace torch