[RecordFunction] More effecient machinery to determine which callbacks to run. (#75807)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75807
There is a tension in RecordFunction between two use cases:
1) In the normal eager path we don't run any callbacks, so we need to bail out of the profiling path as soon as possible to minimize eager overhead.
2) When profiling we want to determine which callbacks to run as efficiently as possible to minimize instrumentation overhead.
The confounding factor in all of this is sampling callbacks because they change which callbacks will run on each call, even in steady state operation. This has traditionally been handled with a two stage procedure: first we flip a coin to determine if a sampled callback *might* run. If false (which it usually is), do nothing. This solves (1). If true, check to see if we need to build the full callback set or if it was a false positive. This procedure has two negative effects:
* It forces us to rebuild the set of callbacks to run on every step when profiling
* It leaks the sampling abstraction, requiring other parts of the code to bump certain values and forces RecordFunction to lazily initialize.
This change introduces a multi-level cache which can (in the common case) quickly determine which callbacks *will* run, rather than if callbacks *might* run. This means that rather than call `shouldRunRecordFunction`, we can simply get the callbacks for an invocation and check if they are empty. (And completely removes the pre-sampling heuristic.) Another major benefit of the new cache structure is that it allows thread-safe registration and unregistration of global callbacks.
It's worth briefly discussing how this maintains eager performance. In the standard eager case (only sampling callbacks registered) the cache first checks that the global callbacks haven't changed (atomic read), decrements a counter to see if a sampling callback fired, and then returns the active callbacks which is simply a SmallVector of pointer pairs and a couple POD values (scope, needs inputs/outputs/ids). The biggest cost according to perf is the SmallVector logic; we could consider adopting a hard limit on active callbacks; more than half a dozen callbacks *running* in a single step would be quite a lot. But the total cost relative to `PYTORCH_DISABLE_PER_OP_PROFILING` is only ~10ns, so debatable if it's worth it to switch to `std::array`.
The primary change is in `record_function.cpp`, which has a more detailed description of the new cache structure. `record_function.h` has some minor changes to align with the new calling convention and the remaining files are simply changes to the call sites.
Future work:
* RecordFunction no longer needs to be lazily initialized.
* We can deprecate the disable/reenable APIs, since we can not safely add and remove global callbacks.
Test Plan:
I tested eager mode performance using the overhead benchmark and found that the non-profiled path was unaffected. However the no-op observer dropped from 0.41us to 0.37us (0.25us if no observers are active) which is about 1/3rd reduction in the cost of the callback selection machinery.
I also added several C++ unit tests, as the core RecordFunction machinery (especially sampling) was largely untested.
Reviewed By: swolchok, davidberard98
Differential Revision: D35276158
fbshipit-source-id: 35135f444724fba4eb97c0ae7f3f710f0f9016fd
(cherry picked from commit 9e359b87422c18f2a195185f32e7e85c82f956fd)
2022-04-19 20:40:00 +00:00
|
|
|
#include <array>
|
|
|
|
|
#include <atomic>
|
|
|
|
|
#include <condition_variable>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include <fmt/format.h>
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
|
|
|
|
|
#include <ATen/Parallel.h>
|
|
|
|
|
#include <ATen/record_function.h>
|
|
|
|
|
#include <c10/util/irange.h>
|
|
|
|
|
|
|
|
|
|
// Test that we can add and remove callbacks (both global and thread local.)
|
|
|
|
|
TEST(RecordFunctionTest, AddRemove) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
auto start_callback =
|
|
|
|
|
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
|
|
|
return nullptr;
|
|
|
|
|
};
|
|
|
|
|
auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
|
|
|
|
|
|
|
|
|
|
auto handle = at::addThreadLocalCallback(
|
|
|
|
|
at::RecordFunctionCallback(start_callback, end_callback));
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(at::hasCallbacks());
|
|
|
|
|
ASSERT_TRUE(at::hasThreadLocalCallbacks());
|
|
|
|
|
ASSERT_FALSE(at::hasGlobalCallbacks());
|
|
|
|
|
|
|
|
|
|
at::removeCallback(handle);
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
handle = at::addGlobalCallback(
|
|
|
|
|
at::RecordFunctionCallback(start_callback, end_callback));
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(at::hasCallbacks());
|
|
|
|
|
ASSERT_FALSE(at::hasThreadLocalCallbacks());
|
|
|
|
|
ASSERT_TRUE(at::hasGlobalCallbacks());
|
|
|
|
|
|
|
|
|
|
at::removeCallback(handle);
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Test that the callbacks that we register are actually run.
|
|
|
|
|
TEST(RecordFunctionTest, ThreadLocalState) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
static int tls_test_start_counter;
|
|
|
|
|
static int tls_test_end_counter;
|
|
|
|
|
tls_test_start_counter = 0;
|
|
|
|
|
tls_test_end_counter = 0;
|
|
|
|
|
|
|
|
|
|
auto start_callback =
|
|
|
|
|
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
|
|
|
|
++tls_test_start_counter;
|
|
|
|
|
return nullptr;
|
|
|
|
|
};
|
|
|
|
|
auto end_callback = [](const at::RecordFunction&, at::ObserverContext*) {
|
|
|
|
|
++tls_test_end_counter;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handle = at::addThreadLocalCallback(
|
|
|
|
|
at::RecordFunctionCallback(start_callback, end_callback));
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
at::RecordFunction guard(at::RecordScope::USER_SCOPE);
|
|
|
|
|
guard.before("Test");
|
|
|
|
|
EXPECT_EQ(tls_test_start_counter, 1);
|
|
|
|
|
EXPECT_EQ(tls_test_end_counter, 0);
|
|
|
|
|
}
|
|
|
|
|
EXPECT_EQ(tls_test_start_counter, 1);
|
|
|
|
|
EXPECT_EQ(tls_test_end_counter, 1);
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
tls_test_start_counter = 0;
|
|
|
|
|
tls_test_end_counter = 0;
|
|
|
|
|
at::DisableRecordFunctionGuard no_profile_guard;
|
|
|
|
|
at::RecordFunction guard(at::RecordScope::USER_SCOPE);
|
|
|
|
|
guard.before("Test");
|
|
|
|
|
EXPECT_EQ(tls_test_start_counter, 0);
|
|
|
|
|
EXPECT_EQ(tls_test_end_counter, 0);
|
|
|
|
|
}
|
|
|
|
|
EXPECT_EQ(tls_test_start_counter, 0);
|
|
|
|
|
EXPECT_EQ(tls_test_end_counter, 0);
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
tls_test_start_counter = 0;
|
|
|
|
|
tls_test_end_counter = 0;
|
|
|
|
|
RECORD_FUNCTION("Test", {});
|
|
|
|
|
EXPECT_EQ(tls_test_start_counter, 1);
|
|
|
|
|
EXPECT_EQ(tls_test_end_counter, 0);
|
|
|
|
|
}
|
|
|
|
|
EXPECT_EQ(tls_test_start_counter, 1);
|
|
|
|
|
EXPECT_EQ(tls_test_end_counter, 1);
|
|
|
|
|
|
|
|
|
|
at::removeCallback(handle);
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Test that callbacks are run in the order that they are registered.
|
|
|
|
|
TEST(RecordFunctionTest, CallOrder) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
static int current_index;
|
|
|
|
|
current_index = 0;
|
|
|
|
|
|
|
|
|
|
static std::array<std::string, 8> expected_order = {
|
|
|
|
|
"Start Callback 0 Outer",
|
|
|
|
|
"Start Callback 1 Outer",
|
|
|
|
|
"Start Callback 0 Inner",
|
|
|
|
|
"Start Callback 1 Inner",
|
|
|
|
|
"End Callback 0 Inner",
|
|
|
|
|
"End Callback 1 Inner",
|
|
|
|
|
"End Callback 0 Outer",
|
|
|
|
|
"End Callback 1 Outer",
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define REGISTER_CALLBACK(index) \
|
|
|
|
|
at::addThreadLocalCallback( \
|
|
|
|
|
at::RecordFunctionCallback( \
|
|
|
|
|
[](const at::RecordFunction& fn) \
|
|
|
|
|
-> std::unique_ptr<at::ObserverContext> { \
|
|
|
|
|
EXPECT_EQ( \
|
|
|
|
|
fmt::format("Start Callback {} {}", index, fn.name()), \
|
|
|
|
|
expected_order[current_index++]); \
|
|
|
|
|
return nullptr; \
|
|
|
|
|
}, \
|
|
|
|
|
[](const at::RecordFunction& fn, at::ObserverContext*) { \
|
|
|
|
|
EXPECT_EQ( \
|
|
|
|
|
fmt::format("End Callback {} {}", index, fn.name()), \
|
|
|
|
|
expected_order[current_index++]); \
|
|
|
|
|
}) \
|
|
|
|
|
.scopes({at::RecordScope::FUNCTION}))
|
|
|
|
|
|
|
|
|
|
REGISTER_CALLBACK(0);
|
|
|
|
|
REGISTER_CALLBACK(1);
|
|
|
|
|
#undef REGISTER_CALLBACK
|
|
|
|
|
|
|
|
|
|
RECORD_FUNCTION("Outer", {});
|
|
|
|
|
{ RECORD_FUNCTION("Inner", {}); }
|
|
|
|
|
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Make sure TLS migrates when tasks are launched.
|
|
|
|
|
TEST(RecordFunctionTest, ThreadMigration) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
static int call_count;
|
|
|
|
|
call_count = 0;
|
|
|
|
|
|
|
|
|
|
auto handle = at::addThreadLocalCallback(
|
|
|
|
|
at::RecordFunctionCallback(
|
|
|
|
|
[](const at::RecordFunction&)
|
|
|
|
|
-> std::unique_ptr<at::ObserverContext> { return nullptr; },
|
|
|
|
|
[](const at::RecordFunction&, at::ObserverContext*) { ++call_count; })
|
|
|
|
|
.scopes({at::RecordScope::FUNCTION}));
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(call_count, 0);
|
|
|
|
|
|
|
|
|
|
std::condition_variable cv;
|
|
|
|
|
std::mutex lock;
|
|
|
|
|
at::launch([&cv]() {
|
|
|
|
|
RECORD_FUNCTION("Test", {});
|
|
|
|
|
cv.notify_all();
|
|
|
|
|
});
|
|
|
|
|
auto guard = std::unique_lock<std::mutex>(lock);
|
|
|
|
|
cv.wait(guard, [] { return call_count > 0; });
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(call_count, 1);
|
|
|
|
|
|
|
|
|
|
at::removeCallback(handle);
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Test sampling logic and validate that callbacks fire at the correct times.
|
|
|
|
|
TEST(RecordFunctionTest, Sampling) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
static int sample_test_counter;
|
|
|
|
|
sample_test_counter = 0;
|
|
|
|
|
|
|
|
|
|
uint32_t seed = 12345;
|
|
|
|
|
double p = 0.25;
|
|
|
|
|
|
|
|
|
|
at::set_record_function_seed_for_testing(seed);
|
|
|
|
|
std::mt19937 generator;
|
|
|
|
|
generator.seed(seed);
|
|
|
|
|
auto dist = std::geometric_distribution<int>(p);
|
|
|
|
|
|
|
|
|
|
// Make sure we know which steps should fire.
|
|
|
|
|
auto outcomes = std::array<int, 5>{7, 0, 0, 6, 2};
|
|
|
|
|
for (const auto i : c10::irange(outcomes.size())) {
|
|
|
|
|
ASSERT_EQ(dist(generator), outcomes[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> expected_counts;
|
|
|
|
|
int running_count = 0;
|
|
|
|
|
for (const auto i : c10::irange(outcomes.size())) {
|
2024-12-19 00:18:08 +00:00
|
|
|
for ([[maybe_unused]] const auto j : c10::irange(outcomes[i])) {
|
[RecordFunction] More effecient machinery to determine which callbacks to run. (#75807)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75807
There is a tension in RecordFunction between two use cases:
1) In the normal eager path we don't run any callbacks, so we need to bail out of the profiling path as soon as possible to minimize eager overhead.
2) When profiling we want to determine which callbacks to run as efficiently as possible to minimize instrumentation overhead.
The confounding factor in all of this is sampling callbacks because they change which callbacks will run on each call, even in steady state operation. This has traditionally been handled with a two stage procedure: first we flip a coin to determine if a sampled callback *might* run. If false (which it usually is), do nothing. This solves (1). If true, check to see if we need to build the full callback set or if it was a false positive. This procedure has two negative effects:
* It forces us to rebuild the set of callbacks to run on every step when profiling
* It leaks the sampling abstraction, requiring other parts of the code to bump certain values and forces RecordFunction to lazily initialize.
This change introduces a multi-level cache which can (in the common case) quickly determine which callbacks *will* run, rather than if callbacks *might* run. This means that rather than call `shouldRunRecordFunction`, we can simply get the callbacks for an invocation and check if they are empty. (And completely removes the pre-sampling heuristic.) Another major benefit of the new cache structure is that it allows thread-safe registration and unregistration of global callbacks.
It's worth briefly discussing how this maintains eager performance. In the standard eager case (only sampling callbacks registered) the cache first checks that the global callbacks haven't changed (atomic read), decrements a counter to see if a sampling callback fired, and then returns the active callbacks which is simply a SmallVector of pointer pairs and a couple POD values (scope, needs inputs/outputs/ids). The biggest cost according to perf is the SmallVector logic; we could consider adopting a hard limit on active callbacks; more than half a dozen callbacks *running* in a single step would be quite a lot. But the total cost relative to `PYTORCH_DISABLE_PER_OP_PROFILING` is only ~10ns, so debatable if it's worth it to switch to `std::array`.
The primary change is in `record_function.cpp`, which has a more detailed description of the new cache structure. `record_function.h` has some minor changes to align with the new calling convention and the remaining files are simply changes to the call sites.
Future work:
* RecordFunction no longer needs to be lazily initialized.
* We can deprecate the disable/reenable APIs, since we can not safely add and remove global callbacks.
Test Plan:
I tested eager mode performance using the overhead benchmark and found that the non-profiled path was unaffected. However the no-op observer dropped from 0.41us to 0.37us (0.25us if no observers are active) which is about 1/3rd reduction in the cost of the callback selection machinery.
I also added several C++ unit tests, as the core RecordFunction machinery (especially sampling) was largely untested.
Reviewed By: swolchok, davidberard98
Differential Revision: D35276158
fbshipit-source-id: 35135f444724fba4eb97c0ae7f3f710f0f9016fd
(cherry picked from commit 9e359b87422c18f2a195185f32e7e85c82f956fd)
2022-04-19 20:40:00 +00:00
|
|
|
expected_counts.push_back(running_count);
|
|
|
|
|
}
|
|
|
|
|
expected_counts.push_back(++running_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto start_callback =
|
|
|
|
|
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
|
|
|
++sample_test_counter;
|
|
|
|
|
return nullptr;
|
|
|
|
|
};
|
|
|
|
|
auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
|
|
|
|
|
|
|
|
|
|
auto handle = at::addThreadLocalCallback(
|
|
|
|
|
at::RecordFunctionCallback(start_callback, end_callback)
|
|
|
|
|
.samplingProb(p)
|
|
|
|
|
.scopes({at::RecordScope::FUNCTION}));
|
|
|
|
|
|
|
|
|
|
for (const auto i : c10::irange(expected_counts.size())) {
|
|
|
|
|
RECORD_FUNCTION("Test", {});
|
|
|
|
|
EXPECT_EQ(sample_test_counter, expected_counts[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
at::removeCallback(handle);
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Validate sampling against a simple reference implementation for a complex set
|
|
|
|
|
// of registered callbacks.
|
|
|
|
|
TEST(RecordFunctionTest, MultipleCallbacks) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
|
|
|
|
|
uint32_t seed = 54321;
|
|
|
|
|
|
|
|
|
|
std::mt19937 generator;
|
|
|
|
|
generator.seed(seed);
|
|
|
|
|
|
|
|
|
|
auto sample = [&](double p) {
|
|
|
|
|
return (p < 1.0 ? std::geometric_distribution<int>(p)(generator) : 0) + 1;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::array<double, 4> probabilities{0.1, 1.0, 1.0, 0.3};
|
|
|
|
|
std::array<int, 4> next_call;
|
|
|
|
|
std::array<int, 4> counts;
|
|
|
|
|
static std::array<int, 4> counts_from_rec_fn;
|
|
|
|
|
counts_from_rec_fn.fill(0);
|
|
|
|
|
|
|
|
|
|
auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
|
|
|
|
|
|
|
|
|
|
#define REGISTER_CALLBACK(register_fn, index) \
|
|
|
|
|
register_fn(at::RecordFunctionCallback( \
|
|
|
|
|
[](const at::RecordFunction& fn) \
|
|
|
|
|
-> std::unique_ptr<at::ObserverContext> { \
|
|
|
|
|
++counts_from_rec_fn[index]; \
|
|
|
|
|
return nullptr; \
|
|
|
|
|
}, \
|
|
|
|
|
end_callback) \
|
|
|
|
|
.samplingProb(probabilities[index]) \
|
|
|
|
|
.scopes({at::RecordScope::FUNCTION}))
|
|
|
|
|
|
|
|
|
|
REGISTER_CALLBACK(at::addGlobalCallback, 0);
|
|
|
|
|
REGISTER_CALLBACK(at::addGlobalCallback, 1);
|
|
|
|
|
REGISTER_CALLBACK(at::addThreadLocalCallback, 2);
|
|
|
|
|
|
|
|
|
|
// The RecordFunction machinery will rebuild callbacks whenever a new observer
|
|
|
|
|
// is registered, so we need to wait until the last callback to seed the
|
|
|
|
|
// random number generator.
|
|
|
|
|
at::set_record_function_seed_for_testing(seed);
|
|
|
|
|
REGISTER_CALLBACK(at::addThreadLocalCallback, 3);
|
|
|
|
|
#undef REGISTER_CALLBACK
|
|
|
|
|
|
|
|
|
|
for (const auto i : c10::irange(probabilities.size())) {
|
|
|
|
|
next_call[i] = sample(probabilities[i]);
|
|
|
|
|
}
|
|
|
|
|
|
2024-12-19 00:18:08 +00:00
|
|
|
for ([[maybe_unused]] const auto i : c10::irange(50)) {
|
[RecordFunction] More effecient machinery to determine which callbacks to run. (#75807)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75807
There is a tension in RecordFunction between two use cases:
1) In the normal eager path we don't run any callbacks, so we need to bail out of the profiling path as soon as possible to minimize eager overhead.
2) When profiling we want to determine which callbacks to run as efficiently as possible to minimize instrumentation overhead.
The confounding factor in all of this is sampling callbacks because they change which callbacks will run on each call, even in steady state operation. This has traditionally been handled with a two stage procedure: first we flip a coin to determine if a sampled callback *might* run. If false (which it usually is), do nothing. This solves (1). If true, check to see if we need to build the full callback set or if it was a false positive. This procedure has two negative effects:
* It forces us to rebuild the set of callbacks to run on every step when profiling
* It leaks the sampling abstraction, requiring other parts of the code to bump certain values and forces RecordFunction to lazily initialize.
This change introduces a multi-level cache which can (in the common case) quickly determine which callbacks *will* run, rather than if callbacks *might* run. This means that rather than call `shouldRunRecordFunction`, we can simply get the callbacks for an invocation and check if they are empty. (And completely removes the pre-sampling heuristic.) Another major benefit of the new cache structure is that it allows thread-safe registration and unregistration of global callbacks.
It's worth briefly discussing how this maintains eager performance. In the standard eager case (only sampling callbacks registered) the cache first checks that the global callbacks haven't changed (atomic read), decrements a counter to see if a sampling callback fired, and then returns the active callbacks which is simply a SmallVector of pointer pairs and a couple POD values (scope, needs inputs/outputs/ids). The biggest cost according to perf is the SmallVector logic; we could consider adopting a hard limit on active callbacks; more than half a dozen callbacks *running* in a single step would be quite a lot. But the total cost relative to `PYTORCH_DISABLE_PER_OP_PROFILING` is only ~10ns, so debatable if it's worth it to switch to `std::array`.
The primary change is in `record_function.cpp`, which has a more detailed description of the new cache structure. `record_function.h` has some minor changes to align with the new calling convention and the remaining files are simply changes to the call sites.
Future work:
* RecordFunction no longer needs to be lazily initialized.
* We can deprecate the disable/reenable APIs, since we can not safely add and remove global callbacks.
Test Plan:
I tested eager mode performance using the overhead benchmark and found that the non-profiled path was unaffected. However the no-op observer dropped from 0.41us to 0.37us (0.25us if no observers are active) which is about 1/3rd reduction in the cost of the callback selection machinery.
I also added several C++ unit tests, as the core RecordFunction machinery (especially sampling) was largely untested.
Reviewed By: swolchok, davidberard98
Differential Revision: D35276158
fbshipit-source-id: 35135f444724fba4eb97c0ae7f3f710f0f9016fd
(cherry picked from commit 9e359b87422c18f2a195185f32e7e85c82f956fd)
2022-04-19 20:40:00 +00:00
|
|
|
RECORD_FUNCTION("Test", {});
|
|
|
|
|
for (const auto j : c10::irange(next_call.size())) {
|
|
|
|
|
if (!(--next_call[j])) {
|
|
|
|
|
++counts[j];
|
|
|
|
|
next_call[j] = sample(probabilities[j]);
|
|
|
|
|
}
|
|
|
|
|
EXPECT_EQ(counts[j], counts_from_rec_fn[j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|
2024-10-17 18:48:48 +00:00
|
|
|
|
|
|
|
|
// Test that KwargsOnly callbacks are run in USER_SCOPE.
|
|
|
|
|
TEST(RecordFunctionTest, KwargsOnly) {
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
static const std::unordered_map<std::string, c10::IValue> myMap = {
|
|
|
|
|
{"a", 1}, {"b", 2.5}};
|
|
|
|
|
|
|
|
|
|
#define REGISTER_CALLBACK() \
|
|
|
|
|
at::addThreadLocalCallback( \
|
|
|
|
|
at::RecordFunctionCallback( \
|
|
|
|
|
[](const at::RecordFunction& fn) \
|
|
|
|
|
-> std::unique_ptr<at::ObserverContext> { \
|
|
|
|
|
EXPECT_EQ(myMap, fn.kwinputs()); \
|
|
|
|
|
return nullptr; \
|
|
|
|
|
}, \
|
|
|
|
|
[](const at::RecordFunction& fn, at::ObserverContext*) {}) \
|
|
|
|
|
.needsInputs(true) \
|
|
|
|
|
.scopes({at::RecordScope::USER_SCOPE}))
|
|
|
|
|
|
|
|
|
|
REGISTER_CALLBACK();
|
|
|
|
|
#undef REGISTER_CALLBACK
|
|
|
|
|
|
|
|
|
|
RECORD_USER_SCOPE_WITH_KWARGS_ONLY("Test", &myMap);
|
|
|
|
|
|
|
|
|
|
at::clearCallbacks();
|
|
|
|
|
ASSERT_FALSE(at::hasCallbacks());
|
|
|
|
|
}
|