#include #include #include #include namespace torch { namespace lazy { namespace { class LazyGraphExecutorTest : public ::testing::Test { protected: void SetUp() override { executor_ = LazyGraphExecutor::Get(); } using CachedComputationType = LazyGraphExecutor::CachedComputation; std::shared_ptr GetCachedComputation(hash_t hash) { return executor_->GetComputationCache()->Get(hash); } void EnsureComputationIsCached( std::vector& tensors, hash_t hash) { // Force computation to be cached by syncing the tensors. executor_->SyncTensorsGraph( &tensors, /* devices */ {}, /* wait */ true, /* sync_ltc_data */ true); // Ensure that the computation cache entry exists. auto cached_computation = GetCachedComputation(hash); EXPECT_NE(cached_computation, nullptr) << "Computation should be cached after sync"; } LazyGraphExecutor* executor_; }; TEST_F(LazyGraphExecutorTest, TestClearComputationCache) { ForEachDevice([&](const torch::Device& device) { torch::Tensor tensor_a = torch::rand({2, 2}, at::TensorOptions(torch::kFloat)); torch::Tensor tensor_b = torch::rand({2, 2}, at::TensorOptions(torch::kFloat)); torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device); torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device); torch::Tensor result = xla_tensor_a + xla_tensor_b; std::vector tensors{TryGetLtcTensor(result)}; hash_t hash = executor_->GetGraphHash(tensors); EnsureComputationIsCached(tensors, hash); EXPECT_EQ(executor_->GetComputationCache()->Numel(), 1); // Clear the entire computation cache. executor_->ClearComputationCache(); // Ensure that there are no cache entries. EXPECT_EQ(executor_->GetComputationCache()->Numel(), 0); auto cached_computation = GetCachedComputation(hash); EXPECT_EQ(cached_computation, nullptr) << "Cache entry should be null after clearing"; }); } TEST_F(LazyGraphExecutorTest, TestRemoveSpecificCacheEntry) { ForEachDevice([&](const torch::Device& device) { torch::Tensor tensor_a = torch::rand({2, 2}, at::TensorOptions(torch::kFloat)); torch::Tensor tensor_b = torch::rand({2, 2}, at::TensorOptions(torch::kFloat)); torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device); torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device); torch::Tensor result = xla_tensor_a + xla_tensor_b; std::vector tensors{TryGetLtcTensor(result)}; hash_t hash = executor_->GetGraphHash(tensors); EnsureComputationIsCached(tensors, hash); // Remove a specific cache entry. executor_->RemoveFromComputationCache(hash); // Ensure that the cache entry has been removed. auto cached_computation = GetCachedComputation(hash); EXPECT_EQ(cached_computation, nullptr) << "Cache entry should be null after removal"; // Attempting to remove again should not do anything. executor_->RemoveFromComputationCache(hash); }); } } // namespace } // namespace lazy } // namespace torch