mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67746 Test Plan: Visual inspection. Sandcastle. Reviewed By: zertosh Differential Revision: D31986646 fbshipit-source-id: 91885c20c3cead3853c49abb9fe0a94a67f33cc8
210 lines
6.1 KiB
C++
210 lines
6.1 KiB
C++
#include <benchmark/benchmark.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include "deep_wide_pt.h"
|
|
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
|
|
using namespace torch;
|
|
|
|
static void BM_deep_wide_base(benchmark::State& state) {
|
|
std::shared_ptr<DeepAndWide> net =
|
|
std::make_shared<DeepAndWide>(num_features);
|
|
|
|
const int batch_size = state.range(0);
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
// warmup
|
|
net->forward(ad_emb_packed, user_emb, wide);
|
|
for (auto _ : state) {
|
|
net->forward(ad_emb_packed, user_emb, wide);
|
|
}
|
|
}
|
|
|
|
static void BM_deep_wide_fast(benchmark::State& state) {
|
|
std::shared_ptr<DeepAndWideFast> net =
|
|
std::make_shared<DeepAndWideFast>(num_features);
|
|
|
|
const int batch_size = state.range(0);
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
// warmup
|
|
net->forward(ad_emb_packed, user_emb, wide);
|
|
for (auto _ : state) {
|
|
net->forward(ad_emb_packed, user_emb, wide);
|
|
}
|
|
}
|
|
|
|
static void BM_deep_wide_jit_graph_executor(benchmark::State& state) {
|
|
auto mod = getDeepAndWideSciptModel();
|
|
|
|
const int batch_size = state.range(0);
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
|
|
CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
|
|
|
|
mod.forward(inputs);
|
|
for (auto _ : state) {
|
|
mod.forward(inputs);
|
|
}
|
|
}
|
|
|
|
static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
|
|
auto mod = getDeepAndWideSciptModel();
|
|
|
|
const int batch_size = state.range(0);
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
|
|
CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
|
|
|
|
mod.forward(inputs);
|
|
for (auto _ : state) {
|
|
mod.forward(inputs);
|
|
}
|
|
}
|
|
|
|
static void BM_deep_wide_static(benchmark::State& state) {
|
|
auto mod = getDeepAndWideSciptModel();
|
|
torch::jit::StaticModule smod(mod);
|
|
|
|
const int batch_size = state.range(0);
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
|
|
smod(inputs, {});
|
|
for (auto _ : state) {
|
|
smod(inputs, {});
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<torch::jit::StaticModule> getStaticModule() {
|
|
static auto smod =
|
|
std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
|
|
return smod;
|
|
}
|
|
|
|
static void BM_deep_wide_static_threaded(benchmark::State& state) {
|
|
auto sm = getStaticModule();
|
|
torch::jit::StaticRuntime sr(*sm);
|
|
|
|
const int batch_size = 1; // state.range(0);
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
|
|
sr(inputs, {});
|
|
for (auto _ : state) {
|
|
sr(inputs, {});
|
|
}
|
|
}
|
|
|
|
static void BM_leaky_relu_const(benchmark::State& state) {
|
|
auto mod = getLeakyReLUConstScriptModel();
|
|
torch::jit::StaticModule smod(mod);
|
|
|
|
const int batch_size = state.range(0);
|
|
auto data = torch::randn({batch_size, num_features});
|
|
std::vector<c10::IValue> inputs({data});
|
|
|
|
smod(inputs, {});
|
|
for (auto _ : state) {
|
|
smod(inputs, {});
|
|
}
|
|
}
|
|
|
|
static void BM_leaky_relu(benchmark::State& state) {
|
|
auto mod = getLeakyReLUScriptModel();
|
|
torch::jit::StaticModule smod(mod);
|
|
|
|
const int batch_size = state.range(0);
|
|
auto neg_slope = torch::randn(1);
|
|
auto data = torch::randn({batch_size, num_features});
|
|
std::vector<c10::IValue> inputs({data, neg_slope[0]});
|
|
|
|
smod(inputs, {});
|
|
for (auto _ : state) {
|
|
smod(inputs, {});
|
|
}
|
|
}
|
|
|
|
BENCHMARK(BM_leaky_relu)->RangeMultiplier(8)->Ranges({{1, 20}});
|
|
BENCHMARK(BM_leaky_relu_const)->RangeMultiplier(8)->Ranges({{1, 20}});
|
|
|
|
static void BM_signed_log1p(benchmark::State& state) {
|
|
auto mod = getSignedLog1pModel();
|
|
torch::jit::StaticModule smod(mod);
|
|
|
|
const int num_elements = state.range(0);
|
|
auto data = torch::randn({num_elements});
|
|
std::vector<c10::IValue> inputs({data});
|
|
|
|
smod(inputs, {});
|
|
for (auto _ : state) {
|
|
smod(inputs, {});
|
|
}
|
|
}
|
|
|
|
BENCHMARK(BM_signed_log1p)->RangeMultiplier(8)->Ranges({{16, 65536}});
|
|
|
|
static void BM_long_static_memory_optimization(benchmark::State& state) {
|
|
auto mod = getLongScriptModel();
|
|
torch::jit::StaticModuleOptions opts;
|
|
opts.optimize_memory = state.range(1);
|
|
torch::jit::StaticModule smod(mod, false, opts);
|
|
|
|
const auto N = state.range(0);
|
|
auto a = torch::randn({N, N});
|
|
auto b = torch::randn({N, N});
|
|
auto c = torch::randn({N, N});
|
|
std::vector<c10::IValue> inputs({a, b, c});
|
|
|
|
smod(inputs, {});
|
|
for (auto _ : state) {
|
|
smod(inputs, {});
|
|
}
|
|
}
|
|
|
|
BENCHMARK(BM_deep_wide_base)->RangeMultiplier(8)->Ranges({{1, 20}});
|
|
BENCHMARK(BM_deep_wide_fast)->RangeMultiplier(8)->Ranges({{1, 20}});
|
|
|
|
BENCHMARK(BM_deep_wide_jit_graph_executor)
|
|
->RangeMultiplier(8)
|
|
->Ranges({{1, 20}});
|
|
|
|
BENCHMARK(BM_deep_wide_jit_profiling_executor)
|
|
->RangeMultiplier(8)
|
|
->Ranges({{1, 20}});
|
|
|
|
BENCHMARK(BM_deep_wide_static)->RangeMultiplier(8)->Ranges({{1, 20}});
|
|
BENCHMARK(BM_deep_wide_static_threaded)->Threads(8);
|
|
|
|
BENCHMARK(BM_long_static_memory_optimization)
|
|
->Args({2 << 0, 0})
|
|
->Args({2 << 2, 0})
|
|
->Args({2 << 4, 0})
|
|
->Args({2 << 8, 0})
|
|
->Args({2 << 0, 1})
|
|
->Args({2 << 2, 1})
|
|
->Args({2 << 4, 1})
|
|
->Args({2 << 8, 1});
|
|
|
|
int main(int argc, char** argv) {
|
|
c10::ParseCommandLineFlags(&argc, &argv);
|
|
::benchmark::Initialize(&argc, argv);
|
|
::benchmark::RunSpecifiedBenchmarks();
|
|
}
|