pytorch/test/cpp/api/operations.cpp
Yujun Zhao 9ea7476d9c Add test to lerp function (#42266)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42266

function `lerp_kernel_scalar` and `lerp_kernel_tensor` are not covered in `Aten/native/cpu/LerpKernel.cpp`, add tests to cover them

Test Plan:
1. Test locally to check new lines are covered
2. CI

https://pxl.cl/1fXPd

Reviewed By: malfet

Differential Revision: D22832164

fbshipit-source-id: b1eaabbf8bfa08b4dedc1a468abfdfb619a50e3c
2020-07-29 22:47:37 -07:00

32 lines
954 B
C++

#include <gtest/gtest.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
using namespace torch::nn;
using namespace std;
struct OperationTest : torch::test::SeedingFixture {};
TEST_F(OperationTest, Lerp) {
for (auto i = 0; i < 10; i++) {
// test lerp_kernel_scalar
auto start = torch::rand({3, 5});
auto end = torch::rand({3, 5});
auto scalar = 0.5;
// expected and actual
auto scalar_expected = start + scalar * (end - start);
auto out = torch::lerp(start, end, scalar);
// compare
ASSERT_EQ(out.dtype(), scalar_expected.dtype());
ASSERT_TRUE(out.allclose(scalar_expected));
// test lerp_kernel_tensor
auto weight = torch::rand({3, 5});
// expected and actual
auto tensor_expected = start + weight * (end - start);
out = torch::lerp(start, end, weight);
// compare
ASSERT_EQ(out.dtype(), tensor_expected.dtype());
ASSERT_TRUE(out.allclose(tensor_expected));
}
}