diff --git a/onnxruntime/core/providers/cpu/rnn/rnn.h b/onnxruntime/core/providers/cpu/rnn/rnn.h index 2c3b91c272..3dfd708c40 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn.h @@ -35,8 +35,9 @@ class RNN : public OpKernel { } ORT_ENFORCE(activations_.size() == static_cast(num_directions)); - for (int direction = 1; direction < num_directions; direction++) { - ORT_ENFORCE(allowed_activations.find(activations_[direction]) != allowed_activations.end()); + for (int direction = 0; direction < num_directions; direction++) { + ORT_ENFORCE(allowed_activations.find(activations_[direction]) != allowed_activations.end(), + "RNN op: Invalid activation attribute - ", activations_[direction]); } } diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index b3bc19cddc..57cc5b9fa9 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -430,7 +430,7 @@ TEST(RNNTest, RNN_bidirectional_1) { std::vector R_dims = {num_directions, hidden_size, hidden_size}; std::vector R_data({// forward 1.0F, 1.0F, - 1.0F, 1.0F, + 1.0F, 1.0F, // reverse 1.0F, 1.0F, 1.0F, 1.0F}); @@ -837,5 +837,35 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); } +TEST(RNNTest, RNN_with_invalid_activation_load_failure) { + OpTester test("RNN"); + int64_t num_directions = 1, input_size = 1, hidden_size = 1, seq_length = 1; + + test.AddAttribute("activations", vector(num_directions, "Invalid_activation")); + test.AddAttribute("direction", "reverse"); + test.AddAttribute("hidden_size", hidden_size); + + int batch_size = 1; + + std::vector X_dims = {seq_length, batch_size, input_size}; + std::vector X_data{0.F}; + test.AddInput("X", X_dims, X_data); + + std::vector W_dims = {num_directions, hidden_size, input_size}; + std::vector W_data({0.F}); + test.AddInput("W", W_dims, W_data); + + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + std::vector R_data({0.F}); + test.AddInput("R", R_dims, R_data); + + std::vector Y_dims = {seq_length, num_directions, batch_size, hidden_size}; + std::vector Y_data({0.F}); + test.AddOutput("Y", Y_dims, Y_data); + + test.Run(OpTester::ExpectResult::kExpectFailure, "RNN op: Invalid activation attribute - Invalid_activation", + {kCudaExecutionProvider, kTensorrtExecutionProvider}); +} + } // namespace test } // namespace onnxruntime