diff --git a/onnxruntime/core/common/path.cc b/onnxruntime/core/common/path.cc index 0d60e884d1..1050f096a4 100644 --- a/onnxruntime/core/common/path.cc +++ b/onnxruntime/core/common/path.cc @@ -253,18 +253,22 @@ Path& Path::Append(const Path& other) { return *this; } -Path& Path::Concat(const PathString& string) { - components_.back() += string; - return *this; -} +Path& Path::Concat(const PathString& value) { + auto first_separator = std::find_if(value.begin(), value.end(), + [](PathChar c) { + return std::find( + k_valid_path_separators.begin(), + k_valid_path_separators.end(), + c) != k_valid_path_separators.end(); + }); + ORT_ENFORCE(first_separator == value.end(), + "Cannot concatenate with a string containing a path separator. String: ", ToMBString(value)); -Path& Path::ConcatIndex(const int index) { -#ifdef _WIN32 - auto index_str = std::to_wstring(index); -#else - auto index_str = std::to_string(index); -#endif - components_.back() += index_str; + if (components_.empty()) { + components_.push_back(value); + } else { + components_.back() += value; + } return *this; } diff --git a/onnxruntime/core/common/path.h b/onnxruntime/core/common/path.h index 9dac03b725..514a636033 100644 --- a/onnxruntime/core/common/path.h +++ b/onnxruntime/core/common/path.h @@ -68,12 +68,6 @@ class Path { */ Path& Concat(const PathString& string); - /** - * Concatenates an index by the end of current path. - * Similar to Concat() except the argument is an index. - */ - Path& ConcatIndex(const int index); - /** Equivalent to this->Append(other). */ Path& operator/=(const Path& other) { return Append(other); diff --git a/onnxruntime/core/common/path_utils.h b/onnxruntime/core/common/path_utils.h new file mode 100644 index 0000000000..4e133a0dd5 --- /dev/null +++ b/onnxruntime/core/common/path_utils.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/common/path_string.h" + +namespace onnxruntime { + +namespace path_utils { + +/** Return a PathString with concatenated args. + * TODO: add support for arguments of type std::wstring. Currently it is not supported as the underneath + * MakeString doesn't support this type. +*/ +template +PathString MakePathString(const Args&... args) { + const std::string str = onnxruntime::MakeString(args...); + return ToPathString(str); +} +} +} diff --git a/onnxruntime/test/common/path_test.cc b/onnxruntime/test/common/path_test.cc index ea8908affc..dbd9990c0d 100644 --- a/onnxruntime/test/common/path_test.cc +++ b/onnxruntime/test/common/path_test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" +#include "core/common/optional.h" #include "test/util/include/asserts.h" namespace onnxruntime { @@ -224,31 +225,29 @@ TEST(PathTest, RelativePathFailure) { TEST(PathTest, Concat) { auto check_concat = - [](const std::string& a, const std::string& b, const std::string& expected_a) { + [](const optional& a, const std::string& b, const std::string& expected_a, bool expect_throw = false) { Path p_a{}, p_expected_a{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a)); + if (a.has_value()) { + ASSERT_STATUS_OK(Path::Parse(ToPathString(a.value()), p_a)); + } ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a)); - EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString()); + if (expect_throw) { + EXPECT_THROW(p_a.Concat(ToPathString(b)).ToPathString(), OnnxRuntimeException); + } else { + EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString()); + } }; - check_concat("/a/b", "c", "/a/bc"); - check_concat("a/b", "cd", "a/bcd"); -} - -TEST(PathTest, ConcatIndex) { - auto check_concat_index = - [](const std::string& a, const int i, const std::string& expected_a) { - Path p_a{}, p_expected_a{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a)); - - EXPECT_EQ(p_a.ConcatIndex(i).ToPathString(), p_expected_a.ToPathString()); - }; - - check_concat_index("/a/b", 0, "/a/b0"); - check_concat_index("a/b", 123, "a/b123"); - check_concat_index("a/b", -1, "a/b-1"); + check_concat({"/a/b"}, "c", "/a/bc"); + check_concat({"a/b"}, "cd", "a/bcd"); + check_concat({""}, "cd", "cd"); + check_concat({}, "c", "c"); +#ifdef _WIN32 + check_concat({"a/b"}, R"(c\d)", "", true /* expect_throw */); +#else + check_concat({"a/b"}, "c/d", "", true /* expect_throw */); +#endif } } // namespace test diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 571152b107..2e7ae8ffe9 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "test/providers/provider_test_utils.h" +#include "core/common/path_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/environment.h" #include "orttraining/models/runner/training_runner.h" @@ -19,6 +20,7 @@ using namespace onnxruntime::logging; using namespace onnxruntime::training; using namespace google::protobuf::util; +using namespace onnxruntime::path_utils; namespace onnxruntime { namespace test { @@ -1113,10 +1115,8 @@ void RetrieveSendRecvOperators( } } -PathString GenerateFileNameWithIndex(const PathString& base_str, int index, const PathString& file_suffix) { - Path p; - ORT_ENFORCE(Path::Parse(base_str, p).IsOK()); - return p.ConcatIndex(index).Concat(file_suffix).ToPathString(); +PathString GenerateFileNameWithIndex(const std::string& base_str, int index, const std::string& file_suffix) { + return path_utils::MakePathString(base_str, index, file_suffix); } TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) { @@ -1147,7 +1147,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) { // graph is partitioned into 3 parts. for (int i = 0; i < static_cast(total_partition_count); ++i) { - PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx")); + PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx"); auto config = MakeBasicTrainingConfig(); if (i == static_cast(total_partition_count - 1)) { @@ -1239,7 +1239,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_MLP) { for(auto is_fp32 : test_with_fp32) { // graph is partitioned into 3 parts. for (int i = 0; i < 3; ++i) { - PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx")); + PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx"); auto config = MakeBasicTrainingConfig(); @@ -1286,7 +1286,7 @@ Status RunOnlinePartition(const std::vector sub_model_files(num_subs); for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) { - sub_model_files[sub_id] = GenerateFileNameWithIndex(ORT_TSTR("sub_"), static_cast(sub_id), ORT_TSTR(".onnx")); + sub_model_files[sub_id] = GenerateFileNameWithIndex("sub_", static_cast(sub_id), ".onnx"); } PipelineSplitter splitter; @@ -1525,7 +1525,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) { for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) { auto& sub_sess = subs[sub_id]; sub_sess.so.enable_profiling = true; - sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex(ORT_TSTR("pipeline"), static_cast(sub_id), ORT_TSTR("")); + sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex("pipeline", static_cast(sub_id), ""); sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level; sub_sess.run_options.run_tag = sub_sess.so.session_logid;