mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42266 function `lerp_kernel_scalar` and `lerp_kernel_tensor` are not covered in `Aten/native/cpu/LerpKernel.cpp`, add tests to cover them Test Plan: 1. Test locally to check new lines are covered 2. CI https://pxl.cl/1fXPd Reviewed By: malfet Differential Revision: D22832164 fbshipit-source-id: b1eaabbf8bfa08b4dedc1a468abfdfb619a50e3c
32 lines
954 B
C++
32 lines
954 B
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
using namespace torch::nn;
|
|
using namespace std;
|
|
struct OperationTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(OperationTest, Lerp) {
|
|
for (auto i = 0; i < 10; i++) {
|
|
// test lerp_kernel_scalar
|
|
auto start = torch::rand({3, 5});
|
|
auto end = torch::rand({3, 5});
|
|
auto scalar = 0.5;
|
|
// expected and actual
|
|
auto scalar_expected = start + scalar * (end - start);
|
|
auto out = torch::lerp(start, end, scalar);
|
|
// compare
|
|
ASSERT_EQ(out.dtype(), scalar_expected.dtype());
|
|
ASSERT_TRUE(out.allclose(scalar_expected));
|
|
|
|
// test lerp_kernel_tensor
|
|
auto weight = torch::rand({3, 5});
|
|
// expected and actual
|
|
auto tensor_expected = start + weight * (end - start);
|
|
out = torch::lerp(start, end, weight);
|
|
// compare
|
|
ASSERT_EQ(out.dtype(), tensor_expected.dtype());
|
|
ASSERT_TRUE(out.allclose(tensor_expected));
|
|
}
|
|
}
|