Enabling alias annotation checks for all operations during autograd tests (#46601)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46601

* except excluded tests and magic methods.

https://github.com/pytorch/pytorch/issues/38731

Previously, we'd only do run these tests for inplace operations. Since this is a lot more tests, fixed these issues that came up when running them -
- Updated schema of conj() to reflect existing behaviour.
- Updated deepEquals method in check_alias_annotation.cpp to re-use the overloaded == operator. Previous implementation did not cover all types of IValues.
- Corrected the order inputs are passed in during autograd testing of 'view' & 'reshape'.
- Subbed out atn::ger with the func its aliased to, atn::outer, for testing. The alias annotation checking code doesn't handle aliased operators properly.
ghstack-source-id: 114830903

Test Plan: Ran all tests in test:jit and verified they pass.

Reviewed By: eellison

Differential Revision: D24424955

fbshipit-source-id: 382d7e2585911b81b1573f21fff1d54a5e9a2054
This commit is contained in:
Rahul Nambiar 2020-10-21 19:59:47 -07:00 committed by Facebook GitHub Bot
parent 33e82c0269
commit adbb50ea67
7 changed files with 49 additions and 40 deletions

View file

@ -314,7 +314,7 @@
use_c10_dispatcher: full
variants: function
- func: conj(Tensor self) -> Tensor
- func: conj(Tensor(a) self) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method

View file

@ -128,6 +128,7 @@ allow_list = [
("aten::_foreach_addcdiv_", datetime.date(2020, 10, 15)),
("aten::_foreach_addcdiv", datetime.date(2020, 10, 15)),
("aten::_foreach_addcmul", datetime.date(2020, 10, 15)),
("aten::conj", datetime.date(2020, 11, 10)),
]
def allow_listed(schema, allow_list):

View file

@ -15734,7 +15734,7 @@ def add_autograd_test(
check_types=check_types)
# alias annotation testing
if is_inplace and test_name not in EXCLUDE_SCRIPT:
if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable)
check(name)

View file

