diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc index 3d7357dac0..8501479877 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc @@ -44,8 +44,10 @@ static void TestSoftmaxGrad(const std::vector& dY_dims, // create rand inputs RandomValueGenerator random{}; - std::vector dY_data = random.Uniform(dY_dims, -10.0f, 10.0f); - std::vector Y_data = random.Uniform(Y_dims, -10.0f, 10.0f); + std::vector dY_data = random.Uniform(dY_dims, 0.0f, 1.0f); + // Add 1e-2 for numerical stability to prevent zero probability. + std::vector Y_data = random.Uniform(Y_dims, 0.02f, 1.02f); + test.AddInput("dY", dY_dims, dY_data); test.AddInput("Y", Y_dims, Y_data); @@ -59,22 +61,14 @@ TEST(CudaKernelTest, SoftmaxGrad_SmallTensor) { std::vector dY_dims{8, 2, 128, 128}; std::vector Y_dims{8, 2, 128, 128}; std::vector dX_dims{8, 2, 128, 128}; - - const double per_sample_tolerance = 1e-4; - const double relative_per_sample_tolerance = 5e-3; - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, per_sample_tolerance, relative_per_sample_tolerance); + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims); } -// TODO fix flaky test -// failing random seed: 552621640 -TEST(CudaKernelTest, DISABLED_SoftmaxGrad_LargeTensor) { +TEST(CudaKernelTest, SoftmaxGrad_LargeTensor) { std::vector dY_dims{8, 16, 512, 512}; std::vector Y_dims{8, 16, 512, 512}; std::vector dX_dims{8, 16, 512, 512}; - - const double per_sample_tolerance = 1e-4; - const double relative_per_sample_tolerance = 5e-3; - TestSoftmaxGrad(dY_dims, Y_dims, dX_dims, per_sample_tolerance, relative_per_sample_tolerance); + TestSoftmaxGrad(dY_dims, Y_dims, dX_dims); } } // namespace test