Add function to materialize COW storages (#117053)

Summary: From Kurt Mohler, see https://github.com/pytorch/pytorch/pull/113396 (manually imported due to ghimport problems)

Test Plan: sandcastle, OSS CI

Differential Revision: D52610522

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117053
Approved by: https://github.com/malfet, https://github.com/kurtamohler
This commit is contained in:
Edward Yang 2024-01-10 15:34:16 +00:00 committed by PyTorch MergeBot
parent ec98df70f3
commit b4a35632f9
32 changed files with 343 additions and 50 deletions

View file

@ -311,6 +311,7 @@ struct MetaAllocator final : public at::Allocator {
DeleterFnPtr raw_deleter() const override {
return deleter;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {}
};
static MetaAllocator g_meta_alloc;

View file

@ -301,6 +301,10 @@ class CUDAHostAllocator {
}
}
void copy_data(void* dest, const void* src, std::size_t count) const {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for CUDAHostAllocator");
}
private:
void process_events() {
while (true) {
@ -496,6 +500,10 @@ struct CUDAHostAllocatorWrapper final : public at::Allocator {
&CUDAHostAllocatorDeleter,
at::DeviceType::CPU};
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
getCUDAHostAllocator().copy_data(dest, src, count);
}
};
static CUDAHostAllocatorWrapper cuda_host_allocator;

View file

@ -23,6 +23,9 @@ public:
DeleterFnPtr raw_deleter() const override {
return allocator_->raw_deleter();
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
allocator_->copy_data(dest, src, count);
}
};
}} // namespace c10::hip

View file

@ -819,6 +819,10 @@ struct TORCH_API MPSAllocator final : public IMPSAllocator {
return _getAllocImpl().format_size(size);
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
private:
bool m_has_unified_memory;
uint32_t m_usage;

View file

@ -94,13 +94,14 @@ void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes) {
if (size_bytes != 0) {
new_data = storage->allocator()->allocate(size_bytes);
}
at::DataPtr old_data = storage->set_data_ptr(std::move(new_data));
const at::DataPtr& old_data = storage->data_ptr();
const auto old_capacity = storage->nbytes();
storage->set_nbytes(size_bytes);
const auto copy_capacity = std::min(size_bytes, old_capacity);
if (old_data != nullptr && copy_capacity > 0) {
memcpy(storage->mutable_data(), old_data.get(), copy_capacity);
memcpy(new_data.get(), old_data.get(), copy_capacity);
}
storage->set_data_ptr_noswap(std::move(new_data));
storage->set_nbytes(size_bytes);
}
// Call the sparse implementation in SparseTensor.cpp directly.

View file

@ -130,6 +130,7 @@ struct ZeroTensorAllocator final : public at::Allocator {
DeleterFnPtr raw_deleter() const override {
return deleter;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {}
at::Device device_;
};

View file

@ -13,6 +13,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/atest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/basic.cpp
${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_rng_test.cpp
@ -54,6 +55,7 @@ list(APPEND ATen_CPU_TEST_SRCS
)
list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_apply_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_caching_host_allocator_test.cpp

View file

@ -0,0 +1,35 @@
#pragma once
#include <gtest/gtest.h>
#include <ATen/ATen.h>
void test_allocator_clone(c10::Allocator* allocator) {
ASSERT_TRUE(allocator != nullptr);
c10::Storage a_storage(c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
0,
allocator,
/*resizable=*/true));
c10::Storage b_storage(c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
0,
allocator,
/*resizable=*/true));
at::Tensor a = at::empty({0}, at::TensorOptions().device(a_storage.device())).set_(a_storage);
at::Tensor b = at::empty({0}, at::TensorOptions().device(b_storage.device())).set_(b_storage);
std::vector<int64_t> sizes({13, 4, 5});
at::rand_out(a, sizes);
at::rand_out(b, sizes);
ASSERT_TRUE(a_storage.nbytes() == static_cast<size_t>(a.numel() * a.element_size()));
ASSERT_TRUE(a_storage.nbytes() == b_storage.nbytes());
void* a_data_ptr = a_storage.mutable_data();
b_storage.set_data_ptr(allocator->clone(a_data_ptr, a_storage.nbytes()));
ASSERT_TRUE((a == b).all().item<bool>());
}

View file

@ -0,0 +1,10 @@
#include <gtest/gtest.h>
#include <c10/core/CPUAllocator.h>
#include <ATen/ATen.h>
#include <ATen/test/allocator_clone_test.h>
TEST(AllocatorTestCPU, test_clone) {
test_allocator_clone(c10::GetDefaultCPUAllocator());
}

View file

@ -0,0 +1,10 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/test/allocator_clone_test.h>
TEST(AllocatorTestCUDA, test_clone) {
test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
}

View file

@ -2,6 +2,8 @@
#include <ATen/ATen.h>
#include <ATen/test/allocator_clone_test.h>
using namespace at;
void XLAFree(void *ptr) {
@ -22,6 +24,9 @@ struct XLAAllocator final : public at::Allocator {
at::DeleterFnPtr raw_deleter() const override {
return &XLAFree;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
TEST(XlaTensorTest, TestNoStorage) {
@ -33,3 +38,11 @@ TEST(XlaTensorTest, TestNoStorage) {
at::Tensor t(std::move(tensor_impl));
ASSERT_TRUE(t.device() == at::Device(DeviceType::XLA, 0));
}
TEST(XlaTensorTest, test_allocator_clone) {
if (!at::hasXLA()) {
return;
}
XLAAllocator allocator;
test_allocator_clone(&allocator);
}

View file

@ -30,7 +30,6 @@ file(GLOB C10_SRCS
*.cpp
core/*.cpp
core/impl/*.cpp
core/impl/cow/*.cpp
mobile/*.cpp
macros/*.cpp
util/*.cpp
@ -39,7 +38,6 @@ file(GLOB C10_HEADERS
*.h
core/*.h
core/impl/*.h
core/impl/cow/*.h
mobile/*.h
macros/*.h
util/*.h

View file

@ -4,6 +4,23 @@
namespace c10 {
DataPtr Allocator::clone(const void* data, std::size_t n) const {
DataPtr new_data = allocate(n);
copy_data(new_data.mutable_get(), data, n);
return new_data;
}
void Allocator::default_copy_data(
void* dest,
const void* src,
std::size_t count) const {
std::memcpy(dest, src, count);
}
bool Allocator::is_simple_data_ptr(const DataPtr& data_ptr) const {
return data_ptr.get() == data_ptr.get_context();
}
static void deleteInefficientStdFunctionContext(void* ptr) {
delete static_cast<InefficientStdFunctionContext*>(ptr);
}

View file

@ -162,6 +162,21 @@ struct C10_API Allocator {
virtual DataPtr allocate(size_t n) const = 0;
// Clones an allocation that came from this allocator.
//
// To perform the copy, this function calls `copy_data`, which
// must be implemented by derived classes.
//
// Note that this explicitly ignores any context that may have been
// attached to the input data.
//
// Requires: input data was allocated by the same allocator.
DataPtr clone(const void* data, std::size_t n) const;
// Checks if DataPtr has a simple context, not wrapped with any out of the
// ordinary contexts.
virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const;
// If this returns a non nullptr, it means that allocate()
// is guaranteed to return a unique_ptr with this deleter attached;
// it means the rawAllocate and rawDeallocate APIs are safe to use.
@ -179,6 +194,22 @@ struct C10_API Allocator {
AT_ASSERT(d);
d(ptr);
}
// Copies data from one allocation to another.
// Pure virtual, so derived classes must define behavior.
// Derived class implementation can simply call `default_copy_data`
// to use `std::memcpy`.
//
// Requires: src and dest were allocated by this allocator
// Requires: src and dest both have length >= count
virtual void copy_data(void* dest, const void* src, std::size_t count)
const = 0;
protected:
// Uses `std::memcpy` to copy data.
// Child classes can use this as `copy_data` when an alternative copy
// API is not needed.
void default_copy_data(void* dest, const void* src, std::size_t count) const;
};
// This context is used to generate DataPtr which have arbitrary

View file

@ -40,6 +40,10 @@ struct C10_API DefaultCPUAllocator final : at::Allocator {
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
ProfiledCPUMemoryReporter& profiledCPUMemoryReporter() {
@ -142,6 +146,16 @@ class DefaultMobileCPUAllocator final : public at::Allocator {
DeleterFnPtr raw_deleter() const override {
return deleter;
}
bool is_simple_data_ptr(const c10::DataPtr& data_ptr) const final {
return reinterpret_cast<const uint8_t*>(data_ptr.get()) ==
reinterpret_cast<const uint8_t*>(data_ptr.get_context()) +
PreGuardBytes;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
void NoDelete(void*) {}

View file

@ -4,6 +4,8 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/SymInt.h>
#include <c10/core/impl/COW.h>
#include <c10/core/impl/COWDeleter.h>
#include <c10/core/impl/PyObjectSlot.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
@ -117,6 +119,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
}
at::DataPtr& mutable_data_ptr() {
maybe_materialize_cow();
return data_ptr_;
}
@ -126,9 +129,10 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
// Returns the previous data_ptr
at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
at::DataPtr old_data_ptr(std::move(data_ptr_));
data_ptr_ = std::move(data_ptr);
return old_data_ptr;
// We need to materialize the old COW DataPtr because it is
// being returned as mutable.
maybe_materialize_cow();
return set_data_ptr_no_materialize_cow(std::move(data_ptr));
}
void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
@ -140,6 +144,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
}
void* mutable_data() {
maybe_materialize_cow();
return data_ptr_.mutable_get();
}
@ -217,7 +222,26 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
protected:
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
// Returns the previous data_ptr. If the old data_ptr was COW,
// this avoids materializing it
at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
at::DataPtr old_data_ptr(std::move(data_ptr_));
data_ptr_ = std::move(data_ptr);
return old_data_ptr;
}
private:
// Triggers a copy if this is a copy-on-write tensor.
void maybe_materialize_cow() {
if (data_ptr_.get_deleter() == impl::cow::cow_deleter) {
impl::cow::materialize_cow_storage(*this);
}
}
DataPtr data_ptr_;
SymInt size_bytes_;
bool size_bytes_is_heap_allocated_;

View file

@ -62,7 +62,6 @@ def define_targets(rules):
exclude = [
"CPUAllocator.cpp",
"impl/alloc_cpu.cpp",
"impl/cow/*.cpp",
],
),
hdrs = rules.glob(
@ -73,7 +72,6 @@ def define_targets(rules):
exclude = [
"CPUAllocator.h",
"impl/alloc_cpu.h",
"impl/cow/*.h",
],
),
linkstatic = True,
@ -92,22 +90,6 @@ def define_targets(rules):
alwayslink = True,
)
rules.cc_library(
name = "impl_cow",
srcs = rules.glob([
"impl/cow/*.cpp",
]),
hdrs = rules.glob([
"impl/cow/*.h",
]),
deps = [
":base",
":CPUAllocator",
],
visibility = ["//c10/test:__pkg__"],
)
rules.filegroup(
name = "headers",
srcs = rules.glob(

View file

@ -1,10 +1,9 @@
#include <c10/core/impl/cow/COW.h>
#include <c10/core/impl/COW.h>
#include <c10/core/Allocator.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/alignment.h>
#include <c10/core/impl/cow/COWDeleter.h>
#include <c10/core/impl/COWDeleter.h>
#include <c10/util/Exception.h>
#include <c10/util/UniqueVoidPtr.h>
@ -30,24 +29,18 @@ at::DataPtr copy_data_ptr(at::DataPtr const& data_ptr) {
return make_data_ptr(data_ptr, *ctx);
}
bool is_simple_context(
const void* context,
const void* data,
const at::Allocator* allocator) {
if (allocator == c10::GetDefaultMobileCPUAllocator()) {
return reinterpret_cast<size_t>(data) ==
reinterpret_cast<size_t>(context) + c10::gAlignment;
} else {
return data == context;
}
}
} // namespace
bool has_simple_data_ptr(const c10::StorageImpl& storage) {
const c10::DataPtr& data_ptr = storage.data_ptr();
return is_simple_context(
data_ptr.get_context(), data_ptr.get(), storage.allocator());
const void* ctx = data_ptr.get_context();
const void* data = data_ptr.get();
const c10::Allocator* allocator = storage.allocator();
if (allocator != nullptr) {
return allocator->is_simple_data_ptr(data_ptr);
} else {
return ctx == data;
}
}
bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
@ -88,8 +81,6 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
// Case 1) We have a simple data pointer: wrap it.
std::unique_ptr<void, DeleterFnPtr> original_ctx =
storage.mutable_data_ptr().move_context();
TORCH_INTERNAL_ASSERT(is_simple_context(
original_ctx.get(), data_ptr.get(), storage.allocator()));
// Save this for the result.
new_data_ptr = make_data_ptr(
@ -117,4 +108,40 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
storage.resizable());
}
C10_API void materialize_cow_storage(StorageImpl& storage) {
const at::DataPtr& data_ptr = storage.data_ptr();
auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
TORCH_INTERNAL_ASSERT(ctx != nullptr);
auto result = ctx->decrement_refcount();
// This must be set by each branch below.
std::optional<DataPtr> new_data_ptr;
if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
// This is the only reference to the data. If there were any racing writes,
// the context ensured they finished before giving us the result.
std::unique_ptr<void, DeleterFnPtr> data =
std::get<cow::COWDeleterContext::LastReference>(std::move(result));
TORCH_INTERNAL_ASSERT(data.get() == data_ptr.get());
new_data_ptr = DataPtr(
data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
} else {
TORCH_INTERNAL_ASSERT(
std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
result));
// We don't need to consume the result, it's just a shared lock ensuring
// that the data will remain while we copy it.
new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
}
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
DataPtr old_data_ptr =
storage.set_data_ptr_no_materialize_cow(*std::move(new_data_ptr));
// The refcount of the context was already decremented above. Release the
// reference to the context so the refcount doesn't get decremented again
old_data_ptr.release_context();
}
} // namespace c10::impl::cow

View file

@ -26,4 +26,7 @@ C10_API bool has_simple_data_ptr(const c10::StorageImpl& storage);
// Check if a DataPtr is COW
C10_API bool is_cow_data_ptr(const c10::DataPtr& data_ptr);
// Eagerly copies a COW storage's data, turning it into a non-COW storage.
C10_API void materialize_cow_storage(StorageImpl& storage);
} // namespace c10::impl::cow

View file

@ -1,4 +1,4 @@
#include <c10/core/impl/cow/COWDeleter.h>
#include <c10/core/impl/COWDeleter.h>
#include <c10/util/Exception.h>
#include <mutex>

View file

@ -3243,6 +3243,10 @@ class NativeCachingAllocator : public CUDAAllocator {
std::string name() override {
return "native";
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
C10_CUDA_CHECK(
cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
}
};
NativeCachingAllocator allocator;

View file

@ -875,6 +875,10 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
std::string name() override {
return "cudaMallocAsync";
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
C10_CUDA_CHECK(
cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
}
};
CudaMallocAsyncAllocator device_allocator;

View file

@ -21,7 +21,6 @@ def define_targets(rules):
"//c10/core:base",
"//c10/util:base",
"//c10/core:CPUAllocator",
"//c10/core:impl_cow",
"@com_google_googletest//:gtest_main",
],
)

View file

@ -1,5 +1,5 @@
#include <c10/core/impl/cow/COW.h>
#include <c10/core/impl/cow/COWDeleter.h>
#include <c10/core/impl/COW.h>
#include <c10/core/impl/COWDeleter.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/StorageImpl.h>
@ -167,6 +167,84 @@ TEST(lazy_clone_storage_test, already_copy_on_write) {
ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
}
TEST(materialize_test, not_copy_on_write_context) {
StorageImpl storage(
{}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
void const* original_data = storage.data();
// Nothing to materialize.
ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
}
TEST(materialize_test, copy_on_write_single_reference) {
// A copy-on-write storage with only a single reference can just
// drop the copy-on-write context upon materialization.
std::unique_ptr<void, DeleterFnPtr> data(
new std::byte[4],
+[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
void* data_ptr = data.get();
StorageImpl storage(
{},
/*size_bytes=*/4,
at::DataPtr(
/*data=*/data_ptr,
/*ctx=*/new cow::COWDeleterContext(std::move(data)),
cow::cow_deleter,
Device(Device::Type::CPU)),
/*allocator=*/nullptr,
/*resizable=*/false);
ASSERT_THAT(storage, is_copy_on_write());
ASSERT_THAT(storage.data(), testing::Eq(data_ptr));
void const* original_data = storage.data();
// Materializes storage. Only reference, so no new allocation.
ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
// But it is no longer copy-on-write.
ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
}
bool buffers_are_equal(const void* a, const void* b, size_t nbytes) {
const char* a_ = static_cast<const char*>(a);
const char* b_ = static_cast<const char*>(b);
for (size_t idx = 0; idx < nbytes; idx++) {
if (a_[idx] != b_[idx]) {
return false;
}
}
return true;
}
TEST(materialize_test, copy_on_write) {
StorageImpl original_storage(
{}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
std::memcpy(original_storage.mutable_data(), "abcd", 4);
void const* original_data = original_storage.data();
auto new_storage = cow::lazy_clone_storage(original_storage);
ASSERT_THAT(new_storage, testing::NotNull());
auto context = new_storage->data_ptr().cast_context<cow::COWDeleterContext>(
cow::cow_deleter);
ASSERT_THAT(context, testing::NotNull());
// Materialized storage has new copy of data.
ASSERT_THAT(new_storage->mutable_data(), testing::Ne(original_data));
// But the original storage still has the original copy.
ASSERT_THAT(original_storage.data(), testing::Eq(original_data));
// And their data is the same
ASSERT_TRUE(new_storage->nbytes() == original_storage.nbytes());
ASSERT_TRUE(buffers_are_equal(
new_storage->data(), original_storage.data(), new_storage->nbytes()));
}
} // namespace
} // namespace c10::impl
// NOLINTEND(clang-analyzer-cplusplus*)

View file

@ -336,6 +336,10 @@ struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator {
return &Delete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for PinnedCPUAllocator");
}
private:
static void Delete(void* data) {
if (!data) {
@ -581,6 +585,10 @@ struct DefaultCUDAAllocator final : public at::Allocator {
return &Delete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for DefaultCUDAAllocator");
}
private:
static void Delete(void* ptr) {
// lock the mutex

View file

@ -1177,7 +1177,6 @@ def main():
"include/ATen/core/dispatch/*.h",
"include/ATen/core/op_registration/*.h",
"include/c10/core/impl/*.h",
"include/c10/core/impl/cow/*.h",
"include/c10/util/*.h",
"include/c10/cuda/*.h",
"include/c10/cuda/impl/*.h",

View file

@ -188,6 +188,10 @@ struct DummyCustomAllocator final : at::Allocator {
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
// Register our dummy allocator

View file

@ -81,6 +81,10 @@ struct DummyCustomAllocator final : at::Allocator {
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
// Register our dummy allocator

View file

@ -330,6 +330,14 @@ std::string CUDAPluggableAllocator::name() {
return "pluggable";
}
void CUDAPluggableAllocator::copy_data(
void* dest,
const void* src,
std::size_t count) const {
C10_CUDA_CHECK(
cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
}
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
current_custom_allocator;

View file

@ -122,6 +122,7 @@ struct CUDAPluggableAllocator
cudaStream_t stream,
bool p2p_enabled) override;
std::string name() override;
void copy_data(void* dest, const void* src, std::size_t count) const final;
protected:
std::function<void*(size_t, int, cudaStream_t)> alloc_fn_;