@ -6,30 +6,6 @@ namespace jit {
namespace {
// map from op alias -> normalized op
static const std::unordered_map<Symbol, Symbol> alias_map = {
{aten::absolute, aten::abs}, {aten::absolute_, aten::abs_},
{aten::clip, aten::clamp}, {aten::clip_, aten::clamp_},
{aten::linalg_det, aten::det}, {aten::ger, aten::outer},
{aten::arccos, aten::acos}, {aten::arccos_, aten::acos_},
{aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_},
{aten::arctan, aten::atan}, {aten::arctan_, aten::atan_},
{aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_},
{aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_},
{aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_},
{aten::fix, aten::trunc}, {aten::fix_, aten::trunc_},
{aten::negative, aten::neg}, {aten::negative_, aten::neg_},
{aten::subtract, aten::sub}, {aten::subtract_, aten::sub_},
{aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_},
{aten::greater, aten::gt}, {aten::greater_, aten::gt_},
{aten::less_equal, aten::le}, {aten::less_equal_, aten::le_},
{aten::less, aten::lt}, {aten::less_, aten::lt_},
{aten::not_equal, aten::ne}, {aten::not_equal_, aten::ne_},
{aten::divide, aten::div}, {aten::divide_, aten::div_},
{aten::multiply, aten::mul}, {aten::multiply_, aten::mul_},
{aten::true_divide, aten::div}, {aten::true_divide_, aten::div_},
};
void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
WithInsertPoint insert_guard{node};
auto graph = node->owningGraph();
@ -53,8 +29,8 @@ void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
// difficult to consumer for downstream user of the IR, such as our own
// optimization passes here, we convert op aliases into a standard form
bool normalizeOpAliases(graph_node_list_iterator& iter) {
auto alias = alias_map.find(iter->kind());
if (alias != alias_map.end()) {
auto alias = getOperatorAliasMap().find(iter->kind());
if (alias != getOperatorAliasMap().end()) {
replaceNodeWithNewSymbol(*iter, alias->second);
iter.destroyCurrent();
return true;
@ -79,6 +55,33 @@ void NormalizeOps(Block* block) {
} // namespace
const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
// map from op alias -> normalized op
static const std::unordered_map<Symbol, Symbol> alias_map = {
{aten::absolute, aten::abs}, {aten::absolute_, aten::abs_},
{aten::clip, aten::clamp}, {aten::clip_, aten::clamp_},
{aten::linalg_det, aten::det}, {aten::ger, aten::outer},
{aten::arccos, aten::acos}, {aten::arccos_, aten::acos_},
{aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_},
{aten::arctan, aten::atan}, {aten::arctan_, aten::atan_},
{aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_},
{aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_},
{aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_},
{aten::fix, aten::trunc}, {aten::fix_, aten::trunc_},
{aten::negative, aten::neg}, {aten::negative_, aten::neg_},
{aten::subtract, aten::sub}, {aten::subtract_, aten::sub_},
{aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_},
{aten::greater, aten::gt}, {aten::greater_, aten::gt_},
{aten::less_equal, aten::le}, {aten::less_equal_, aten::le_},
{aten::less, aten::lt}, {aten::less_, aten::lt_},
{aten::not_equal, aten::ne}, {aten::not_equal_, aten::ne_},
{aten::divide, aten::div}, {aten::divide_, aten::div_},
{aten::multiply, aten::mul}, {aten::multiply_, aten::mul_},
{aten::true_divide, aten::div}, {aten::true_divide_, aten::div_},
};
return alias_map;
}
void NormalizeOps(const std::shared_ptr<Graph>& graph) {
NormalizeOps(graph->block());
}

View file

@ -12,5 +12,7 @@ namespace jit {
// Currently only handles normalization of op aliases.
TORCH_API void NormalizeOps(const std::shared_ptr<Graph>& graph);
const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap();
} // namespace jit
} // namespace torch

View file

@ -1,5 +1,6 @@
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/normalize_ops.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch {
@ -61,19 +62,11 @@ Stack deepCopy(const Stack& stack) {
}
bool deepEquals(const IValue& lhs, const IValue& rhs) {
if (lhs.isInt() && rhs.isInt()) {
return lhs.toInt() == rhs.toInt();
} else if (lhs.isDouble() && rhs.isDouble()) {
return lhs.toDouble() == rhs.toDouble();
} else if (lhs.isNone() && rhs.isNone()) {
return true;
} else if (lhs.isIntList() && rhs.isIntList()) {
return lhs.toIntVector() == rhs.toIntVector();
} else if (lhs.isTensor() && rhs.isTensor()) {
if (lhs.isTensor() && rhs.isTensor()) {
return lhs.toTensor().equal(rhs.toTensor());
}
throw std::runtime_error("Deep equals not implemented for type");
return lhs == rhs;
}
struct AliasAndIValue {
@ -146,6 +139,16 @@ const Node* findNodeForOp(
return node;
}
}
// Check for alias-ed operator names
const auto aliasOp = torch::jit::getOperatorAliasMap().find(opName);
AT_ASSERT(aliasOp != torch::jit::getOperatorAliasMap().end());
for (const auto node : g.nodes()) {
if (node->kind() == aliasOp->second) {
return node;
}
}
AT_ASSERT(false);
}

View file

@ -586,13 +586,13 @@ def method_tests():
('transpose', (S, S, S), (2, 0), '3d', (False,)),
('t', (1, 2), NO_ARGS, '', (False,)),
('view', (S, S, S), (S * S, S), '', (False,)),
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
('view', (torch.Size([S * S, S]),), (S, S, S), 'size', (False,)),
('view', (S,), (S,), '1d', (False,)),
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
('view', (), (1,), 'scalar_to_1d', (False,)),
('ravel', (S, S, S), NO_ARGS, '', (False,)),
('reshape', (S, S, S), (S * S, S), '', (False,)),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
('reshape', (torch.Size([S * S, S]),), (S, S, S), 'size', (False,)),
('reshape', (S,), (S,), '1d', (False,)),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
('reshape', (), (1,), 'scalar_to_1d', (False,)),