mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
adding concat logic when initial path is empty (#4525)
* concat * add path_utils * address feedback * use string in test * convert wstring to sting in windows * address feedback * address feedback * fix comment
This commit is contained in:
parent
d1f45f9361
commit
183098e344
5 changed files with 69 additions and 49 deletions
|
|
@ -253,18 +253,22 @@ Path& Path::Append(const Path& other) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Path& Path::Concat(const PathString& string) {
|
||||
components_.back() += string;
|
||||
return *this;
|
||||
}
|
||||
Path& Path::Concat(const PathString& value) {
|
||||
auto first_separator = std::find_if(value.begin(), value.end(),
|
||||
[](PathChar c) {
|
||||
return std::find(
|
||||
k_valid_path_separators.begin(),
|
||||
k_valid_path_separators.end(),
|
||||
c) != k_valid_path_separators.end();
|
||||
});
|
||||
ORT_ENFORCE(first_separator == value.end(),
|
||||
"Cannot concatenate with a string containing a path separator. String: ", ToMBString(value));
|
||||
|
||||
Path& Path::ConcatIndex(const int index) {
|
||||
#ifdef _WIN32
|
||||
auto index_str = std::to_wstring(index);
|
||||
#else
|
||||
auto index_str = std::to_string(index);
|
||||
#endif
|
||||
components_.back() += index_str;
|
||||
if (components_.empty()) {
|
||||
components_.push_back(value);
|
||||
} else {
|
||||
components_.back() += value;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,12 +68,6 @@ class Path {
|
|||
*/
|
||||
Path& Concat(const PathString& string);
|
||||
|
||||
/**
|
||||
* Concatenates an index by the end of current path.
|
||||
* Similar to Concat() except the argument is an index.
|
||||
*/
|
||||
Path& ConcatIndex(const int index);
|
||||
|
||||
/** Equivalent to this->Append(other). */
|
||||
Path& operator/=(const Path& other) {
|
||||
return Append(other);
|
||||
|
|
|
|||
23
onnxruntime/core/common/path_utils.h
Normal file
23
onnxruntime/core/common/path_utils.h
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/path_string.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace path_utils {
|
||||
|
||||
/** Return a PathString with concatenated args.
|
||||
* TODO: add support for arguments of type std::wstring. Currently it is not supported as the underneath
|
||||
* MakeString doesn't support this type.
|
||||
*/
|
||||
template <typename... Args>
|
||||
PathString MakePathString(const Args&... args) {
|
||||
const std::string str = onnxruntime::MakeString(args...);
|
||||
return ToPathString(str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "core/common/optional.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -224,31 +225,29 @@ TEST(PathTest, RelativePathFailure) {
|
|||
|
||||
TEST(PathTest, Concat) {
|
||||
auto check_concat =
|
||||
[](const std::string& a, const std::string& b, const std::string& expected_a) {
|
||||
[](const optional<std::string>& a, const std::string& b, const std::string& expected_a, bool expect_throw = false) {
|
||||
Path p_a{}, p_expected_a{};
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
|
||||
if (a.has_value()) {
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(a.value()), p_a));
|
||||
}
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
|
||||
|
||||
EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString());
|
||||
if (expect_throw) {
|
||||
EXPECT_THROW(p_a.Concat(ToPathString(b)).ToPathString(), OnnxRuntimeException);
|
||||
} else {
|
||||
EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString());
|
||||
}
|
||||
};
|
||||
|
||||
check_concat("/a/b", "c", "/a/bc");
|
||||
check_concat("a/b", "cd", "a/bcd");
|
||||
}
|
||||
|
||||
TEST(PathTest, ConcatIndex) {
|
||||
auto check_concat_index =
|
||||
[](const std::string& a, const int i, const std::string& expected_a) {
|
||||
Path p_a{}, p_expected_a{};
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
|
||||
|
||||
EXPECT_EQ(p_a.ConcatIndex(i).ToPathString(), p_expected_a.ToPathString());
|
||||
};
|
||||
|
||||
check_concat_index("/a/b", 0, "/a/b0");
|
||||
check_concat_index("a/b", 123, "a/b123");
|
||||
check_concat_index("a/b", -1, "a/b-1");
|
||||
check_concat({"/a/b"}, "c", "/a/bc");
|
||||
check_concat({"a/b"}, "cd", "a/bcd");
|
||||
check_concat({""}, "cd", "cd");
|
||||
check_concat({}, "c", "c");
|
||||
#ifdef _WIN32
|
||||
check_concat({"a/b"}, R"(c\d)", "", true /* expect_throw */);
|
||||
#else
|
||||
check_concat({"a/b"}, "c/d", "", true /* expect_throw */);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "orttraining/core/optimizer/gist_encode_decode.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "core/common/path_utils.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "orttraining/models/runner/training_runner.h"
|
||||
|
|
@ -19,6 +20,7 @@
|
|||
using namespace onnxruntime::logging;
|
||||
using namespace onnxruntime::training;
|
||||
using namespace google::protobuf::util;
|
||||
using namespace onnxruntime::path_utils;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -1113,10 +1115,8 @@ void RetrieveSendRecvOperators(
|
|||
}
|
||||
}
|
||||
|
||||
PathString GenerateFileNameWithIndex(const PathString& base_str, int index, const PathString& file_suffix) {
|
||||
Path p;
|
||||
ORT_ENFORCE(Path::Parse(base_str, p).IsOK());
|
||||
return p.ConcatIndex(index).Concat(file_suffix).ToPathString();
|
||||
PathString GenerateFileNameWithIndex(const std::string& base_str, int index, const std::string& file_suffix) {
|
||||
return path_utils::MakePathString(base_str, index, file_suffix);
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
|
||||
|
|
@ -1147,7 +1147,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
|
|||
// graph is partitioned into 3 parts.
|
||||
for (int i = 0; i < static_cast<int>(total_partition_count); ++i) {
|
||||
|
||||
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
|
||||
PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx");
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
||||
if (i == static_cast<int>(total_partition_count - 1)) {
|
||||
|
|
@ -1239,7 +1239,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_MLP) {
|
|||
for(auto is_fp32 : test_with_fp32) {
|
||||
// graph is partitioned into 3 parts.
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
|
||||
PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx");
|
||||
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
||||
|
|
@ -1286,7 +1286,7 @@ Status RunOnlinePartition(const std::vector<TrainingSession::TrainingConfigurati
|
|||
pipe.cut_list = cut_list;
|
||||
|
||||
for (int i = 0; i < pipeline_stage_size; ++i) {
|
||||
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
|
||||
PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx");
|
||||
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
config.pipeline_config = pipe;
|
||||
|
|
@ -1331,7 +1331,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_Invalid_Input) {
|
|||
|
||||
// verify pipeline config can load and gradient graph can construct.
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
||||
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
|
||||
std::string filename_base = "testdata/test_training_model_";
|
||||
|
||||
auto load_and_check_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
|
@ -1437,8 +1437,8 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
|||
};
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
PathString input_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR(".onnx"));
|
||||
PathString output_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR("_back.onnx"));
|
||||
PathString input_file = GenerateFileNameWithIndex(filename_base, i, ".onnx");
|
||||
PathString output_file = GenerateFileNameWithIndex(filename_base, i, "_back.onnx");
|
||||
|
||||
load_and_check_gradient_graph(i, input_file, output_file);
|
||||
}
|
||||
|
|
@ -1505,7 +1505,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
|||
|
||||
std::vector<PathString> sub_model_files(num_subs);
|
||||
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
|
||||
sub_model_files[sub_id] = GenerateFileNameWithIndex(ORT_TSTR("sub_"), static_cast<int>(sub_id), ORT_TSTR(".onnx"));
|
||||
sub_model_files[sub_id] = GenerateFileNameWithIndex("sub_", static_cast<int>(sub_id), ".onnx");
|
||||
}
|
||||
|
||||
PipelineSplitter splitter;
|
||||
|
|
@ -1525,7 +1525,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
|||
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
|
||||
auto& sub_sess = subs[sub_id];
|
||||
sub_sess.so.enable_profiling = true;
|
||||
sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex(ORT_TSTR("pipeline"), static_cast<int>(sub_id), ORT_TSTR(""));
|
||||
sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex("pipeline", static_cast<int>(sub_id), "");
|
||||
|
||||
sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level;
|
||||
sub_sess.run_options.run_tag = sub_sess.so.session_logid;
|
||||
|
|
|
|||
Loading…
Reference in a new issue