mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Use custom AppendOnlyList for op_events to reduce the number of atomic operations (#78643)
Summary: - Use atomic counter in Block storage constructor and offset within the block to calculate correlation_id. - Implicitly deduce correlation id, no need to store it when profiling. Test Plan: Added test_profiler_correlation_id() in test_profiler.py to check the uniqueness of correlation id. To run the test: python test_profiler.py Differential Revision: D36793803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78643 Approved by: https://github.com/robieta
This commit is contained in:
parent
d9a6f76a9e
commit
f754d2501d
5 changed files with 94 additions and 28 deletions
|
|
@ -1,5 +1,4 @@
|
|||
# Owner(s): ["oncall: profiler"]
|
||||
|
||||
import collections
|
||||
import gc
|
||||
import io
|
||||
|
|
@ -1084,6 +1083,30 @@ class TestProfiler(TestCase):
|
|||
with profile():
|
||||
self.assertEqual(profiler_type(), ActiveProfilerType.KINETO)
|
||||
|
||||
def test_profiler_correlation_id(self):
|
||||
'''
|
||||
We expect the correlation_id to be unique across multiple invokation of the profiler,
|
||||
So we will reuse id_uniqueness_set.
|
||||
'''
|
||||
id_uniqueness_set = set()
|
||||
model = torch.nn.Sequential(
|
||||
nn.Conv2d(16, 33, 18),
|
||||
nn.ReLU(),
|
||||
nn.Linear(243, 243),
|
||||
nn.ReLU(),
|
||||
)
|
||||
inputs = torch.randn(40, 16, 18, 260)
|
||||
uint32_max = 2**32 - 1
|
||||
for i in range(5):
|
||||
with profile() as prof:
|
||||
model(inputs)
|
||||
for event in prof.profiler.kineto_results.events():
|
||||
corr_id = event.correlation_id()
|
||||
if (corr_id):
|
||||
self.assertTrue(corr_id not in id_uniqueness_set)
|
||||
id_uniqueness_set.add(corr_id)
|
||||
self.assertTrue(corr_id < uint32_max)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -45,12 +45,6 @@ namespace autograd {
|
|||
namespace profiler {
|
||||
|
||||
namespace {
|
||||
// TODO: consider TLS (tid + tls counter)
|
||||
uint64_t next_correlation_id() {
|
||||
static std::atomic<uint64_t> corr_id_{1};
|
||||
return corr_id_++;
|
||||
}
|
||||
|
||||
inline int64_t getTimeUs() {
|
||||
#ifdef USE_KINETO
|
||||
return libkineto::timeSinceEpoch(std::chrono::system_clock::now());
|
||||
|
|
@ -386,7 +380,6 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
|
|||
// reenable the forward/backward correlation when kineto fix the following
|
||||
// raw pointer
|
||||
// GenericTraceActivity.flow.linkedActivity
|
||||
|
||||
/*
|
||||
std::unordered_map<uint64_t, libkineto::GenericTraceActivity*>
|
||||
tidSeq2activity;
|
||||
|
|
@ -538,13 +531,7 @@ std::unique_ptr<at::ObserverContext> onFunctionEnter(
|
|||
if (!state_ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto corr_id = next_correlation_id();
|
||||
if (fn.scope() == at::RecordScope::USER_SCOPE) {
|
||||
torch::profiler::impl::kineto::pushUserCorrelationId(corr_id);
|
||||
} else {
|
||||
torch::profiler::impl::kineto::pushCorrelationId(corr_id);
|
||||
}
|
||||
return state_ptr->record_queue_.getSubqueue()->begin_op(fn, corr_id);
|
||||
return state_ptr->record_queue_.getSubqueue()->begin_op(fn);
|
||||
}
|
||||
|
||||
// @lint-ignore CLANGTIDY clang-diagnostic-unused-parameter
|
||||
|
|
|
|||
|
|
@ -258,6 +258,23 @@ DEFINE_VISITOR(
|
|||
#undef DEFINE_VISITOR
|
||||
#undef OUT_T
|
||||
|
||||
template <typename T, size_t ChunkSize>
|
||||
ThreadLocalSubqueue::EventBlock<T, ChunkSize>::EventBlock() {
|
||||
static std::atomic<uint64_t> counter_{0};
|
||||
id_start_ = 1 + ChunkSize * counter_++;
|
||||
}
|
||||
template <class... Args>
|
||||
std::pair<KinetoObserverContext::Event*, uint64_t> ThreadLocalSubqueue::OpList::
|
||||
emplace_back(Args&&... args) {
|
||||
maybe_grow();
|
||||
*next_ = {std::forward<Args>(args)...};
|
||||
auto corr_id = buffer_last_->correlation_id(next_);
|
||||
return {next_++, corr_id};
|
||||
}
|
||||
uint64_t ThreadLocalSubqueue::OpList::correlationID(const OpList::Iterator& e) {
|
||||
return e.address().first->correlation_id(&*e);
|
||||
}
|
||||
|
||||
ThreadLocalSubqueue::ThreadLocalSubqueue(
|
||||
const uint64_t tid,
|
||||
const ProfilerConfig& config)
|
||||
|
|
@ -266,10 +283,10 @@ ThreadLocalSubqueue::ThreadLocalSubqueue(
|
|||
}
|
||||
|
||||
std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
|
||||
const at::RecordFunction& fn,
|
||||
uint64_t correlation_id) {
|
||||
auto event = op_events_.emplace_back(
|
||||
correlation_id,
|
||||
const at::RecordFunction& fn) {
|
||||
KinetoObserverContext::Event* event;
|
||||
uint64_t corr_id;
|
||||
std::tie(event, corr_id) = op_events_.emplace_back(
|
||||
fn.seqNr(),
|
||||
fn.forwardThreadId(),
|
||||
fn.scope(),
|
||||
|
|
@ -279,6 +296,11 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
|
|||
if (config_.report_input_shapes) {
|
||||
inputs_outputs_.push(fn.inputs());
|
||||
}
|
||||
if (fn.scope() == at::RecordScope::USER_SCOPE) {
|
||||
torch::profiler::impl::kineto::pushUserCorrelationId(corr_id);
|
||||
} else {
|
||||
torch::profiler::impl::kineto::pushCorrelationId(corr_id);
|
||||
}
|
||||
|
||||
#if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
|
||||
// backward nodes source range corresponds to the forward node
|
||||
|
|
@ -518,7 +540,9 @@ std::vector<std::shared_ptr<Result>> RecordQueue::getRecords(
|
|||
auto jit_module_it = queue.jit_modules_.begin();
|
||||
auto extra_args_it = queue.extra_args_.begin();
|
||||
auto gpu_fallback_it = queue.gpu_fallback_.begin();
|
||||
for (auto& i : queue.op_events_) {
|
||||
for (auto event = queue.op_events_.begin(); event != queue.op_events_.end();
|
||||
++event) {
|
||||
auto& i = *event;
|
||||
auto start_time = converter(i.start_time_);
|
||||
out.emplace_back(Result::create(
|
||||
start_time,
|
||||
|
|
@ -527,6 +551,7 @@ std::vector<std::shared_ptr<Result>> RecordQueue::getRecords(
|
|||
/*extra_fields_=*/
|
||||
ExtraFields<EventType::TorchOp>(
|
||||
std::move(i.basic_fields_),
|
||||
ThreadLocalSubqueue::OpList::correlationID(event),
|
||||
converter(i.end_time_),
|
||||
input_getter(),
|
||||
steal_or_default(jit_stack_it),
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ template <EventType>
|
|||
struct ExtraFields;
|
||||
|
||||
struct TorchOpBasicFields {
|
||||
uint64_t correlation_id_;
|
||||
int64_t sequence_number_;
|
||||
uint64_t forward_tid_;
|
||||
at::RecordScope scope_;
|
||||
|
|
@ -63,6 +62,7 @@ template <>
|
|||
struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
ExtraFields(
|
||||
TorchOpBasicFields&& f,
|
||||
uint64_t correlation_id,
|
||||
time_t end_time_ns,
|
||||
Inputs&& inputs,
|
||||
jit_stack_t&& jit_stack,
|
||||
|
|
@ -70,13 +70,14 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
|||
extra_args_t&& extra_args,
|
||||
FallbackPair&& gpu_fallback)
|
||||
: TorchOpBasicFields(std::move(f)),
|
||||
correlation_id_{correlation_id},
|
||||
end_time_ns_{end_time_ns},
|
||||
inputs_{std::move(inputs)},
|
||||
jit_stack_{std::move(jit_stack)},
|
||||
jit_modules_{std::move(jit_modules)},
|
||||
extra_args_{std::move(extra_args)},
|
||||
gpu_fallback_{std::move(gpu_fallback)} {}
|
||||
|
||||
uint64_t correlation_id_;
|
||||
time_t end_time_ns_;
|
||||
Inputs inputs_;
|
||||
jit_stack_t jit_stack_;
|
||||
|
|
@ -323,9 +324,7 @@ class TORCH_API ThreadLocalSubqueue {
|
|||
public:
|
||||
ThreadLocalSubqueue(const uint64_t tid, const ProfilerConfig& config);
|
||||
|
||||
std::unique_ptr<KinetoObserverContext> begin_op(
|
||||
const at::RecordFunction& fn,
|
||||
uint64_t correlation_id);
|
||||
std::unique_ptr<KinetoObserverContext> begin_op(const at::RecordFunction& fn);
|
||||
|
||||
template <class... Args>
|
||||
void emplace_backend_event(Args&&... args) {
|
||||
|
|
@ -358,7 +357,33 @@ class TORCH_API ThreadLocalSubqueue {
|
|||
friend class RecordQueue;
|
||||
// See `containers.h` for block size benchmarks.
|
||||
static constexpr size_t BlockSize = 512;
|
||||
AppendOnlyList<KinetoObserverContext::Event, BlockSize> op_events_;
|
||||
|
||||
template <typename T, size_t ChunkSize>
|
||||
class EventBlock : public std::array<T, ChunkSize> {
|
||||
public:
|
||||
EventBlock();
|
||||
uint64_t correlation_id(const T* ptr) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
ptr >= this->data() && ptr < this->data() + ChunkSize);
|
||||
return id_start_ + (ptr - this->data());
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t id_start_;
|
||||
};
|
||||
|
||||
class OpList : public AppendOnlyList<
|
||||
KinetoObserverContext::Event,
|
||||
BlockSize,
|
||||
EventBlock> {
|
||||
public:
|
||||
template <class... Args>
|
||||
std::pair<KinetoObserverContext::Event*, uint64_t> emplace_back(
|
||||
Args&&... args);
|
||||
static uint64_t correlationID(const OpList::Iterator& e);
|
||||
};
|
||||
|
||||
OpList op_events_;
|
||||
|
||||
// report_input_shapes
|
||||
InputOutputEncoder inputs_outputs_;
|
||||
|
|
|
|||
|
|
@ -39,10 +39,16 @@ namespace impl {
|
|||
// Performance drops off for larger values, so testing on a case-by-case basis
|
||||
// is recommended if performance is absolutely critical.
|
||||
|
||||
template <typename T, size_t ChunkSize>
|
||||
template <
|
||||
typename T,
|
||||
size_t ChunkSize,
|
||||
template <typename U, size_t N> class block_t = std::array>
|
||||
class AppendOnlyList {
|
||||
public:
|
||||
using array_t = std::array<T, ChunkSize>;
|
||||
using array_t = block_t<T, ChunkSize>;
|
||||
static_assert(
|
||||
std::is_base_of<std::array<T, ChunkSize>, array_t>::value,
|
||||
"AppendOnlyList expects raw low level pointer storage.");
|
||||
static_assert(ChunkSize > 0, "Block cannot be empty.");
|
||||
|
||||
AppendOnlyList() : buffer_last_{buffer_.before_begin()} {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue