mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Enabling alias annotation checks for all operations during autograd tests (#46601)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46601 * except excluded tests and magic methods. https://github.com/pytorch/pytorch/issues/38731 Previously, we'd only do run these tests for inplace operations. Since this is a lot more tests, fixed these issues that came up when running them - - Updated schema of conj() to reflect existing behaviour. - Updated deepEquals method in check_alias_annotation.cpp to re-use the overloaded == operator. Previous implementation did not cover all types of IValues. - Corrected the order inputs are passed in during autograd testing of 'view' & 'reshape'. - Subbed out atn::ger with the func its aliased to, atn::outer, for testing. The alias annotation checking code doesn't handle aliased operators properly. ghstack-source-id: 114830903 Test Plan: Ran all tests in test:jit and verified they pass. Reviewed By: eellison Differential Revision: D24424955 fbshipit-source-id: 382d7e2585911b81b1573f21fff1d54a5e9a2054
This commit is contained in:
parent
33e82c0269
commit
adbb50ea67
7 changed files with 49 additions and 40 deletions
|
|
@ -314,7 +314,7 @@
|
|||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
|
||||
- func: conj(Tensor self) -> Tensor
|
||||
- func: conj(Tensor(a) self) -> Tensor(a)
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,7 @@ allow_list = [
|
|||
("aten::_foreach_addcdiv_", datetime.date(2020, 10, 15)),
|
||||
("aten::_foreach_addcdiv", datetime.date(2020, 10, 15)),
|
||||
("aten::_foreach_addcmul", datetime.date(2020, 10, 15)),
|
||||
("aten::conj", datetime.date(2020, 11, 10)),
|
||||
]
|
||||
|
||||
def allow_listed(schema, allow_list):
|
||||
|
|
|
|||
|
|
@ -15734,7 +15734,7 @@ def add_autograd_test(
|
|||
check_types=check_types)
|
||||
|
||||
# alias annotation testing
|
||||
if is_inplace and test_name not in EXCLUDE_SCRIPT:
|
||||
if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
|
||||
check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable)
|
||||
|
||||
check(name)
|
||||
|
|
|
|||
|
|
@ -6,30 +6,6 @@ namespace jit {
|
|||
|
||||
namespace {
|
||||
|
||||
// map from op alias -> normalized op
|
||||
static const std::unordered_map<Symbol, Symbol> alias_map = {
|
||||
{aten::absolute, aten::abs}, {aten::absolute_, aten::abs_},
|
||||
{aten::clip, aten::clamp}, {aten::clip_, aten::clamp_},
|
||||
{aten::linalg_det, aten::det}, {aten::ger, aten::outer},
|
||||
{aten::arccos, aten::acos}, {aten::arccos_, aten::acos_},
|
||||
{aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_},
|
||||
{aten::arctan, aten::atan}, {aten::arctan_, aten::atan_},
|
||||
{aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_},
|
||||
{aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_},
|
||||
{aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_},
|
||||
{aten::fix, aten::trunc}, {aten::fix_, aten::trunc_},
|
||||
{aten::negative, aten::neg}, {aten::negative_, aten::neg_},
|
||||
{aten::subtract, aten::sub}, {aten::subtract_, aten::sub_},
|
||||
{aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_},
|
||||
{aten::greater, aten::gt}, {aten::greater_, aten::gt_},
|
||||
{aten::less_equal, aten::le}, {aten::less_equal_, aten::le_},
|
||||
{aten::less, aten::lt}, {aten::less_, aten::lt_},
|
||||
{aten::not_equal, aten::ne}, {aten::not_equal_, aten::ne_},
|
||||
{aten::divide, aten::div}, {aten::divide_, aten::div_},
|
||||
{aten::multiply, aten::mul}, {aten::multiply_, aten::mul_},
|
||||
{aten::true_divide, aten::div}, {aten::true_divide_, aten::div_},
|
||||
};
|
||||
|
||||
void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
|
||||
WithInsertPoint insert_guard{node};
|
||||
auto graph = node->owningGraph();
|
||||
|
|
@ -53,8 +29,8 @@ void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
|
|||
// difficult to consumer for downstream user of the IR, such as our own
|
||||
// optimization passes here, we convert op aliases into a standard form
|
||||
bool normalizeOpAliases(graph_node_list_iterator& iter) {
|
||||
auto alias = alias_map.find(iter->kind());
|
||||
if (alias != alias_map.end()) {
|
||||
auto alias = getOperatorAliasMap().find(iter->kind());
|
||||
if (alias != getOperatorAliasMap().end()) {
|
||||
replaceNodeWithNewSymbol(*iter, alias->second);
|
||||
iter.destroyCurrent();
|
||||
return true;
|
||||
|
|
@ -79,6 +55,33 @@ void NormalizeOps(Block* block) {
|
|||
|
||||
} // namespace
|
||||
|
||||
const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
|
||||
// map from op alias -> normalized op
|
||||
static const std::unordered_map<Symbol, Symbol> alias_map = {
|
||||
{aten::absolute, aten::abs}, {aten::absolute_, aten::abs_},
|
||||
{aten::clip, aten::clamp}, {aten::clip_, aten::clamp_},
|
||||
{aten::linalg_det, aten::det}, {aten::ger, aten::outer},
|
||||
{aten::arccos, aten::acos}, {aten::arccos_, aten::acos_},
|
||||
{aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_},
|
||||
{aten::arctan, aten::atan}, {aten::arctan_, aten::atan_},
|
||||
{aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_},
|
||||
{aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_},
|
||||
{aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_},
|
||||
{aten::fix, aten::trunc}, {aten::fix_, aten::trunc_},
|
||||
{aten::negative, aten::neg}, {aten::negative_, aten::neg_},
|
||||
{aten::subtract, aten::sub}, {aten::subtract_, aten::sub_},
|
||||
{aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_},
|
||||
{aten::greater, aten::gt}, {aten::greater_, aten::gt_},
|
||||
{aten::less_equal, aten::le}, {aten::less_equal_, aten::le_},
|
||||
{aten::less, aten::lt}, {aten::less_, aten::lt_},
|
||||
{aten::not_equal, aten::ne}, {aten::not_equal_, aten::ne_},
|
||||
{aten::divide, aten::div}, {aten::divide_, aten::div_},
|
||||
{aten::multiply, aten::mul}, {aten::multiply_, aten::mul_},
|
||||
{aten::true_divide, aten::div}, {aten::true_divide_, aten::div_},
|
||||
};
|
||||
return alias_map;
|
||||
}
|
||||
|
||||
void NormalizeOps(const std::shared_ptr<Graph>& graph) {
|
||||
NormalizeOps(graph->block());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,5 +12,7 @@ namespace jit {
|
|||
// Currently only handles normalization of op aliases.
|
||||
TORCH_API void NormalizeOps(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/normalize_ops.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -61,19 +62,11 @@ Stack deepCopy(const Stack& stack) {
|
|||
}
|
||||
|
||||
bool deepEquals(const IValue& lhs, const IValue& rhs) {
|
||||
if (lhs.isInt() && rhs.isInt()) {
|
||||
return lhs.toInt() == rhs.toInt();
|
||||
} else if (lhs.isDouble() && rhs.isDouble()) {
|
||||
return lhs.toDouble() == rhs.toDouble();
|
||||
} else if (lhs.isNone() && rhs.isNone()) {
|
||||
return true;
|
||||
} else if (lhs.isIntList() && rhs.isIntList()) {
|
||||
return lhs.toIntVector() == rhs.toIntVector();
|
||||
} else if (lhs.isTensor() && rhs.isTensor()) {
|
||||
if (lhs.isTensor() && rhs.isTensor()) {
|
||||
return lhs.toTensor().equal(rhs.toTensor());
|
||||
}
|
||||
|
||||
throw std::runtime_error("Deep equals not implemented for type");
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
struct AliasAndIValue {
|
||||
|
|
@ -146,6 +139,16 @@ const Node* findNodeForOp(
|
|||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for alias-ed operator names
|
||||
const auto aliasOp = torch::jit::getOperatorAliasMap().find(opName);
|
||||
AT_ASSERT(aliasOp != torch::jit::getOperatorAliasMap().end());
|
||||
for (const auto node : g.nodes()) {
|
||||
if (node->kind() == aliasOp->second) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
AT_ASSERT(false);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -586,13 +586,13 @@ def method_tests():
|
|||
('transpose', (S, S, S), (2, 0), '3d', (False,)),
|
||||
('t', (1, 2), NO_ARGS, '', (False,)),
|
||||
('view', (S, S, S), (S * S, S), '', (False,)),
|
||||
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
|
||||
('view', (torch.Size([S * S, S]),), (S, S, S), 'size', (False,)),
|
||||
('view', (S,), (S,), '1d', (False,)),
|
||||
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
|
||||
('view', (), (1,), 'scalar_to_1d', (False,)),
|
||||
('ravel', (S, S, S), NO_ARGS, '', (False,)),
|
||||
('reshape', (S, S, S), (S * S, S), '', (False,)),
|
||||
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
|
||||
('reshape', (torch.Size([S * S, S]),), (S, S, S), 'size', (False,)),
|
||||
('reshape', (S,), (S,), '1d', (False,)),
|
||||
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
|
||||
('reshape', (), (1,), 'scalar_to_1d', (False,)),
|
||||
|
|
|
|||
Loading…
Reference in a new issue