mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update intra_inter_benchmark (#22051)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22051 ghimport-source-id: 70710b3866b1a5e21656b77d2695ada74d00254e Test Plan: PARALLEL_BACKEND=NATIVE_TBB USE_OPENMP=0 USE_TBB=1 MKL_SEQ=1 MKLDNN_THREADING=SEQ USE_CUDA=0 BLAS=MKL USE_MKLDNN=1 BUILD_BINARY=1 python setup.py develop --cmake ./build/bin/intra_inter_benchmark Imported from OSS Differential Revision: D15933951 Pulled By: ilia-cher fbshipit-source-id: 88ad8f7a1634c1612ffaa68f22721ffc73d9b2ba
This commit is contained in:
parent
91bf0a9f9d
commit
7b1d6c8912
4 changed files with 141 additions and 53 deletions
|
|
@ -50,11 +50,11 @@ void launch_tasks_and_wait(int tasks_num) {
|
|||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
|
||||
std::cout << "Failed to parse command line flags" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
|
||||
at::init_num_threads();
|
||||
|
||||
if (FLAGS_inter_op_threads > 0) {
|
||||
|
|
|
|||
|
|
@ -1,66 +1,114 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/Parallel.h"
|
||||
|
||||
#include "c10/util/Flags.h"
|
||||
#include "caffe2/core/init.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <ctime>
|
||||
#include <thread>
|
||||
|
||||
C10_DEFINE_int(iter, 100, "Number of iterations (tasks)");
|
||||
C10_DEFINE_int(sub_iter, 100, "Number of subtasks")
|
||||
C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
|
||||
C10_DEFINE_int(iter_pow, 10, "Number of tasks, 2^N");
|
||||
C10_DEFINE_int(sub_iter, 1024, "Number of subtasks");
|
||||
C10_DEFINE_int(warmup_iter_pow, 3, "Number of warmup tasks, 2^N");
|
||||
C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
|
||||
C10_DEFINE_int(intra_op_threads, 0, "Number of intra-op threads");
|
||||
C10_DEFINE_int(tensor_dim, 2000, "Tensor dim");
|
||||
C10_DEFINE_int(benchmark_iter, 3, "Number of times to run benchmark")
|
||||
C10_DEFINE_int(tensor_dim, 50, "Tensor dim");
|
||||
C10_DEFINE_int(benchmark_iter, 10, "Number of times to run benchmark")
|
||||
C10_DEFINE_bool(extra_stats, false,
|
||||
"Collect extra stats; warning: skews results");
|
||||
C10_DEFINE_string(task_type, "add", "Tensor operation: add or mm");
|
||||
|
||||
namespace {
|
||||
std::atomic<int> counter{0};
|
||||
int overall_tasks = 0;
|
||||
std::condition_variable cv;
|
||||
std::mutex mutex;
|
||||
}
|
||||
std::mutex tasks_mutex;
|
||||
bool run_mm = false;
|
||||
|
||||
void launch_task(at::Tensor& left, at::Tensor& right) {
|
||||
at::launch([&left, &right]() {
|
||||
at::parallel_for(0, FLAGS_sub_iter, 1,
|
||||
[&left, &right](int64_t begin, int64_t end) {
|
||||
for (auto k = begin; k < end; ++k) {
|
||||
auto result = left.add(right);
|
||||
auto cur_ctr = ++counter;
|
||||
if (cur_ctr == overall_tasks) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
cv.notify_one();
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
std::mutex stats_mutex;
|
||||
std::unordered_set<std::thread::id> tids;
|
||||
}
|
||||
|
||||
void wait() {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
std::unique_lock<std::mutex> lk(tasks_mutex);
|
||||
while (counter < overall_tasks) {
|
||||
cv.wait(lk);
|
||||
}
|
||||
}
|
||||
|
||||
void launch_tasks_and_wait(at::Tensor& left, at::Tensor& right, int tasks_num) {
|
||||
overall_tasks = tasks_num * FLAGS_sub_iter;
|
||||
counter = 0;
|
||||
for (auto idx = 0; idx < tasks_num; ++idx) {
|
||||
launch_task(left, right);
|
||||
void _launch_tasks_tree(
|
||||
int level, int end_level, at::Tensor& left, at::Tensor& right) {
|
||||
if (level == end_level) {
|
||||
at::parallel_for(0, FLAGS_sub_iter, 1,
|
||||
[&left, &right](int64_t begin, int64_t end) {
|
||||
if (FLAGS_extra_stats) {
|
||||
std::unique_lock<std::mutex> lk(stats_mutex);
|
||||
tids.insert(std::this_thread::get_id());
|
||||
}
|
||||
for (auto k = begin; k < end; ++k) {
|
||||
if (run_mm) {
|
||||
left.mm(right);
|
||||
} else {
|
||||
left.add(right);
|
||||
}
|
||||
auto cur_ctr = ++counter;
|
||||
if (cur_ctr == overall_tasks) {
|
||||
std::unique_lock<std::mutex> lk(tasks_mutex);
|
||||
cv.notify_one();
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
at::launch([&left, &right, level, end_level]() {
|
||||
_launch_tasks_tree(level + 1, end_level, left, right);
|
||||
});
|
||||
at::launch([&left, &right, level, end_level]() {
|
||||
_launch_tasks_tree(level + 1, end_level, left, right);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
void launch_tasks_and_wait(at::Tensor& left, at::Tensor& right, int iter_pow) {
|
||||
overall_tasks = pow(2, iter_pow) * FLAGS_sub_iter;
|
||||
counter = 0;
|
||||
|
||||
_launch_tasks_tree(0, iter_pow, left, right);
|
||||
wait();
|
||||
}
|
||||
|
||||
void reset_extra_stats() {
|
||||
tids.clear();
|
||||
}
|
||||
|
||||
void print_extra_stats() {
|
||||
std::cout << "# threads: " << tids.size() << std::endl;
|
||||
}
|
||||
|
||||
void print_runtime_stats(const std::vector<float>& runtimes) {
|
||||
TORCH_INTERNAL_ASSERT(!runtimes.empty());
|
||||
float sum = 0.0;
|
||||
float sqr_sum = 0.0;
|
||||
size_t N = runtimes.size();
|
||||
for (size_t idx = 0; idx < N; ++idx) {
|
||||
sum += runtimes[idx];
|
||||
sqr_sum += runtimes[idx] * runtimes[idx];
|
||||
}
|
||||
float mean = sum / N;
|
||||
float sd = std::sqrt(sqr_sum / N - mean * mean);
|
||||
std::cout << "N = " << N << ", mean = " << mean << ", sd = " << sd
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
|
||||
std::cout << "Failed to parse command line flags" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
|
||||
at::init_num_threads();
|
||||
|
||||
if (FLAGS_inter_op_threads > 0) {
|
||||
|
|
@ -70,36 +118,49 @@ int main(int argc, char** argv) {
|
|||
at::set_num_threads(FLAGS_intra_op_threads);
|
||||
}
|
||||
|
||||
TORCH_CHECK(FLAGS_task_type == "add" || FLAGS_task_type == "mm");
|
||||
run_mm = FLAGS_task_type == "mm";
|
||||
|
||||
auto left = at::ones({FLAGS_tensor_dim, FLAGS_tensor_dim}, at::kFloat);
|
||||
auto right = at::ones({FLAGS_tensor_dim, FLAGS_tensor_dim}, at::kFloat);
|
||||
|
||||
std::cout << "Launching " << FLAGS_warmup_iter << " warmup tasks" << std::endl;
|
||||
std::cout << "Launching " << pow(2, FLAGS_warmup_iter_pow)
|
||||
<< " warmup tasks" << std::endl;
|
||||
|
||||
typedef std::chrono::high_resolution_clock clock;
|
||||
typedef std::chrono::milliseconds ms;
|
||||
|
||||
std::chrono::time_point<clock> start_time = clock::now();
|
||||
launch_tasks_and_wait(left, right, FLAGS_warmup_iter);
|
||||
launch_tasks_and_wait(left, right, FLAGS_warmup_iter_pow);
|
||||
auto duration = static_cast<float>(
|
||||
std::chrono::duration_cast<ms>(clock::now() - start_time).count());
|
||||
|
||||
std::cout << "Warmup time: " << duration << " ms." << std::endl;
|
||||
|
||||
std::cout << "Launching " << FLAGS_iter << " tasks with "
|
||||
std::cout << "Launching " << pow(2, FLAGS_iter_pow) << " tasks with "
|
||||
<< FLAGS_sub_iter << " subtasks each, using "
|
||||
<< at::get_num_interop_threads() << " inter-op threads and "
|
||||
<< at::get_num_threads() << " intra-op threads, "
|
||||
<< "tensor dim: " << FLAGS_tensor_dim << std::endl;
|
||||
<< "tensor dim: " << FLAGS_tensor_dim
|
||||
<< ", task type: " << FLAGS_task_type << std::endl;
|
||||
|
||||
std::vector<float> runtimes;
|
||||
for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
|
||||
reset_extra_stats();
|
||||
start_time = clock::now();
|
||||
launch_tasks_and_wait(left, right, FLAGS_iter);
|
||||
launch_tasks_and_wait(left, right, FLAGS_iter_pow);
|
||||
duration = static_cast<float>(
|
||||
std::chrono::duration_cast<ms>(clock::now() - start_time).count());
|
||||
runtimes.push_back(duration);
|
||||
|
||||
std::cout << "Time to run " << FLAGS_iter << " iterations "
|
||||
<< (duration/1000.0) << " s." << std::endl;
|
||||
if (FLAGS_extra_stats) {
|
||||
print_extra_stats();
|
||||
}
|
||||
|
||||
std::cout << "Runtime: " << duration << " ms." << std::endl;
|
||||
}
|
||||
|
||||
print_runtime_stats(runtimes);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,4 +101,9 @@ bool GlobalInit() {
|
|||
char** mobile_argv = &mobile_name;
|
||||
return ::caffe2::GlobalInit(&mobile_argc, &mobile_argv);
|
||||
}
|
||||
|
||||
bool unsafeRunCaffe2InitFunction(const char* name, int* pargc, char*** pargv) {
|
||||
return internal::Caffe2InitializeRegistry::Registry()->RunNamedFunction(
|
||||
name, pargc, pargv);
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -15,8 +15,14 @@ class CAFFE2_API Caffe2InitializeRegistry {
|
|||
// multiple shared libraries loaded with RTLD_LOCAL
|
||||
static Caffe2InitializeRegistry* Registry();
|
||||
|
||||
void
|
||||
Register(InitFunction function, bool run_early, const char* description) {
|
||||
void Register(
|
||||
InitFunction function,
|
||||
bool run_early,
|
||||
const char* description,
|
||||
const char* name = nullptr) {
|
||||
if (name) {
|
||||
named_functions_[name] = function;
|
||||
}
|
||||
if (run_early) {
|
||||
// Disallow registration after GlobalInit of early init functions
|
||||
CAFFE_ENFORCE(!early_init_functions_run_yet_);
|
||||
|
|
@ -57,6 +63,13 @@ class CAFFE2_API Caffe2InitializeRegistry {
|
|||
return RunRegisteredInitFunctionsInternal(init_functions_, pargc, pargv);
|
||||
}
|
||||
|
||||
bool RunNamedFunction(const char* name, int* pargc, char*** pargv) {
|
||||
if (named_functions_.count(name)) {
|
||||
return named_functions_[name](pargc, pargv);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
// Run all registered initialization functions. This has to be called AFTER
|
||||
// all static initialization are finished and main() has started, since we are
|
||||
|
|
@ -77,31 +90,40 @@ class CAFFE2_API Caffe2InitializeRegistry {
|
|||
Caffe2InitializeRegistry() {}
|
||||
vector<std::pair<InitFunction, const char*> > early_init_functions_;
|
||||
vector<std::pair<InitFunction, const char*> > init_functions_;
|
||||
std::unordered_map<std::string, InitFunction> named_functions_;
|
||||
bool early_init_functions_run_yet_ = false;
|
||||
bool init_functions_run_yet_ = false;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
CAFFE2_API bool unsafeRunCaffe2InitFunction(
|
||||
const char* name,
|
||||
int* pargc = nullptr,
|
||||
char*** pargv = nullptr);
|
||||
|
||||
class CAFFE2_API InitRegisterer {
|
||||
public:
|
||||
InitRegisterer(internal::Caffe2InitializeRegistry::InitFunction function,
|
||||
bool run_early, const char* description) {
|
||||
internal::Caffe2InitializeRegistry::Registry()
|
||||
->Register(function, run_early, description);
|
||||
InitRegisterer(
|
||||
internal::Caffe2InitializeRegistry::InitFunction function,
|
||||
bool run_early,
|
||||
const char* description,
|
||||
const char* name = nullptr) {
|
||||
internal::Caffe2InitializeRegistry::Registry()->Register(
|
||||
function, run_early, description, name);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CAFFE2_INIT_FUNCTION(name, function, description) \
|
||||
namespace { \
|
||||
::caffe2::InitRegisterer g_caffe2_initregisterer_##name( \
|
||||
function, false, description); \
|
||||
} // namespace
|
||||
#define REGISTER_CAFFE2_INIT_FUNCTION(name, function, description) \
|
||||
namespace { \
|
||||
::caffe2::InitRegisterer \
|
||||
g_caffe2_initregisterer_##name(function, false, description, #name); \
|
||||
} // namespace
|
||||
|
||||
#define REGISTER_CAFFE2_EARLY_INIT_FUNCTION(name, function, description) \
|
||||
namespace { \
|
||||
::caffe2::InitRegisterer g_caffe2_initregisterer_##name( \
|
||||
function, true, description); \
|
||||
} // namespace
|
||||
#define REGISTER_CAFFE2_EARLY_INIT_FUNCTION(name, function, description) \
|
||||
namespace { \
|
||||
::caffe2::InitRegisterer \
|
||||
g_caffe2_initregisterer_##name(function, true, description, #name); \
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* @brief Determine whether GlobalInit has already been run
|
||||
|
|
|
|||
Loading…
Reference in a new issue