From b2da700e4d953239833e40f9a1b39b15936cc6dd Mon Sep 17 00:00:00 2001 From: Andrew McDowell <35566934+AndrewMcDowell@users.noreply.github.com> Date: Mon, 26 Oct 2020 01:11:12 +0000 Subject: [PATCH] Allow Upper case letters in RHS of einsum equations. (#5569) Co-authored-by: Andrew McDowell --- .../math/einsum_utils/einsum_compute_preprocessor.cc | 9 ++++----- onnxruntime/test/providers/cpu/math/einsum_test.cc | 10 ++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc index 7b7f46ccf0..91c700568a 100644 --- a/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc +++ b/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc @@ -369,14 +369,13 @@ Status EinsumComputePreprocessor::CalculateOutputShape() { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Found '.' not part of an ellipsis in the output subscript provided"); } - if (!(subscript_label >= 'a' && subscript_label <= 'z')) { + auto letter_index = EinsumOp::LetterToIndex(subscript_label); + if (letter_index == -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The only subscript labels allowed in the output subscript " - "are lowercase letters (a-z)"); + "The only subscript labels allowed are lower-cased letters (a-z) and " + "upper-cased letters (A-Z)"); } - auto letter_index = static_cast(subscript_label - 'a'); - if (output_letter_to_count[letter_index] != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Output subscript contains repeated letters"); diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index c5917de22a..0d22ba6f14 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -187,6 +187,16 @@ TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) { test.Run(); } +TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedOutputLabel) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + // Einsum should handle be able to handle upper case on both LHS and RHS + test.AddAttribute("equation", "Ki,ik->Kk"); + test.AddInput("x", {2, 1}, {1.f, 2.f}); + test.AddInput("y", {1, 2}, {1.f, 2.f}); + test.AddOutput("o", {2, 2}, {1.f, 2.f, 2.f, 4.f}); + test.Run(); +} + TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "ij,jk,kl->li");