diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc index 0f32bedec2..7b7f46ccf0 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc @@ -134,8 +134,10 @@ Status EinsumComputePreprocessor::ProcessSubscripts() { num_of_ellipsis_dims_ = static_cast(current_num_of_ellipsis_dims); } - // We reserve 'EinsumOp::num_of_letters' for broadcasted dims as we only allow 'a' - 'z' (0 - 25) for non-broadcasted dims - // We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to) during broadcasting related post-processing + // We reserve 'EinsumOp::num_of_letters' for broadcasted dims as we only allow 'a' - 'z' + // and 'A' - 'Z' (0 - 51) for non-broadcasted dims. + // We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to) + // during broadcasting related post-processing. for (size_t i = 0; i < num_of_ellipsis_dims_; ++i) { current_subscript_indices.push_back(EinsumOp::num_of_letters); } @@ -150,12 +152,13 @@ Status EinsumComputePreprocessor::ProcessSubscripts() { "Found '.' not part of an ellipsis in input: ", input_index); } - if (!(subscript_label >= 'a' && subscript_label <= 'z')) { + auto letter_index = EinsumOp::LetterToIndex(subscript_label); + if (letter_index == -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The only subscript labels allowed are lowercase letters (a-z)"); + "The only subscript labels allowed are lower-cased letters (a-z) and " + "upper-cased letters (A-Z)"); } - auto letter_index = static_cast(subscript_label - 'a'); auto dim_value = dims[dim_counter]; // Subscript label not found in global subscript label array diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h index 3c11d49e33..7c7b25f2a7 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h @@ -19,7 +19,27 @@ namespace onnxruntime { namespace EinsumOp { -constexpr size_t num_of_letters = 26; +// Einsum accepts 'a' - 'z' and 'A' - 'Z' and needs to differentiate between lower-cased +// and upper-cased letters in the equation string (26 * 2 = 52). +constexpr size_t num_of_letters = 52; + +/** Returns the index associated with the input character + * Returns a value between 0 - 25 for input in 'a' - 'z' + * Returns a value between 26 - 51 for input in 'A' - 'Z' + * Returns -1 for invalid input not in 'a' - 'z' or 'A' - 'Z' (caller should handle the returned result) + */ +inline int64_t LetterToIndex(char ch) { + if (ch >= 'a' && ch <= 'z') { + return static_cast(ch - 'a'); + } + + if (ch >= 'A' && ch <= 'Z') { + return 26 + static_cast(ch - 'A'); + } + + // invalid character - return error value + return -1; +} } // namespace EinsumOp diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index cf030e60f1..c5917de22a 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -177,6 +177,16 @@ TEST(Einsum, ExplicitEinsumAsMatmul) { test.Run(); } +TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + // 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that + test.AddAttribute("equation", "iK,Kk->ik"); + test.AddInput("x", {2, 1}, {1.f, 2.f}); + test.AddInput("y", {1, 2}, {1.f, 2.f}); + test.AddOutput("o", {2, 2}, {1.f, 2.f, 2.f, 4.f}); + test.Run(); +} + TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "ij,jk,kl->li");