mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Make add_relu an internal function (#46676)
Summary: Cleanup for 1.7 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46676 Reviewed By: gchanan Differential Revision: D24458565 Pulled By: albanD fbshipit-source-id: b1e4b4630233d3f1a4bac20e3077411d1ae17f7b
This commit is contained in:
parent
870a5a0d6d
commit
27e2ea4cea
9 changed files with 21 additions and 20 deletions
|
|
@ -37,9 +37,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
|
|||
m.impl("add.out", CppFunction::makeFallthrough());
|
||||
m.impl("add_.Scalar", CppFunction::makeFallthrough());
|
||||
m.impl("add_.Tensor", CppFunction::makeFallthrough());
|
||||
m.impl("add_relu.Tensor", CppFunction::makeFallthrough());
|
||||
m.impl("add_relu.out", CppFunction::makeFallthrough());
|
||||
m.impl("add_relu_.Tensor", CppFunction::makeFallthrough());
|
||||
m.impl("_add_relu.Tensor", CppFunction::makeFallthrough());
|
||||
m.impl("_add_relu.out", CppFunction::makeFallthrough());
|
||||
m.impl("_add_relu_.Tensor", CppFunction::makeFallthrough());
|
||||
m.impl("addcdiv", CppFunction::makeFallthrough());
|
||||
m.impl("addcdiv.out", CppFunction::makeFallthrough());
|
||||
m.impl("addcdiv_", CppFunction::makeFallthrough());
|
||||
|
|
|
|||
|
|
@ -388,19 +388,19 @@
|
|||
SparseCUDA: add_out_sparse_cuda
|
||||
MkldnnCPU: mkldnn_add_out
|
||||
|
||||
- func: add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
||||
- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: add_relu
|
||||
|
||||
- func: add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
- func: _add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: add_relu_
|
||||
|
||||
- func: add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: _add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: add_relu_out
|
||||
|
|
|
|||
|
|
@ -129,6 +129,8 @@ allow_list = [
|
|||
("aten::_foreach_addcdiv", datetime.date(2020, 10, 15)),
|
||||
("aten::_foreach_addcmul", datetime.date(2020, 10, 15)),
|
||||
("aten::conj", datetime.date(2020, 11, 10)),
|
||||
("aten::add_relu", datetime.date(2020, 10, 28)),
|
||||
("aten::add_relu_", datetime.date(2020, 10, 28)),
|
||||
]
|
||||
|
||||
def allow_listed(schema, allow_list):
|
||||
|
|
|
|||
|
|
@ -584,7 +584,7 @@ class TestJit(JitTestCase):
|
|||
m = torch.jit.load(buffer)
|
||||
new_res = m(a, b, c)
|
||||
FileCheck().check_not("aten::relu(") \
|
||||
.check("aten::add_relu(") \
|
||||
.check("aten::_add_relu(") \
|
||||
.run(m.graph)
|
||||
torch.testing.assert_allclose(orig_res, new_res)
|
||||
|
||||
|
|
@ -603,7 +603,7 @@ class TestJit(JitTestCase):
|
|||
m = torch.jit.load(buffer)
|
||||
new_res = m(a, b, c)
|
||||
FileCheck().check_not("aten::relu_(") \
|
||||
.check("aten::add_relu(") \
|
||||
.check("aten::_add_relu(") \
|
||||
.run(m.graph)
|
||||
torch.testing.assert_allclose(orig_res, new_res)
|
||||
|
||||
|
|
@ -634,10 +634,10 @@ class TestJit(JitTestCase):
|
|||
new_res = m(a_copy, b)
|
||||
FileCheck().check_not("aten::add_(") \
|
||||
.check_not("aten::relu_(") \
|
||||
.check("aten::add_relu_(") \
|
||||
.check("aten::_add_relu_(") \
|
||||
.run(m.graph)
|
||||
torch.testing.assert_allclose(orig_res, new_res)
|
||||
# Since add_relu_ does inplace mutation ensure
|
||||
# Since _add_relu_ does inplace mutation ensure
|
||||
# a_copy is modified
|
||||
torch.testing.assert_allclose(orig_res, a_copy)
|
||||
|
||||
|
|
@ -672,10 +672,10 @@ class TestJit(JitTestCase):
|
|||
new_res = m(a_copy, b)
|
||||
FileCheck().check_not("aten::add(") \
|
||||
.check_not("aten::relu_(") \
|
||||
.check("aten::add_relu(") \
|
||||
.check("aten::_add_relu(") \
|
||||
.run(m.graph)
|
||||
torch.testing.assert_allclose(orig_res, new_res)
|
||||
# Since add_relu_ with out=a does inplace mutation ensure
|
||||
# Since _add_relu_ with out=a does inplace mutation ensure
|
||||
# a_copy is modified
|
||||
torch.testing.assert_allclose(orig_res, a_copy)
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class TestOptimizer(unittest.TestCase):
|
|||
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
|
||||
.check_not("aten::add(") \
|
||||
.check_not("aten::relu(") \
|
||||
.check_count("aten::add_relu(", 1, exactly=True) \
|
||||
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
|
|
|||
|
|
@ -9210,7 +9210,7 @@ class TestAddRelu(TestCase):
|
|||
a = a + 5
|
||||
add_res = a + b
|
||||
relu_res = torch.relu(add_res)
|
||||
add_relu_res = torch.add_relu(a, b)
|
||||
add_relu_res = torch._VF._add_relu(a, b)
|
||||
|
||||
self.assertTrue(torch.allclose(add_relu_res, relu_res))
|
||||
|
||||
|
|
|
|||
|
|
@ -1903,7 +1903,7 @@
|
|||
- name: aten::resize_as_
|
||||
- name: aten::scalar_tensor
|
||||
- name: aten::to
|
||||
- name: aten::add_relu
|
||||
- name: aten::_add_relu
|
||||
depends:
|
||||
- name: aten::as_strided_
|
||||
- name: aten::copy_
|
||||
|
|
@ -1915,7 +1915,7 @@
|
|||
- name: aten::resize_
|
||||
- name: aten::resize_as_
|
||||
- name: aten::to
|
||||
- name: aten::add_relu_
|
||||
- name: aten::_add_relu_
|
||||
depends:
|
||||
- name: aten::as_strided_
|
||||
- name: aten::copy_
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
|
|||
return (%res))";
|
||||
std::string add_relu_fused = R"(
|
||||
graph(%a, %b, %alpha):
|
||||
%res = aten::add_relu(%a, %b, %alpha)
|
||||
%res = aten::_add_relu(%a, %b, %alpha)
|
||||
return (%res))";
|
||||
rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
|
|||
return (%res))";
|
||||
std::string add_inplace_relu_fused = R"(
|
||||
graph(%a, %b, %alpha):
|
||||
%res = aten::add_relu_(%a, %b, %alpha)
|
||||
%res = aten::_add_relu_(%a, %b, %alpha)
|
||||
return (%res))";
|
||||
rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
|
|||
return (%res))";
|
||||
std::string add_out_relu_fused = R"(
|
||||
graph(%a, %b, %alpha, %out):
|
||||
%res = aten::add_relu(%a, %b, %alpha, %out)
|
||||
%res = aten::_add_relu(%a, %b, %alpha, %out)
|
||||
return (%res))";
|
||||
|
||||
rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);
|
||||
|
|
|
|||
|
|
@ -209,7 +209,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.arccos: lambda input, out=None: -1,
|
||||
torch.acosh: lambda input, out=None: -1,
|
||||
torch.arccosh: lambda input, out=None: -1,
|
||||
torch.add_relu: lambda input, other, out=None: -1,
|
||||
torch.add: lambda input, other, out=None: -1,
|
||||
torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
|
||||
torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue