Add numerical stability to SoftmaxGrad test inputs. (#3857)

* Increase the tolerance for SoftmaxGrad CPU-GPU compare tests.

* Increase the tolerance for SoftmaxGrad CPU-GPU compare tests.

* Add 1e-2 to Y for numerical stability.

* build break.

* comments.

* PR feedback.

* PR feedback.
This commit is contained in:
M. Zeeshan Siddiqui 2020-05-11 17:59:24 -07:00 committed by GitHub
parent af7d453435
commit c46a9e8d65
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -44,8 +44,10 @@ static void TestSoftmaxGrad(const std::vector<int64_t>& dY_dims,
// create rand inputs
RandomValueGenerator random{};
std::vector<float> dY_data = random.Uniform<float>(dY_dims, -10.0f, 10.0f);
std::vector<float> Y_data = random.Uniform<float>(Y_dims, -10.0f, 10.0f);
std::vector<float> dY_data = random.Uniform<float>(dY_dims, 0.0f, 1.0f);
// Add 1e-2 for numerical stability to prevent zero probability.
std::vector<float> Y_data = random.Uniform<float>(Y_dims, 0.02f, 1.02f);
test.AddInput<float>("dY", dY_dims, dY_data);
test.AddInput<float>("Y", Y_dims, Y_data);
@ -59,22 +61,14 @@ TEST(CudaKernelTest, SoftmaxGrad_SmallTensor) {
std::vector<int64_t> dY_dims{8, 2, 128, 128};
std::vector<int64_t> Y_dims{8, 2, 128, 128};
std::vector<int64_t> dX_dims{8, 2, 128, 128};
const double per_sample_tolerance = 1e-4;
const double relative_per_sample_tolerance = 5e-3;
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, per_sample_tolerance, relative_per_sample_tolerance);
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims);
}
// TODO fix flaky test
// failing random seed: 552621640
TEST(CudaKernelTest, DISABLED_SoftmaxGrad_LargeTensor) {
TEST(CudaKernelTest, SoftmaxGrad_LargeTensor) {
std::vector<int64_t> dY_dims{8, 16, 512, 512};
std::vector<int64_t> Y_dims{8, 16, 512, 512};
std::vector<int64_t> dX_dims{8, 16, 512, 512};
const double per_sample_tolerance = 1e-4;
const double relative_per_sample_tolerance = 5e-3;
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, per_sample_tolerance, relative_per_sample_tolerance);
TestSoftmaxGrad(dY_dims, Y_dims, dX_dims);
}
} // namespace test