mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Fix evaluation issues (#3538)
* allow switching between eval and training modes dynamically Co-authored-by: Tixxx <root@525204a066204ea794f942530b05ae7f000000.axlncovkyjne5caro2tmz3zryb.xx.internal.cloudapp.net>
This commit is contained in:
parent
939589c265
commit
0638565fe0
13 changed files with 243 additions and 31 deletions
|
|
@ -26,6 +26,9 @@ struct OrtRunOptions {
|
|||
// So it is possible that only some of the nodes are executed.
|
||||
bool only_execute_path_to_fetches = false;
|
||||
|
||||
// Set to 'true' to run in training mode.
|
||||
bool training_mode = false;
|
||||
|
||||
OrtRunOptions() = default;
|
||||
~OrtRunOptions() = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -344,20 +344,6 @@ LoopImpl::LoopImpl(OpKernelContextInternal& context,
|
|||
condition_ = cond_tensor ? *cond_tensor->Data<bool>() : true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static OrtValue MakeScalarMLValue(const AllocatorPtr& allocator, T value, bool is_1d) {
|
||||
auto* data_type = DataTypeImpl::GetType<T>();
|
||||
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(data_type,
|
||||
is_1d ? TensorShape({1}) : TensorShape({}),
|
||||
allocator);
|
||||
|
||||
*p_tensor->MutableData<T>() = value;
|
||||
|
||||
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
|
||||
return OrtValue{p_tensor.release(), ml_tensor,
|
||||
ml_tensor->GetDeleteFunc()};
|
||||
}
|
||||
|
||||
Status LoopImpl::Initialize() {
|
||||
auto status = Status::OK();
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,21 @@
|
|||
namespace onnxruntime {
|
||||
class Graph;
|
||||
|
||||
// Creates a scalar MLValue based on given value and allocator.
|
||||
template <typename T>
|
||||
OrtValue MakeScalarMLValue(const AllocatorPtr& allocator, T value, bool is_1d) {
|
||||
auto* data_type = DataTypeImpl::GetType<T>();
|
||||
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(data_type,
|
||||
is_1d ? TensorShape({1}) : TensorShape({}),
|
||||
allocator);
|
||||
|
||||
*p_tensor->MutableData<T>() = value;
|
||||
|
||||
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
|
||||
return OrtValue{p_tensor.release(), ml_tensor,
|
||||
ml_tensor->GetDeleteFunc()};
|
||||
}
|
||||
|
||||
namespace controlflow {
|
||||
|
||||
/** Interface for control flow kernels */
|
||||
|
|
|
|||
|
|
@ -269,8 +269,8 @@ class InferenceSession {
|
|||
*/
|
||||
common::Status NewIOBinding(std::unique_ptr<IOBinding>* io_binding) ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
virtual common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
|
||||
|
|
|
|||
|
|
@ -717,7 +717,9 @@ Applies to a particular Run() invocation. Default is 0.)pbdoc")
|
|||
R"pbdoc(Set to True to terminate any currently executing calls that are using this
|
||||
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
|
||||
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
|
||||
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc");
|
||||
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
|
||||
.def_readwrite("training_mode", &RunOptions::training_mode,
|
||||
R"pbdoc(Choose to run in training or inferencing mode)pbdoc");
|
||||
|
||||
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
|
||||
It is usually used to identify the model used to run the prediction and
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
expected_losses = [
|
||||
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,
|
||||
11.029067039489746, 11.040265083312988, 11.046793937683105, 10.993699073791504]
|
||||
expected_eval_loss = [10.9691801071167]
|
||||
expected_eval_loss = [10.95898914]
|
||||
actual_losses, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=False)
|
||||
|
||||
|
|
@ -354,7 +354,7 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
# print('eval_loss actual: ', actual_eval_loss)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
rtol = 1e-03
|
||||
rtol = 1e-04
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
|
|
@ -362,7 +362,7 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
expected_losses = [
|
||||
11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266,
|
||||
11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668]
|
||||
expected_eval_loss = [10.969207763671875]
|
||||
expected_eval_loss = [10.959011]
|
||||
|
||||
actual_losses, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False)
|
||||
|
|
@ -374,9 +374,9 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
# print('eval_loss actual: ', actual_eval_loss)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
rtol = 1e-03
|
||||
assert_allclose(expected_losses, actual_losses, err_msg="loss mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, err_msg="evaluation loss mismatch")
|
||||
rtol = 1e-04
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(module=__name__, buffer=True)
|
||||
|
|
|
|||
|
|
@ -12,18 +12,18 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
def testBertTrainingMixedPrecision(self):
|
||||
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
|
||||
expected_all_finites = [False, True, True, True, True, True, True, True]
|
||||
expected_eval_loss = [10.96875]
|
||||
expected_eval_loss = [10.960938]
|
||||
actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=1, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False)
|
||||
|
||||
rtol = 1e-03
|
||||
rtol = 1e-04
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
def testBertTrainingMixedPrecisionInternalLossScale(self):
|
||||
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
|
||||
expected_eval_loss = [10.96875]
|
||||
expected_eval_loss = [10.960938]
|
||||
actual_losses, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=1,
|
||||
use_mixed_precision=True,
|
||||
|
|
@ -31,18 +31,18 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
use_simple_model_desc=False,
|
||||
use_internel_loss_scale=True)
|
||||
|
||||
rtol = 1e-03
|
||||
rtol = 1e-04
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
def testBertTrainingGradientAccumulationMixedPrecision(self):
|
||||
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
|
||||
expected_all_finites = [False, True]
|
||||
expected_eval_loss = [10.96875]
|
||||
expected_eval_loss = [10.960938]
|
||||
actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=4, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False)
|
||||
|
||||
rtol = 1e-03
|
||||
rtol = 1e-04
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import sys
|
||||
import copy
|
||||
from numpy.testing import assert_allclose, assert_array_equal
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description
|
||||
from helper import get_name
|
||||
import onnxruntime
|
||||
from onnxruntime_test_training_unittest_utils import process_dropout
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
|
||||
|
||||
torch.manual_seed(1)
|
||||
onnxruntime.set_seed(1)
|
||||
|
||||
class TestTrainingDropout(unittest.TestCase):
|
||||
def testTrainingAndEvalDropout(self):
|
||||
# Temporarily disable this test.
|
||||
# The graph below will trigger ORT
|
||||
# to sort backward graph before forward graph which gives incorrect result.
|
||||
# TODO Re-enable when that is fixed.
|
||||
return
|
||||
class TwoDropoutNet(nn.Module):
|
||||
def __init__(self, drop_prb_1, drop_prb_2, dim_size):
|
||||
super(TwoDropoutNet, self).__init__()
|
||||
self.drop_1 = nn.Dropout(drop_prb_1)
|
||||
self.drop_2 = nn.Dropout(drop_prb_2)
|
||||
self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32))
|
||||
def forward(self, x):
|
||||
x = x + self.weight_1
|
||||
x = self.drop_1(x)
|
||||
x = self.drop_2(x)
|
||||
output = x
|
||||
return output[0]
|
||||
dim_size = 3
|
||||
device = torch.device("cuda", 0)
|
||||
# This will drop all values, therefore expecting all 0 in output tensor
|
||||
model = TwoDropoutNet(0.999, 0.999, dim_size)
|
||||
input_desc = IODescription('input', [dim_size], torch.float32)
|
||||
output_desc = IODescription('output', [], torch.float32)
|
||||
model_desc = ModelDescription([input_desc], [output_desc])
|
||||
lr_desc = ort_trainer_learning_rate_description()
|
||||
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes,
|
||||
lr_desc,
|
||||
device,
|
||||
postprocess_model=process_dropout,
|
||||
world_rank=0, world_size=1)
|
||||
input = torch.ones(dim_size, dtype=torch.float32).to(device)
|
||||
expected_training_output = [0.0]
|
||||
expected_eval_output = [1.0]
|
||||
learning_rate = torch.tensor([1.0000000e+00]).to(device)
|
||||
input_args=[input, learning_rate]
|
||||
train_output = model.train_step(*input_args)
|
||||
|
||||
rtol = 1e-04
|
||||
assert_allclose(expected_training_output, train_output.item(), rtol=rtol, err_msg="dropout training loss mismatch")
|
||||
|
||||
eval_output = model.eval_step(input)
|
||||
assert_allclose(expected_eval_output, eval_output.item(), rtol=rtol, err_msg="dropout eval loss mismatch")
|
||||
|
||||
# Do another train step to make sure it's using original ratios
|
||||
train_output_2 = model.train_step(*input_args)
|
||||
assert_allclose(expected_training_output, train_output_2.item(), rtol=rtol, err_msg="dropout training loss 2 mismatch")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(module=__name__, buffer=True)
|
||||
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
|
||||
def get_node_index(model, node):
|
||||
i = 0
|
||||
while i < len(model.graph.node):
|
||||
if model.graph.node[i] == node:
|
||||
break
|
||||
i += 1
|
||||
return i if i < len(model.graph.node) else None
|
||||
|
||||
def add_const(model, name, output, t_value = None, f_value = None):
|
||||
const_node = model.graph.node.add()
|
||||
const_node.op_type = 'Constant'
|
||||
const_node.name = name
|
||||
const_node.output.extend([output])
|
||||
attr = const_node.attribute.add()
|
||||
attr.name = 'value'
|
||||
if t_value is not None:
|
||||
attr.type = 4
|
||||
attr.t.CopyFrom(t_value)
|
||||
else:
|
||||
attr.type = 1
|
||||
attr.f = f_value
|
||||
return const_node
|
||||
|
||||
def process_dropout(model):
|
||||
dropouts = []
|
||||
index = 0
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Dropout':
|
||||
new_dropout = model.graph.node.add()
|
||||
new_dropout.op_type = 'TrainableDropout'
|
||||
new_dropout.name = 'TrainableDropout_%d' % index
|
||||
#make ratio node
|
||||
ratio = np.asarray([node.attribute[0].f], dtype=np.float32)
|
||||
print(ratio.shape)
|
||||
ratio_value = numpy_helper.from_array(ratio)
|
||||
ratio_node = add_const(model, 'dropout_node_ratio_%d' % index, 'dropout_node_ratio_%d' % index, t_value=ratio_value)
|
||||
print (ratio_node)
|
||||
new_dropout.input.extend([node.input[0], ratio_node.output[0]])
|
||||
new_dropout.output.extend(node.output)
|
||||
dropouts.append(get_node_index(model, node))
|
||||
index += 1
|
||||
dropouts.sort(reverse=True)
|
||||
for d in dropouts:
|
||||
del model.graph.node[d]
|
||||
model.opset_import[0].version = 10
|
||||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include "core/framework/data_transfer_utils.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
#include "orttraining/core/graph/loss_function_builder.h"
|
||||
#include "orttraining/core/graph/optimizer_builder.h"
|
||||
#include "orttraining/core/framework/checkpointing.h"
|
||||
|
|
@ -228,6 +230,9 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
}
|
||||
|
||||
// Set eval feed names for Dropout ratio.
|
||||
ORT_RETURN_IF_ERROR(SetDropoutEvalFeedNames());
|
||||
|
||||
// add Tensorboard
|
||||
if (config.tensorboard_config.has_value()) {
|
||||
const auto& tensorboard_config = config.tensorboard_config.value();
|
||||
|
|
@ -676,6 +681,57 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons
|
|||
return IsFP32Node(output_producer_node);
|
||||
}
|
||||
|
||||
common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
|
||||
// Override initializers in eval mode.
|
||||
if (!run_options.training_mode) {
|
||||
// override all dropout raiots to 0
|
||||
for (auto& drop_ratio : dropout_eval_feeds_) {
|
||||
OrtValue feed_value;
|
||||
// We allocate on CPU first, copy will be taken care off downstream.
|
||||
auto cpu_allocator = session_state_->GetExecutionProviders()
|
||||
.Get(onnxruntime::kCpuExecutionProvider)
|
||||
->GetAllocator(0, OrtMemTypeDefault);
|
||||
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
|
||||
// Bind new feed to graph input.
|
||||
ORT_RETURN_IF_ERROR(io_binding.BindInput(drop_ratio, feed_value));
|
||||
}
|
||||
}
|
||||
|
||||
// Call Run in inferenceSession
|
||||
return InferenceSession::Run(run_options, io_binding);
|
||||
}
|
||||
|
||||
common::Status TrainingSession::Run(IOBinding& io_binding) {
|
||||
RunOptions run_options;
|
||||
// Set training_mode to true in training session by default.
|
||||
run_options.training_mode = true;
|
||||
return Run(run_options, io_binding);
|
||||
}
|
||||
|
||||
static const std::unordered_set<std::string> Dropout_Nodes = {
|
||||
"TrainableDropout",
|
||||
};
|
||||
// TODO remove this once ONNX properly supports training_mode input.
|
||||
Status TrainingSession::SetDropoutEvalFeedNames() {
|
||||
Graph& graph = model_->MainGraph();
|
||||
|
||||
// add ratio node to graph input for overriding.
|
||||
GraphAugmenter::GraphDefs defs{};
|
||||
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
auto it = Dropout_Nodes.find(node.OpType());
|
||||
if(it != Dropout_Nodes.cend()) {
|
||||
auto& ratio_name = node.InputDefs()[1]->Name();
|
||||
dropout_eval_feeds_.insert(ratio_name);
|
||||
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
|
||||
"Input: " + ratio_name + " should not have any producer node.");
|
||||
defs.AddGraphInputs({ratio_name});
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs));
|
||||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
||||
Status TrainingSession::SetStateTensors(const NameMLValMap& state_tensors, bool strict) {
|
||||
ORT_RETURN_IF_NOT(IsInitialized(), "Can't update initializers before session has been initialized.");
|
||||
|
||||
|
|
|
|||
|
|
@ -212,9 +212,26 @@ class TrainingSession : public InferenceSession {
|
|||
/** Gets the model location. */
|
||||
const PathString& GetModelLocation() const { return model_location_; }
|
||||
|
||||
/** Checks to be see if given graph output is produced by an fp32-only node. */
|
||||
/**
|
||||
* Checks to be see if given graph output is produced by an fp32-only node.
|
||||
* @param The name of the output.
|
||||
* @return Whether output is from fp32-only node or not.
|
||||
*/
|
||||
bool IsGraphOutputFp32Node(const std::string& output_name) const;
|
||||
|
||||
/**
|
||||
* Gets the list of Dropout ratio inputs that will be used as feeds in eval mode,
|
||||
* since each ratio input has its own name.
|
||||
* @return The list of feed names.
|
||||
*/
|
||||
std::unordered_set<std::string> GetDropoutEvalFeeds() const { return dropout_eval_feeds_; }
|
||||
|
||||
/** Override Run function in InferenceSession to inject some training-specific logics **/
|
||||
using InferenceSession::Run; // For overload resolution.
|
||||
common::Status Run(const RunOptions& run_options, IOBinding& io_binding) override;
|
||||
|
||||
common::Status Run(IOBinding& io_binding) override;
|
||||
|
||||
private:
|
||||
/** Configures the loss function.
|
||||
The loss function can either be provided externally or built from the provided loss function information.
|
||||
|
|
@ -312,6 +329,8 @@ class TrainingSession : public InferenceSession {
|
|||
|
||||
std::unordered_set<std::string> GetStateTensorNames() const;
|
||||
|
||||
common::Status SetDropoutEvalFeedNames();
|
||||
|
||||
NameMLValMap GetWeights() const;
|
||||
|
||||
static bool IsImmutableWeight(const ImmutableWeights& immutable_weights,
|
||||
|
|
@ -335,6 +354,7 @@ class TrainingSession : public InferenceSession {
|
|||
std::unique_ptr<ILossFunction> loss_graph_builder_;
|
||||
optional<LossFunctionInfo> loss_function_info_;
|
||||
|
||||
std::unordered_set<std::string> dropout_eval_feeds_;
|
||||
OptimizerGraphConfig opt_graph_config_;
|
||||
std::unordered_map<std::string, OptimizerNodeConfig> opt_configs_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out
|
|||
device_index = input_get_device_index(input)
|
||||
iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype),
|
||||
list(input.size()), input.data_ptr())
|
||||
|
||||
|
||||
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
|
||||
torch_outputs = {}
|
||||
for output_desc in output_descs_resolved:
|
||||
|
|
@ -787,6 +787,7 @@ class ORTTrainer():
|
|||
elif self.current_step % self.gradient_accumulation_steps != 0:
|
||||
run_options = ort.RunOptions()
|
||||
run_options.only_execute_path_to_fetches = True
|
||||
run_options.training_mode = True
|
||||
output_desc = self.output_desc_with_group_accumulated_gradients
|
||||
elif self.use_mixed_precision:
|
||||
has_if_all_finite = True
|
||||
|
|
@ -864,6 +865,7 @@ class ORTTrainer():
|
|||
|
||||
run_options = ort.RunOptions()
|
||||
run_options.only_execute_path_to_fetches = True
|
||||
run_options.training_mode = False
|
||||
|
||||
session_run_results = ort_training_session_run_helper(self.session, self.eval_io_binding, input,
|
||||
input_desc,
|
||||
|
|
|
|||
|
|
@ -1095,6 +1095,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs,
|
|||
run_subprocess(
|
||||
[sys.executable, 'onnxruntime_test_ort_trainer.py'],
|
||||
cwd=cwd, dll_path=dll_path)
|
||||
run_subprocess(
|
||||
[sys.executable, 'onnxruntime_test_training_unit_tests.py'],
|
||||
cwd=cwd, dll_path=dll_path)
|
||||
|
||||
# run additional frontend tests for orttraining-linux-gpu-frontend_test_ci-pipeline
|
||||
if args.enable_training_python_frontend_e2e_tests:
|
||||
|
|
|
|||
Loading…
Reference in a new issue