adding concat logic when initial path is empty (#4525)

* concat

* add path_utils

* address feedback

* use string in test

* convert wstring to sting in windows

* address feedback

* address feedback

* fix comment
This commit is contained in:
Xueyun Zhu 2020-07-16 23:46:12 -07:00 committed by GitHub
parent d1f45f9361
commit 183098e344
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 49 deletions

View file

@ -253,18 +253,22 @@ Path& Path::Append(const Path& other) {
return *this;
}
Path& Path::Concat(const PathString& string) {
components_.back() += string;
return *this;
}
Path& Path::Concat(const PathString& value) {
auto first_separator = std::find_if(value.begin(), value.end(),
[](PathChar c) {
return std::find(
k_valid_path_separators.begin(),
k_valid_path_separators.end(),
c) != k_valid_path_separators.end();
});
ORT_ENFORCE(first_separator == value.end(),
"Cannot concatenate with a string containing a path separator. String: ", ToMBString(value));
Path& Path::ConcatIndex(const int index) {
#ifdef _WIN32
auto index_str = std::to_wstring(index);
#else
auto index_str = std::to_string(index);
#endif
components_.back() += index_str;
if (components_.empty()) {
components_.push_back(value);
} else {
components_.back() += value;
}
return *this;
}

View file

@ -68,12 +68,6 @@ class Path {
*/
Path& Concat(const PathString& string);
/**
* Concatenates an index by the end of current path.
* Similar to Concat() except the argument is an index.
*/
Path& ConcatIndex(const int index);
/** Equivalent to this->Append(other). */
Path& operator/=(const Path& other) {
return Append(other);

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/common/path_string.h"
namespace onnxruntime {
namespace path_utils {
/** Return a PathString with concatenated args.
* TODO: add support for arguments of type std::wstring. Currently it is not supported as the underneath
* MakeString doesn't support this type.
*/
template <typename... Args>
PathString MakePathString(const Args&... args) {
const std::string str = onnxruntime::MakeString(args...);
return ToPathString(str);
}
}
}

View file

@ -5,6 +5,7 @@
#include "gtest/gtest.h"
#include "core/common/optional.h"
#include "test/util/include/asserts.h"
namespace onnxruntime {
@ -224,31 +225,29 @@ TEST(PathTest, RelativePathFailure) {
TEST(PathTest, Concat) {
auto check_concat =
[](const std::string& a, const std::string& b, const std::string& expected_a) {
[](const optional<std::string>& a, const std::string& b, const std::string& expected_a, bool expect_throw = false) {
Path p_a{}, p_expected_a{};
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
if (a.has_value()) {
ASSERT_STATUS_OK(Path::Parse(ToPathString(a.value()), p_a));
}
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString());
if (expect_throw) {
EXPECT_THROW(p_a.Concat(ToPathString(b)).ToPathString(), OnnxRuntimeException);
} else {
EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString());
}
};
check_concat("/a/b", "c", "/a/bc");
check_concat("a/b", "cd", "a/bcd");
}
TEST(PathTest, ConcatIndex) {
auto check_concat_index =
[](const std::string& a, const int i, const std::string& expected_a) {
Path p_a{}, p_expected_a{};
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
EXPECT_EQ(p_a.ConcatIndex(i).ToPathString(), p_expected_a.ToPathString());
};
check_concat_index("/a/b", 0, "/a/b0");
check_concat_index("a/b", 123, "a/b123");
check_concat_index("a/b", -1, "a/b-1");
check_concat({"/a/b"}, "c", "/a/bc");
check_concat({"a/b"}, "cd", "a/bcd");
check_concat({""}, "cd", "cd");
check_concat({}, "c", "c");
#ifdef _WIN32
check_concat({"a/b"}, R"(c\d)", "", true /* expect_throw */);
#else
check_concat({"a/b"}, "c/d", "", true /* expect_throw */);
#endif
}
} // namespace test

View file

@ -5,6 +5,7 @@
#include "gtest/gtest.h"
#include "orttraining/core/optimizer/gist_encode_decode.h"
#include "test/providers/provider_test_utils.h"
#include "core/common/path_utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/session/environment.h"
#include "orttraining/models/runner/training_runner.h"
@ -19,6 +20,7 @@
using namespace onnxruntime::logging;
using namespace onnxruntime::training;
using namespace google::protobuf::util;
using namespace onnxruntime::path_utils;
namespace onnxruntime {
namespace test {
@ -1113,10 +1115,8 @@ void RetrieveSendRecvOperators(
}
}
PathString GenerateFileNameWithIndex(const PathString& base_str, int index, const PathString& file_suffix) {
Path p;
ORT_ENFORCE(Path::Parse(base_str, p).IsOK());
return p.ConcatIndex(index).Concat(file_suffix).ToPathString();
PathString GenerateFileNameWithIndex(const std::string& base_str, int index, const std::string& file_suffix) {
return path_utils::MakePathString(base_str, index, file_suffix);
}
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
@ -1147,7 +1147,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
// graph is partitioned into 3 parts.
for (int i = 0; i < static_cast<int>(total_partition_count); ++i) {
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx");
auto config = MakeBasicTrainingConfig();
if (i == static_cast<int>(total_partition_count - 1)) {
@ -1239,7 +1239,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_MLP) {
for(auto is_fp32 : test_with_fp32) {
// graph is partitioned into 3 parts.
for (int i = 0; i < 3; ++i) {
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx");
auto config = MakeBasicTrainingConfig();
@ -1286,7 +1286,7 @@ Status RunOnlinePartition(const std::vector<TrainingSession::TrainingConfigurati
pipe.cut_list = cut_list;
for (int i = 0; i < pipeline_stage_size; ++i) {
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
PathString output_file = GenerateFileNameWithIndex("pipeline_partition_", i, "_back.onnx");
auto config = MakeBasicTrainingConfig();
config.pipeline_config = pipe;
@ -1331,7 +1331,7 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_Invalid_Input) {
// verify pipeline config can load and gradient graph can construct.
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
std::string filename_base = "testdata/test_training_model_";
auto load_and_check_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
auto config = MakeBasicTrainingConfig();
@ -1437,8 +1437,8 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
};
for (int i = 0; i < 3; ++i) {
PathString input_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR(".onnx"));
PathString output_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR("_back.onnx"));
PathString input_file = GenerateFileNameWithIndex(filename_base, i, ".onnx");
PathString output_file = GenerateFileNameWithIndex(filename_base, i, "_back.onnx");
load_and_check_gradient_graph(i, input_file, output_file);
}
@ -1505,7 +1505,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
std::vector<PathString> sub_model_files(num_subs);
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
sub_model_files[sub_id] = GenerateFileNameWithIndex(ORT_TSTR("sub_"), static_cast<int>(sub_id), ORT_TSTR(".onnx"));
sub_model_files[sub_id] = GenerateFileNameWithIndex("sub_", static_cast<int>(sub_id), ".onnx");
}
PipelineSplitter splitter;
@ -1525,7 +1525,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
auto& sub_sess = subs[sub_id];
sub_sess.so.enable_profiling = true;
sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex(ORT_TSTR("pipeline"), static_cast<int>(sub_id), ORT_TSTR(""));
sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex("pipeline", static_cast<int>(sub_id), "");
sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level;
sub_sess.run_options.run_tag = sub_sess.so.session_logid;