mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Generate nearbyint for Round in tensorexpr llvm codegen, match torch.round result (#104430)
Fixes #103465, which matches the behavior of `torch.round` ([doc](https://pytorch.org/docs/stable/generated/torch.round.html?highlight=round#torch.round)) - “round half to even” Using the repro code, the output is correct: ``` Using torch version=2.1.0a0+git84fedbc and optimization enabled=True [cpu ] Python = 2, Torch = 2, Torch traced = 2 Using torch version=2.1.0a0+git84fedbc and optimization enabled=False [cpu ] Python = 2, Torch = 2, Torch traced = 2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/104430 Approved by: https://github.com/jgong5, https://github.com/davidberard98
This commit is contained in:
parent
8ce3a18b6a
commit
a2fe6953bc
2 changed files with 15 additions and 2 deletions
|
|
@ -926,6 +926,19 @@ class TestTensorExprFuser(BaseTestClass):
|
|||
# print("Failed on dev=", dev, "function=", torch_fn)
|
||||
# # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
|
||||
|
||||
|
||||
def test_round_2(self):
|
||||
def round(x):
|
||||
return torch.round(x)
|
||||
|
||||
for data_type in [torch.float32, torch.double]:
|
||||
a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type)
|
||||
traced = torch.jit.trace(round, (a))
|
||||
x = warmup_and_run_forward(traced, a)
|
||||
self.assertLastGraphAllFused()
|
||||
y = round(x)
|
||||
self.assertEqual(x, y)
|
||||
|
||||
def test_rand_like(self):
|
||||
N = 1 << 16
|
||||
|
||||
|
|
|
|||
|
|
@ -2024,7 +2024,7 @@ void LLVMCodeGenImpl::visit(IntrinsicsPtr v) {
|
|||
SIMD_UNARY_MATH_CASE(kFloor, "floorf", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kCeil, "ceilf", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kTrunc, "truncf", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kRound, "roundf", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kRound, "nearbyint", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kErf, "erff", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kErfc, "erfcf", FloatTy_)
|
||||
SIMD_UNARY_MATH_CASE(kTan, "tanf", FloatTy_)
|
||||
|
|
@ -2082,7 +2082,7 @@ void LLVMCodeGenImpl::visit(IntrinsicsPtr v) {
|
|||
SIMD_UNARY_MATH_CASE(kFloor, "floor", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kCeil, "ceil", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kTrunc, "trunc", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kRound, "round", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kRound, "nearbyint", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kErf, "erf", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kErfc, "erfc", DoubleTy_)
|
||||
SIMD_UNARY_MATH_CASE(kTan, "tan", DoubleTy_)
|
||||
|
|
|
|||
Loading…
Reference in a new issue