mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Merge remote-tracking branch 'upstream/master' into DmlDev
This commit is contained in:
commit
907952ef8c
10 changed files with 177 additions and 101 deletions
|
|
@ -12,7 +12,7 @@ Examples of these abstractions are: ([threadpool.h](https://github.com/microsoft
|
|||
* TryBatchParallelFor
|
||||
* TryParallelFor
|
||||
* TrySimpleParallelFor
|
||||
* static version of NumThreads
|
||||
* DegreeOfParallelism
|
||||
|
||||
**Please do not write #ifdef pragma omp in operator code**.
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class ThreadPoolTempl;
|
|||
namespace concurrency {
|
||||
|
||||
class ExtendedThreadPoolInterface;
|
||||
class BatchHandle;
|
||||
class LoopCounter;
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
|
|
@ -118,27 +118,30 @@ class ThreadPool {
|
|||
#else
|
||||
using NAME_CHAR_TYPE = char;
|
||||
#endif
|
||||
// Constructs a pool that contains "num_threads" threads with specified
|
||||
// "name". env->StartThread() is used to create individual threads with the
|
||||
// given ThreadOptions. If "low_latency_hint" is true the thread pool
|
||||
// Constructs a pool for running with with "degree_of_parallelism" threads with
|
||||
// specified "name". env->StartThread() is used to create individual threads
|
||||
// with the given ThreadOptions. If "low_latency_hint" is true the thread pool
|
||||
// implementation may use it as a hint that lower latency is preferred at the
|
||||
// cost of higher CPU usage, e.g. by letting one or more idle threads spin
|
||||
// wait. Conversely, if the threadpool is used to schedule high-latency
|
||||
// operations like I/O the hint should be set to false.
|
||||
//
|
||||
// REQUIRES: num_threads > 0
|
||||
// REQUIRES: degree_of_parallelism > 0
|
||||
// The allocator parameter is only used for creating a Eigen::ThreadPoolDevice to be used with Eigen Tensor classes.
|
||||
ThreadPool(Env* env,
|
||||
const ThreadOptions& thread_options,
|
||||
const NAME_CHAR_TYPE* name,
|
||||
int num_threads,
|
||||
int degree_of_parallelism,
|
||||
bool low_latency_hint);
|
||||
|
||||
// Waits until all scheduled work has finished and then destroy the
|
||||
// set of threads.
|
||||
~ThreadPool();
|
||||
|
||||
// Schedules fn() for execution in the pool of threads.
|
||||
// Schedules fn() for execution in the pool of threads. The function may run
|
||||
// synchronously if it cannot be enqueued. This will occur if the thread pool's
|
||||
// degree-of-parallelism is 1, but it may also occur for implementation-dependent
|
||||
// reasons such as if queues used for buffering work are full.
|
||||
void Schedule(std::function<void()> fn);
|
||||
|
||||
// Returns the number of shards used by ParallelForFixedBlockSizeScheduling
|
||||
|
|
@ -171,7 +174,7 @@ class ThreadPool {
|
|||
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
|
||||
#ifdef _OPENMP
|
||||
ORT_UNUSED_PARAMETER(cost_per_unit);
|
||||
std::ptrdiff_t num_threads = concurrency::ThreadPool::NumThreads(tp);
|
||||
std::ptrdiff_t num_threads = concurrency::ThreadPool::DegreeOfParallelism(tp);
|
||||
if (total < num_threads) {
|
||||
num_threads = total;
|
||||
}
|
||||
|
|
@ -199,7 +202,7 @@ class ThreadPool {
|
|||
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
|
||||
#ifdef _OPENMP
|
||||
ORT_UNUSED_PARAMETER(scheduling_params);
|
||||
std::ptrdiff_t num_threads = concurrency::ThreadPool::NumThreads(tp);
|
||||
std::ptrdiff_t num_threads = concurrency::ThreadPool::DegreeOfParallelism(tp);
|
||||
if (total < num_threads) {
|
||||
num_threads = total;
|
||||
}
|
||||
|
|
@ -217,16 +220,15 @@ class ThreadPool {
|
|||
#endif
|
||||
}
|
||||
|
||||
// Prefer using this API to get the number of threads unless you know what you're doing.
|
||||
// This API takes into account if openmp is enabled/disabled and if the thread pool ptr is nullptr.
|
||||
static int NumThreads(const concurrency::ThreadPool* tp);
|
||||
|
||||
// Returns the number of threads in the pool. Preferably use the static version of this API instead.
|
||||
int NumThreads() const;
|
||||
|
||||
// Returns current thread id between 0 and NumThreads() - 1, if called from a
|
||||
// thread in the pool. Returns -1 otherwise.
|
||||
int CurrentThreadId() const;
|
||||
// Return the degree of parallelism that code should assume when using the thread pool.
|
||||
// This API takes into account if OpenMP is enabled/disabled, and if the thread pool ptr is
|
||||
// nullptr. It decouples the degree of parallelism for use with the thread pool from
|
||||
// the implementation choice of whether this matches the number of threads created in
|
||||
// the pool.
|
||||
//
|
||||
// Currently, a loop with degree-of-parallelism N is supported by a pool of N-1 threads
|
||||
// working in combination with the thread initiating the loop.
|
||||
static int DegreeOfParallelism(const concurrency::ThreadPool* tp);
|
||||
|
||||
// Directly schedule the 'total' tasks to the underlying threadpool, without
|
||||
// cutting them by halves
|
||||
|
|
@ -254,7 +256,7 @@ class ThreadPool {
|
|||
|
||||
/**
|
||||
* Tries to call the given function in parallel, with calls split into (num_batches) batches.
|
||||
*\param num_batches If it is zero, it will be replaced to the value of NumThreads().
|
||||
*\param num_batches If it is zero, it will be replaced to the value of DegreeOfParallelism().
|
||||
*\param fn A std::function or STL style functor with signature of "void f(int32_t);"
|
||||
* Pitfall: Caller should cap `num_batches` to a reasonable value based on the cost of `fn` and the value of `total`.
|
||||
*For example, if fn is as simple as: int sum=0; fn = [&](int i){sum +=i;} and `total` is 100, then num_batches should
|
||||
|
|
@ -288,7 +290,7 @@ class ThreadPool {
|
|||
}
|
||||
|
||||
if (num_batches <= 0) {
|
||||
num_batches = std::min<ptrdiff_t>(total, tp->NumThreads());
|
||||
num_batches = std::min<ptrdiff_t>(total, DegreeOfParallelism(tp));
|
||||
}
|
||||
|
||||
if (num_batches <= 1) {
|
||||
|
|
@ -334,6 +336,16 @@ class ThreadPool {
|
|||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool);
|
||||
|
||||
private:
|
||||
friend class LoopCounter;
|
||||
|
||||
// Returns the number of threads created in the pool. This may be different from the
|
||||
// value returned by DegreeOfParallelism to code using the pool.
|
||||
int NumThreads() const;
|
||||
|
||||
// Returns current thread id between 0 and NumThreads() - 1, if called from a
|
||||
// thread in the pool. Returns -1 otherwise.
|
||||
int CurrentThreadId() const;
|
||||
|
||||
// Run fn with up to n degree-of-parallelism enlisting the thread pool for
|
||||
// help. The degree-of-parallelism includes the caller, and so if n==1
|
||||
// then the function will run directly in the caller. The fork-join
|
||||
|
|
@ -359,11 +371,14 @@ class ThreadPool {
|
|||
const std::ptrdiff_t block_size = 1) const;
|
||||
|
||||
ThreadOptions thread_options_;
|
||||
// underlying_threadpool_ is the user_threadpool if user_threadpool is
|
||||
// provided in the constructor. Otherwise it is the eigen_threadpool_.
|
||||
ExtendedThreadPoolInterface* underlying_threadpool_;
|
||||
// eigen_threadpool_ is instantiated and owned by thread::ThreadPool if
|
||||
// user_threadpool is not in the constructor.
|
||||
|
||||
// If a thread pool is created with degree_of_parallelism != 1 then an underlying
|
||||
// EigenThreadPool is used to create OS threads and handle work distribution to them.
|
||||
// If degree_of_parallelism == 1 then underlying_threadpool_ is left as nullptr
|
||||
// and parallel work is run directly by the caller.
|
||||
ExtendedThreadPoolInterface* underlying_threadpool_ = nullptr;
|
||||
|
||||
// If used, underlying_threadpool_ is instantiated and owned by the ThreadPool.
|
||||
std::unique_ptr<ThreadPoolTempl<Env> > extended_eigen_threadpool_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -74,19 +74,19 @@ public:
|
|||
// does not need to be unique, but we aim for a good distribution, particularly in the case where
|
||||
// most/all of the thread pool's threads are active in the loop. Threads outside the pool may
|
||||
// also be claiming work, with CurrentThreadId -1.
|
||||
int num_threads = _tp.NumThreads();
|
||||
int my_thread_idx = (_tp.CurrentThreadId() + 1) % num_threads;
|
||||
assert(my_thread_idx >= 0 && my_thread_idx < num_threads);
|
||||
int d_of_p = ThreadPool::DegreeOfParallelism(&_tp);
|
||||
int my_thread_idx = (_tp.CurrentThreadId() + 1) % d_of_p;
|
||||
assert(my_thread_idx >= 0 && my_thread_idx < d_of_p);
|
||||
|
||||
int home_shard;
|
||||
if (num_threads >= NUM_SHARDS) {
|
||||
if (d_of_p >= NUM_SHARDS) {
|
||||
// More threads than shards => allocate them home shards round-robin, aiming to sprace the load across
|
||||
// the shards
|
||||
home_shard = my_thread_idx % NUM_SHARDS;
|
||||
} else {
|
||||
// Fewer threads than shards => spread the threads evenly across the shards, so each will work
|
||||
// on a run of successive shards before contention
|
||||
home_shard = (my_thread_idx * NUM_SHARDS) / num_threads;
|
||||
home_shard = (my_thread_idx * NUM_SHARDS) / d_of_p;
|
||||
}
|
||||
assert(home_shard >= 0 && home_shard < NUM_SHARDS);
|
||||
return home_shard;
|
||||
|
|
@ -126,13 +126,26 @@ private:
|
|||
#pragma warning(pop) /* Padding added in LoopCounterShard, LoopCounter */
|
||||
#endif
|
||||
|
||||
ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, const NAME_CHAR_TYPE* name, int num_threads,
|
||||
ThreadPool::ThreadPool(Env* env,
|
||||
const ThreadOptions& thread_options,
|
||||
const NAME_CHAR_TYPE* name,
|
||||
int degree_of_parallelism,
|
||||
bool low_latency_hint)
|
||||
: thread_options_(thread_options) {
|
||||
ORT_ENFORCE(num_threads >= 1);
|
||||
extended_eigen_threadpool_ =
|
||||
onnxruntime::make_unique<ThreadPoolTempl<Env>>(name, num_threads, low_latency_hint, *env, thread_options_);
|
||||
underlying_threadpool_ = extended_eigen_threadpool_.get();
|
||||
// In the current implementation, a thread pool with degree_of_parallelism==1 uses
|
||||
// the caller as one of the threads for executing work. Hence we only create
|
||||
// additional thread(s) for degree_of_parallelism>=2.
|
||||
ORT_ENFORCE(degree_of_parallelism >= 1);
|
||||
if (degree_of_parallelism >= 2) {
|
||||
int threads_to_create = degree_of_parallelism - 1;
|
||||
extended_eigen_threadpool_ =
|
||||
onnxruntime::make_unique<ThreadPoolTempl<Env>>(name,
|
||||
threads_to_create,
|
||||
low_latency_hint,
|
||||
*env,
|
||||
thread_options_);
|
||||
underlying_threadpool_ = extended_eigen_threadpool_.get();
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() = default;
|
||||
|
|
@ -153,8 +166,8 @@ void ThreadPool::ParallelForFixedBlockSizeScheduling(const std::ptrdiff_t total,
|
|||
|
||||
// Split the work across threads in the pool. Each work item will run a loop claiming iterations,
|
||||
// hence we need at most one for each thread, even if the numberof blocks of iterations is larger.
|
||||
int num_threads = NumThreads();
|
||||
int num_work_items = static_cast<int>(std::min(static_cast<std::ptrdiff_t>(num_threads), total));
|
||||
auto d_of_p = DegreeOfParallelism(this);
|
||||
int num_work_items = static_cast<int>(std::min(static_cast<std::ptrdiff_t>(d_of_p), total));
|
||||
assert(num_work_items > 0);
|
||||
|
||||
LoopCounter lc(*this, total, block_size);
|
||||
|
|
@ -184,12 +197,20 @@ void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function<voi
|
|||
|
||||
void ThreadPool::Schedule(std::function<void()> fn) {
|
||||
ORT_ENFORCE(fn != nullptr);
|
||||
underlying_threadpool_->Schedule(std::move(fn));
|
||||
if (underlying_threadpool_) {
|
||||
underlying_threadpool_->Schedule(std::move(fn));
|
||||
} else {
|
||||
fn();
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::RunInParallel(std::function<void()> fn, int n) {
|
||||
ORT_ENFORCE(fn != nullptr);
|
||||
underlying_threadpool_->RunInParallel(std::move(fn), n);
|
||||
if (underlying_threadpool_) {
|
||||
underlying_threadpool_->RunInParallel(std::move(fn), n);
|
||||
} else {
|
||||
fn();
|
||||
}
|
||||
}
|
||||
|
||||
bool ThreadPool::ShouldParallelizeLoop(const std::ptrdiff_t num_iterations,
|
||||
|
|
@ -201,9 +222,10 @@ bool ThreadPool::ShouldParallelizeLoop(const std::ptrdiff_t num_iterations,
|
|||
|
||||
// Do not parallelize loops with only a single thread available. If the
|
||||
// caller is outside the current pool (ID == -1) then we parallelize
|
||||
// via the pool's thread(s). If the caller is inside the current pool
|
||||
// if the pool has any threads. If the caller is inside the current pool
|
||||
// (ID != -1) then we require at least one additional thread in the pool.
|
||||
if (CurrentThreadId() != -1 && NumThreads() == 1) {
|
||||
if ((CurrentThreadId() == -1 && NumThreads() == 0) ||
|
||||
(CurrentThreadId() != -1 && NumThreads() == 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -304,14 +326,17 @@ void ThreadPool::ParallelFor(std::ptrdiff_t n, const TensorOpCost& c,
|
|||
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t)>& f) {
|
||||
ORT_ENFORCE(n >= 0);
|
||||
Eigen::TensorOpCost cost{c.bytes_loaded, c.bytes_stored, c.compute_cycles};
|
||||
auto d_of_p = DegreeOfParallelism(this);
|
||||
// Compute small problems directly in the caller thread.
|
||||
if ((!ShouldParallelizeLoop(n)) ||
|
||||
Eigen::TensorCostModel<Eigen::ThreadPoolDevice>::numThreads(static_cast<double>(n), cost, static_cast<int>(NumThreads())) == 1) {
|
||||
Eigen::TensorCostModel<Eigen::ThreadPoolDevice>::numThreads(static_cast<double>(n),
|
||||
cost,
|
||||
d_of_p) == 1) {
|
||||
f(0, n);
|
||||
return;
|
||||
}
|
||||
|
||||
ptrdiff_t block = CalculateParallelForBlock(n, cost, nullptr, NumThreads());
|
||||
ptrdiff_t block = CalculateParallelForBlock(n, cost, nullptr, d_of_p);
|
||||
ParallelForFixedBlockSizeScheduling(n, block, f);
|
||||
}
|
||||
|
||||
|
|
@ -320,23 +345,38 @@ void ThreadPool::ParallelFor(std::ptrdiff_t total, double cost_per_unit,
|
|||
ParallelFor(total, TensorOpCost{0, 0, static_cast<double>(cost_per_unit)}, fn);
|
||||
}
|
||||
|
||||
int ThreadPool::NumThreads(const concurrency::ThreadPool* tp) {
|
||||
int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) {
|
||||
#ifdef _OPENMP
|
||||
// When using OpenMP, omp_get_num_threads() returns the number of threads in the
|
||||
// current parallel region. Hence if this is 1 then we aim to parallelise
|
||||
// across the number of threads configured. Otherwise, given that we do not
|
||||
// use nested parallelism, we do not parallelise further.
|
||||
ORT_UNUSED_PARAMETER(tp);
|
||||
return (omp_get_num_threads() == 1) ? omp_get_max_threads() : 1;
|
||||
#else
|
||||
return tp ? tp->NumThreads() : 1;
|
||||
// When not using OpenMP, we parallelise over the N threads created by the pool
|
||||
// tp, plus 1 for the thread entering a loop.
|
||||
return tp ? (tp->NumThreads()+1) : 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Return the number of threads created by the pool.
|
||||
int ThreadPool::NumThreads() const {
|
||||
return underlying_threadpool_->NumThreads();
|
||||
if (underlying_threadpool_) {
|
||||
return underlying_threadpool_->NumThreads();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Return ID of the current thread within this pool. Returns -1 for a thread outside the
|
||||
// current pool.
|
||||
int ThreadPool::CurrentThreadId() const {
|
||||
return underlying_threadpool_->CurrentThreadId();
|
||||
if (underlying_threadpool_) {
|
||||
return underlying_threadpool_->CurrentThreadId();
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace concurrency
|
||||
|
|
|
|||
|
|
@ -774,7 +774,7 @@ MlasGetMaximumThreadCount(
|
|||
return 1;
|
||||
#endif
|
||||
#else
|
||||
return onnxruntime::concurrency::ThreadPool::NumThreads(ThreadPool);
|
||||
return onnxruntime::concurrency::ThreadPool::DegreeOfParallelism(ThreadPool);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape
|
|||
const int64_t num_blocks = input_shape[axis_parsed];
|
||||
const int64_t block_slice = reduced_cols / k;
|
||||
|
||||
int64_t tp_threads = concurrency::ThreadPool::NumThreads(threadpool);
|
||||
int64_t tp_threads = concurrency::ThreadPool::DegreeOfParallelism(threadpool);
|
||||
int64_t num_threads = std::min(tp_threads, rows); // split on rows so can't have more threads than rows
|
||||
|
||||
// rough attempt to make sure there's enough work for each thread. if there's insufficient work the usage of
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
|
|||
} else {
|
||||
// split the work into one block per thread so we can re-use the 'private_scores' vector as much as possible
|
||||
// TODO: Refine the number of threads used
|
||||
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::NumThreads(ttp), SafeInt<int32_t>(n_trees_));
|
||||
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(n_trees_));
|
||||
OrtMutex merge_mutex;
|
||||
concurrency::ThreadPool::TrySimpleParallelFor(
|
||||
ttp,
|
||||
|
|
@ -361,7 +361,7 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
|
|||
} else {
|
||||
// split the work into one block per thread so we can re-use the 'scores' vector as much as possible
|
||||
// TODO: Refine the number of threads used.
|
||||
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::NumThreads(ttp), SafeInt<int32_t>(N));
|
||||
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(N));
|
||||
concurrency::ThreadPool::TrySimpleParallelFor(
|
||||
ttp,
|
||||
num_threads,
|
||||
|
|
|
|||
|
|
@ -1069,7 +1069,7 @@ void UniDirectionalLstm<T>::GateComputations(
|
|||
|
||||
template <typename T>
|
||||
void UniDirectionalLstm<T>::SetNumThreads() {
|
||||
int threads = concurrency::ThreadPool::NumThreads(thread_pool_);
|
||||
int threads = concurrency::ThreadPool::DegreeOfParallelism(thread_pool_);
|
||||
|
||||
if (threads < 1)
|
||||
threads = 1;
|
||||
|
|
|
|||
|
|
@ -10,20 +10,24 @@ from onnxruntime_test_ort_trainer import runBertTrainingTest
|
|||
|
||||
class TestOrtTrainer(unittest.TestCase):
|
||||
def testBertTrainingMixedPrecision(self):
|
||||
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
|
||||
expected_all_finites = [False, True, True, True, True, True, True, True]
|
||||
expected_eval_loss = [10.960938]
|
||||
expected_losses = [
|
||||
11.034248352050781, 11.125300407409668, 11.006105422973633, 11.047048568725586,
|
||||
11.027417182922363, 11.015759468078613, 11.060905456542969, 10.971782684326172]
|
||||
expected_all_finites = [True, True, True, True, True, True, True, True]
|
||||
expected_eval_loss = [10.959012985229492]
|
||||
actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=1, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False)
|
||||
|
||||
rtol = 1e-04
|
||||
rtol = 1e-02
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
def testBertTrainingMixedPrecisionInternalLossScale(self):
|
||||
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
|
||||
expected_eval_loss = [10.960938]
|
||||
expected_losses = [
|
||||
11.034248352050781, 11.125300407409668, 11.006105422973633, 11.047048568725586,
|
||||
11.027417182922363, 11.015759468078613, 11.060905456542969, 10.971782684326172]
|
||||
expected_eval_loss = [10.959012985229492]
|
||||
actual_losses, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=1,
|
||||
use_mixed_precision=True,
|
||||
|
|
@ -31,18 +35,20 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
use_simple_model_desc=False,
|
||||
use_internel_loss_scale=True)
|
||||
|
||||
rtol = 1e-04
|
||||
rtol = 1e-02
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
||||
def testBertTrainingGradientAccumulationMixedPrecision(self):
|
||||
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
|
||||
expected_all_finites = [False, True]
|
||||
expected_eval_loss = [10.960938]
|
||||
expected_losses = [
|
||||
11.034248352050781, 11.125300407409668, 11.006077766418457, 11.047025680541992,
|
||||
11.027434349060059, 11.0156831741333, 11.060973167419434, 10.971841812133789]
|
||||
expected_all_finites = [True, True]
|
||||
expected_eval_loss = [10.95903205871582]
|
||||
actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=4, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False)
|
||||
|
||||
rtol = 1e-04
|
||||
rtol = 1e-02
|
||||
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
|
||||
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
|
||||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
|
||||
|
|
|
|||
|
|
@ -66,54 +66,56 @@ class ORTGlueTest(unittest.TestCase):
|
|||
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/")
|
||||
self.cache_dir = '/tmp/glue/'
|
||||
self.logging_steps = 10
|
||||
self.rtol = 1e-02
|
||||
|
||||
|
||||
def test_roberta_with_mrpc(self):
|
||||
expected_acc = 0.8897058823529411
|
||||
expected_f1 = 0.9200710479573712
|
||||
expected_acc_and_f1 = 0.9048884651551561
|
||||
expected_loss = 0.2911236987394445
|
||||
expected_acc = 0.8676470588235294
|
||||
expected_f1 = 0.9035714285714286
|
||||
expected_acc_and_f1 = 0.885609243697479
|
||||
expected_loss = 0.3022572344862947
|
||||
|
||||
results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False)
|
||||
assert_allclose(results['acc'], expected_acc)
|
||||
assert_allclose(results['f1'], expected_f1)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1)
|
||||
assert_allclose(results['loss'], expected_loss)
|
||||
assert_allclose(results['acc'], expected_acc, rtol=self.rtol)
|
||||
assert_allclose(results['f1'], expected_f1, rtol=self.rtol)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol)
|
||||
assert_allclose(results['loss'], expected_loss, rtol=self.rtol)
|
||||
|
||||
def test_roberta_fp16_with_mrpc(self):
|
||||
expected_acc = 0.8921568627450981
|
||||
expected_f1 = 0.9219858156028369
|
||||
expected_acc_and_f1 = 0.9070713391739675
|
||||
expected_loss = 0.3033953265232198
|
||||
expected_acc = 0.8995098039215687
|
||||
expected_f1 = 0.9279437609841829
|
||||
expected_acc_and_f1 = 0.9137267824528758
|
||||
expected_loss = 0.32052762967114357
|
||||
|
||||
results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True)
|
||||
assert_allclose(results['acc'], expected_acc)
|
||||
assert_allclose(results['f1'], expected_f1)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1)
|
||||
assert_allclose(results['loss'], expected_loss)
|
||||
assert_allclose(results['acc'], expected_acc, rtol=self.rtol)
|
||||
assert_allclose(results['f1'], expected_f1, rtol=self.rtol)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol)
|
||||
assert_allclose(results['loss'], expected_loss, rtol=self.rtol)
|
||||
|
||||
def test_bert_with_mrpc(self):
|
||||
expected_acc = 0.8529411764705882
|
||||
expected_f1 = 0.896551724137931
|
||||
expected_acc_and_f1 = 0.8747464503042597
|
||||
expected_loss = 0.4139287974320206
|
||||
expected_acc = 0.8553921568627451
|
||||
expected_f1 = 0.8970331588132635
|
||||
expected_acc_and_f1 = 0.8762126578380043
|
||||
expected_loss = 0.42737212419217707
|
||||
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False)
|
||||
assert_allclose(results['acc'], expected_acc)
|
||||
assert_allclose(results['f1'], expected_f1)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1)
|
||||
assert_allclose(results['loss'], expected_loss)
|
||||
assert_allclose(results['acc'], expected_acc, rtol=self.rtol)
|
||||
assert_allclose(results['f1'], expected_f1, rtol=self.rtol)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol)
|
||||
assert_allclose(results['loss'], expected_loss, rtol=self.rtol)
|
||||
|
||||
def test_bert_fp16_with_mrpc(self):
|
||||
expected_acc = 0.8627450980392157
|
||||
expected_f1 = 0.9047619047619047
|
||||
expected_acc_and_f1 = 0.8837535014005602
|
||||
expected_loss = 0.41143255315574945
|
||||
expected_acc = 0.8651960784313726
|
||||
expected_f1 = 0.9063032367972743
|
||||
expected_acc_and_f1 = 0.8857496576143234
|
||||
expected_loss = 0.38716790532948925
|
||||
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True)
|
||||
assert_allclose(results['acc'], expected_acc)
|
||||
assert_allclose(results['f1'], expected_f1)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1)
|
||||
assert_allclose(results['loss'], expected_loss)
|
||||
assert_allclose(results['acc'], expected_acc, rtol=self.rtol)
|
||||
assert_allclose(results['f1'], expected_f1, rtol=self.rtol)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol)
|
||||
assert_allclose(results['loss'], expected_loss, rtol=self.rtol)
|
||||
|
||||
def model_to_desc(self, model_name, model):
|
||||
if model_name.startswith('bert') or model_name.startswith('xlnet'):
|
||||
|
|
|
|||
|
|
@ -1101,6 +1101,15 @@ def adb_shell(*args, **kwargs):
|
|||
def run_training_python_frontend_tests(cwd):
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_training_unit_tests.py'], cwd=cwd)
|
||||
run_subprocess([
|
||||
sys.executable, 'orttraining_test_transformers.py',
|
||||
'BertModelTest.test_for_pretraining_full_precision_list_input'], cwd=cwd)
|
||||
run_subprocess([
|
||||
sys.executable, 'orttraining_test_transformers.py',
|
||||
'BertModelTest.test_for_pretraining_full_precision_dict_input'], cwd=cwd)
|
||||
run_subprocess([
|
||||
sys.executable, 'orttraining_test_transformers.py',
|
||||
'BertModelTest.test_for_pretraining_full_precision_list_and_dict_input'], cwd=cwd)
|
||||
|
||||
|
||||
def run_training_python_frontend_e2e_tests(cwd):
|
||||
|
|
@ -1120,16 +1129,20 @@ def run_training_python_frontend_e2e_tests(cwd):
|
|||
[sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_bert_fp16_with_mrpc', '-v'],
|
||||
cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'})
|
||||
|
||||
run_subprocess(
|
||||
[sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_roberta_with_mrpc', '-v'],
|
||||
cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'})
|
||||
|
||||
run_subprocess(
|
||||
[sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_roberta_fp16_with_mrpc', '-v'],
|
||||
cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'})
|
||||
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'], cwd=cwd)
|
||||
|
||||
run_subprocess([
|
||||
sys.executable, 'orttraining_test_transformers.py',
|
||||
'BertModelTest.test_for_pretraining_mixed_precision_all'], cwd=cwd)
|
||||
|
||||
run_subprocess([
|
||||
sys.executable, 'orttraining_test_transformers.py',
|
||||
'BertModelTest.test_for_pretraining_full_precision_all'], cwd=cwd)
|
||||
|
||||
|
||||
def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
|
||||
for config in configs:
|
||||
|
|
|
|||
Loading…
Reference in a new issue