pytorch/test/cpp/lazy/test_lazy_graph_executor.cpp
rpsilva 4abff4b271 Introduce cache clearing APIs for the lazy graph executor (#144489)
This PR introduces two new methods to the LazyGraphExecutor class:

- ClearComputationCache(): Allows clearing the entire computation cache.
- RemoveFromComputationCache(hash): Enables removal of specific cache entries based on their hash.

The main objective is to expose cache management functionality for debugging cache hits and misses across different computations. For instance:
- Reset the cache state in tests, allowing reuse of the same computation client to evaluate cache logic consistently.
- Selectively remove cache entries to analyze the impact on subsequent computations.
- Improve observability into the cache behavior, aiding in the investigation of cache-related issues or optimizations.

On the XLA lazy graph executor, we want to run a series of tests that modify some parts of the HLO module proto of the computation, and we need a means to ensure that the hash is agnostic to some elements (OpMetadata in the XLA proto data). Hence, it would be easy to parameterize the test, clear the cache and validate that the resulting hash is the same between runs. Otherwise, we'd need to hardcode the resulting serialized hash.

Simultaneously, **another motivation**, is that users could also clear some computation hashes for an added flexibility in their applications, by introducing their own custom strategies for maintaining the cache (without relying on the default LRU).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144489
Approved by: https://github.com/wconstab
2025-01-29 17:38:01 +00:00

97 lines
3.2 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/lazy/test_lazy_ops_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <vector>
namespace torch {
namespace lazy {
namespace {
class LazyGraphExecutorTest : public ::testing::Test {
protected:
void SetUp() override {
executor_ = LazyGraphExecutor::Get();
}
using CachedComputationType = LazyGraphExecutor::CachedComputation;
std::shared_ptr<CachedComputationType> GetCachedComputation(hash_t hash) {
return executor_->GetComputationCache()->Get(hash);
}
void EnsureComputationIsCached(
std::vector<LazyTensorPtr>& tensors,
hash_t hash) {
// Force computation to be cached by syncing the tensors.
executor_->SyncTensorsGraph(
&tensors, /* devices */ {}, /* wait */ true, /* sync_ltc_data */ true);
// Ensure that the computation cache entry exists.
auto cached_computation = GetCachedComputation(hash);
EXPECT_NE(cached_computation, nullptr)
<< "Computation should be cached after sync";
}
LazyGraphExecutor* executor_;
};
TEST_F(LazyGraphExecutorTest, TestClearComputationCache) {
ForEachDevice([&](const torch::Device& device) {
torch::Tensor tensor_a =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor tensor_b =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device);
torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device);
torch::Tensor result = xla_tensor_a + xla_tensor_b;
std::vector<LazyTensorPtr> tensors{TryGetLtcTensor(result)};
hash_t hash = executor_->GetGraphHash(tensors);
EnsureComputationIsCached(tensors, hash);
EXPECT_EQ(executor_->GetComputationCache()->Numel(), 1);
// Clear the entire computation cache.
executor_->ClearComputationCache();
// Ensure that there are no cache entries.
EXPECT_EQ(executor_->GetComputationCache()->Numel(), 0);
auto cached_computation = GetCachedComputation(hash);
EXPECT_EQ(cached_computation, nullptr)
<< "Cache entry should be null after clearing";
});
}
TEST_F(LazyGraphExecutorTest, TestRemoveSpecificCacheEntry) {
ForEachDevice([&](const torch::Device& device) {
torch::Tensor tensor_a =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor tensor_b =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device);
torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device);
torch::Tensor result = xla_tensor_a + xla_tensor_b;
std::vector<LazyTensorPtr> tensors{TryGetLtcTensor(result)};
hash_t hash = executor_->GetGraphHash(tensors);
EnsureComputationIsCached(tensors, hash);
// Remove a specific cache entry.
executor_->RemoveFromComputationCache(hash);
// Ensure that the cache entry has been removed.
auto cached_computation = GetCachedComputation(hash);
EXPECT_EQ(cached_computation, nullptr)
<< "Cache entry should be null after removal";
// Attempting to remove again should not do anything.
executor_->RemoveFromComputationCache(hash);
});
}
} // namespace
} // namespace lazy
} // namespace torch