From 3a8d7463bd024dcbd00d04a323cdb10dfe867f18 Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Tue, 9 Jul 2019 21:47:47 -0700 Subject: [PATCH] Enabled BFloat16 storage (#21523) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21523 ghimport-source-id: 698b3cbd6b21c09b9ff8bf8011980df8e35c33b0 Test Plan: Imported from OSS Differential Revision: D15819368 Pulled By: izdeby fbshipit-source-id: f6b3bba7b3ca8ee677bd80a231dbb3920c07d61c --- aten/src/ATen/DLConvertor.cpp | 3 +++ aten/src/ATen/function_wrapper.py | 1 + aten/src/TH/CMakeLists.txt | 1 + aten/src/TH/THGenerateBFloat16Type.h | 19 +++++++++++++++++++ aten/src/TH/THStorageFunctions.cpp | 6 ++++++ aten/src/TH/THStorageFunctions.h | 6 ++++++ aten/src/TH/THTensor.cpp | 3 +++ aten/src/TH/THTensor.h | 3 +++ aten/src/TH/THTensor.hpp | 3 +++ aten/src/TH/generic/THStorage.h | 1 + aten/src/TH/generic/THStorageCopy.cpp | 1 + aten/src/TH/generic/THStorageCopy.h | 1 + aten/src/TH/generic/THTensor.h | 1 + aten/src/THC/CMakeLists.txt | 1 + aten/src/THC/THCGenerateBFloat16Type.h | 25 +++++++++++++++++++++++++ aten/src/THC/THCStorage.cpp | 3 +++ aten/src/THC/THCStorage.cu | 3 +++ aten/src/THC/THCStorage.h | 3 +++ aten/src/THC/THCStorageCopy.cpp | 3 +++ aten/src/THC/THCStorageCopy.cu | 3 +++ aten/src/THC/THCStorageCopy.h | 3 +++ aten/src/THC/THCTensor.cpp | 5 +++++ aten/src/THC/THCTensor.cu | 3 +++ aten/src/THC/THCTensor.h | 4 ++++ aten/src/THC/THCTensor.hpp | 3 +++ aten/src/THC/THCTensorCopy.cu | 11 +++++++++++ aten/src/THC/THCTensorCopy.h | 3 +++ aten/src/THC/generic/THCStorage.h | 19 ++++++++++--------- aten/src/THC/generic/THCStorageCopy.cpp | 2 ++ aten/src/THC/generic/THCStorageCopy.cu | 1 + aten/src/THC/generic/THCStorageCopy.h | 3 +++ aten/src/THC/generic/THCTensor.h | 1 + c10/core/ScalarType.h | 3 ++- c10/test/util/bfloat16_test.cpp | 2 +- c10/util/BFloat16.h | 5 ++--- test/test_torch.py | 12 ++++++++++++ torch/__init__.py | 8 +++++++- torch/_storage_docs.py | 1 + torch/csrc/Module.cpp | 15 +++++++++++++++ torch/csrc/Storage.cpp | 3 +++ torch/csrc/Storage.h | 6 ++++++ torch/csrc/byte_order.cpp | 11 ++++++++++- torch/csrc/byte_order.h | 2 ++ torch/csrc/cuda/Module.cpp | 1 + torch/csrc/cuda/Storage.cpp | 3 +++ torch/csrc/cuda/Storage.h | 22 ++++++++++++++-------- torch/csrc/cuda/serialization.cpp | 3 +++ torch/csrc/cuda/serialization.h | 3 +++ torch/csrc/cuda/utils.cpp | 3 +++ torch/csrc/cuda/utils.h | 3 +++ torch/csrc/generic/Storage.cpp | 3 +++ torch/csrc/generic/StorageMethods.cpp | 2 ++ torch/csrc/serialization.cpp | 3 +++ torch/csrc/serialization.h | 3 +++ torch/csrc/utils.cpp | 3 +++ torch/csrc/utils.h | 7 +++++++ torch/csrc/utils/tensor_dtypes.cpp | 2 ++ torch/cuda/__init__.py | 9 +++++++-- torch/storage.py | 4 ++++ 59 files changed, 264 insertions(+), 26 deletions(-) create mode 100644 aten/src/TH/THGenerateBFloat16Type.h create mode 100644 aten/src/THC/THCGenerateBFloat16Type.h diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index e75db146ed2..2cdd4e044f3 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) { case ScalarType::Bool: dtype.code = DLDataTypeCode::kDLUInt; break; + case ScalarType::BFloat16: + throw std::logic_error("BFloat16 is not supported by dlpack"); + break; case ScalarType::QInt8: throw std::logic_error("QInt8 is not supported by dlpack"); break; diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 17e6ae552e3..193793da613 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -209,6 +209,7 @@ scalar_types = [ ('Long', 'int64_t', 'Long', False), ('Short', 'int16_t', 'Long', False), ('Half', 'Half', 'Double', True), + ('BFloat16', 'BFloat16', 'BFloat16AccrealNotDefined', True), ] diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt index df9e66f260e..98927c6cab1 100644 --- a/aten/src/TH/CMakeLists.txt +++ b/aten/src/TH/CMakeLists.txt @@ -64,6 +64,7 @@ INSTALL(FILES THFilePrivate.h ${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h THGenerateAllTypes.h + THGenerateBFloat16Type.h THGenerateBoolType.h THGenerateDoubleType.h THGenerateFloatType.h diff --git a/aten/src/TH/THGenerateBFloat16Type.h b/aten/src/TH/THGenerateBFloat16Type.h new file mode 100644 index 00000000000..40c34b66e41 --- /dev/null +++ b/aten/src/TH/THGenerateBFloat16Type.h @@ -0,0 +1,19 @@ +#ifndef TH_GENERIC_FILE +#error "You must define TH_GENERIC_FILE before including THGenerateBFloat16Type.h" +#endif + +#include +#define scalar_t at::BFloat16 +#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val) +#define Real BFloat16 +#define TH_REAL_IS_BFLOAT16 +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef scalar_t +#undef Real +#undef TH_REAL_IS_BFLOAT16 +#undef TH_CONVERT_ACCREAL_TO_REAL + +#ifndef THGenerateManyTypes +#undef TH_GENERIC_FILE +#endif diff --git a/aten/src/TH/THStorageFunctions.cpp b/aten/src/TH/THStorageFunctions.cpp index e25bfd45efd..9fcc19c9fb7 100644 --- a/aten/src/TH/THStorageFunctions.cpp +++ b/aten/src/TH/THStorageFunctions.cpp @@ -15,6 +15,9 @@ #include #include +#include +#include + #include #include @@ -27,6 +30,9 @@ #include #include +#include +#include + THStorage* THStorage_new(caffe2::TypeMeta data_type) { THStorage* storage = c10::make_intrusive( data_type, diff --git a/aten/src/TH/THStorageFunctions.h b/aten/src/TH/THStorageFunctions.h index 246f740232a..adaccb435af 100644 --- a/aten/src/TH/THStorageFunctions.h +++ b/aten/src/TH/THStorageFunctions.h @@ -17,6 +17,9 @@ #include #include +#include +#include + #include #include @@ -29,5 +32,8 @@ #include #include +#include +#include + // This exists to have a data-type independent way of freeing (necessary for THPPointer). TH_API void THStorage_free(THStorage *storage); diff --git a/aten/src/TH/THTensor.cpp b/aten/src/TH/THTensor.cpp index 6959d3ea4b9..067f8202ff4 100644 --- a/aten/src/TH/THTensor.cpp +++ b/aten/src/TH/THTensor.cpp @@ -9,6 +9,9 @@ #include #include +#include +#include + #include #include diff --git a/aten/src/TH/THTensor.h b/aten/src/TH/THTensor.h index 42abf788b6a..021ffc3c8ff 100644 --- a/aten/src/TH/THTensor.h +++ b/aten/src/TH/THTensor.h @@ -16,6 +16,9 @@ #include #include +#include +#include + /* random numbers */ #include #include diff --git a/aten/src/TH/THTensor.hpp b/aten/src/TH/THTensor.hpp index a9c89f22218..b0a0603f542 100644 --- a/aten/src/TH/THTensor.hpp +++ b/aten/src/TH/THTensor.hpp @@ -130,3 +130,6 @@ TH_CPP_API c10::optional> THTensor_compute_stride( #include #include + +#include +#include diff --git a/aten/src/TH/generic/THStorage.h b/aten/src/TH/generic/THStorage.h index 2e432c1daf8..1d0f9420400 100644 --- a/aten/src/TH/generic/THStorage.h +++ b/aten/src/TH/generic/THStorage.h @@ -34,6 +34,7 @@ #define THIntStorage THStorage #define THLongStorage THStorage #define THBoolStorage THStorage +#define THBFloat16Storage THStorage TH_API scalar_t* THStorage_(data)(const THStorage*); TH_API ptrdiff_t THStorage_(size)(const THStorage*); diff --git a/aten/src/TH/generic/THStorageCopy.cpp b/aten/src/TH/generic/THStorageCopy.cpp index 0ce5035edf0..c5eda5699f5 100644 --- a/aten/src/TH/generic/THStorageCopy.cpp +++ b/aten/src/TH/generic/THStorageCopy.cpp @@ -38,5 +38,6 @@ IMPLEMENT_THStorage_COPY(Float) IMPLEMENT_THStorage_COPY(Double) IMPLEMENT_THStorage_COPY(Half) IMPLEMENT_THStorage_COPY(Bool) +IMPLEMENT_THStorage_COPY(BFloat16) #endif diff --git a/aten/src/TH/generic/THStorageCopy.h b/aten/src/TH/generic/THStorageCopy.h index 0301fc6a489..4797ba6761f 100644 --- a/aten/src/TH/generic/THStorageCopy.h +++ b/aten/src/TH/generic/THStorageCopy.h @@ -15,5 +15,6 @@ TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src); TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src); TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src); +TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src); #endif diff --git a/aten/src/TH/generic/THTensor.h b/aten/src/TH/generic/THTensor.h index cd750d0cfa9..8f97ba40a85 100644 --- a/aten/src/TH/generic/THTensor.h +++ b/aten/src/TH/generic/THTensor.h @@ -19,6 +19,7 @@ #define THIntTensor THTensor #define THLongTensor THTensor #define THBoolTensor THTensor +#define THBFloat16Tensor THTensor /**** access methods ****/ TH_API THStorage* THTensor_(storage)(const THTensor *self); diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 6acc9e5489e..6ce3c9d8fee 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -85,6 +85,7 @@ INSTALL(FILES THCDeviceTensorUtils.cuh THCDeviceTensorUtils-inl.cuh THCGenerateAllTypes.h + THCGenerateBFloat16Type.h THCGenerateBoolType.h THCGenerateByteType.h THCGenerateCharType.h diff --git a/aten/src/THC/THCGenerateBFloat16Type.h b/aten/src/THC/THCGenerateBFloat16Type.h new file mode 100644 index 00000000000..957288a8b4b --- /dev/null +++ b/aten/src/THC/THCGenerateBFloat16Type.h @@ -0,0 +1,25 @@ +#ifndef THC_GENERIC_FILE +#error "You must define THC_GENERIC_FILE before including THCGenerateBFloat16Type.h" +#endif +#include + +#define scalar_t at::BFloat16 +#define Real BFloat16 + +#define CReal CudaBFloat16 + +#define THC_REAL_IS_BFLOAT16 +#line 1 THC_GENERIC_FILE +#include THC_GENERIC_FILE +#undef scalar_t +#undef Real + +#undef CReal + +#undef THC_REAL_IS_BFLOAT16 + +#ifndef THCGenerateAllTypes +#ifndef THCGenerateFloatTypes +#undef THC_GENERIC_FILE +#endif +#endif diff --git a/aten/src/THC/THCStorage.cpp b/aten/src/THC/THCStorage.cpp index af7117925e4..c28d8d0253e 100644 --- a/aten/src/THC/THCStorage.cpp +++ b/aten/src/THC/THCStorage.cpp @@ -11,6 +11,9 @@ #include #include +#include +#include + #include void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) diff --git a/aten/src/THC/THCStorage.cu b/aten/src/THC/THCStorage.cu index 01d54622d91..b0541345445 100644 --- a/aten/src/THC/THCStorage.cu +++ b/aten/src/THC/THCStorage.cu @@ -14,3 +14,6 @@ #include #include + +#include +#include diff --git a/aten/src/THC/THCStorage.h b/aten/src/THC/THCStorage.h index 19216edcff7..4d4d9abf09e 100644 --- a/aten/src/THC/THCStorage.h +++ b/aten/src/THC/THCStorage.h @@ -12,4 +12,7 @@ #include #include +#include +#include + #endif diff --git a/aten/src/THC/THCStorageCopy.cpp b/aten/src/THC/THCStorageCopy.cpp index 2c15088d32a..818316a480d 100644 --- a/aten/src/THC/THCStorageCopy.cpp +++ b/aten/src/THC/THCStorageCopy.cpp @@ -8,3 +8,6 @@ #include #include + +#include +#include diff --git a/aten/src/THC/THCStorageCopy.cu b/aten/src/THC/THCStorageCopy.cu index 9252e721f74..ebc7f9019e5 100644 --- a/aten/src/THC/THCStorageCopy.cu +++ b/aten/src/THC/THCStorageCopy.cu @@ -11,3 +11,6 @@ #include #include + +#include +#include diff --git a/aten/src/THC/THCStorageCopy.h b/aten/src/THC/THCStorageCopy.h index db971944bb1..3a6dc0c7f69 100644 --- a/aten/src/THC/THCStorageCopy.h +++ b/aten/src/THC/THCStorageCopy.h @@ -11,4 +11,7 @@ #include #include +#include +#include + #endif diff --git a/aten/src/THC/THCTensor.cpp b/aten/src/THC/THCTensor.cpp index 171ec945f57..be5c40c6fb3 100644 --- a/aten/src/THC/THCTensor.cpp +++ b/aten/src/THC/THCTensor.cpp @@ -10,6 +10,9 @@ #include #include +#include +#include + #include #include @@ -66,6 +69,8 @@ THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) { return THCudaDoubleTensor_new(state); case at::ScalarType::Bool: return THCudaBoolTensor_new(state); + case at::ScalarType::BFloat16: + return THCudaBFloat16Tensor_new(state); default: AT_ERROR("unexpected ScalarType: ", toString(scalar_type)); } diff --git a/aten/src/THC/THCTensor.cu b/aten/src/THC/THCTensor.cu index cc25d141ea2..c49bde31a09 100644 --- a/aten/src/THC/THCTensor.cu +++ b/aten/src/THC/THCTensor.cu @@ -6,3 +6,6 @@ #include #include + +#include +#include diff --git a/aten/src/THC/THCTensor.h b/aten/src/THC/THCTensor.h index 9670eb3b62a..c8ebd5d3d66 100644 --- a/aten/src/THC/THCTensor.h +++ b/aten/src/THC/THCTensor.h @@ -19,4 +19,8 @@ typedef struct THC_CLASS THCDescBuff #include #include + +#include +#include + #endif diff --git a/aten/src/THC/THCTensor.hpp b/aten/src/THC/THCTensor.hpp index 3162506e073..b543c0af25c 100644 --- a/aten/src/THC/THCTensor.hpp +++ b/aten/src/THC/THCTensor.hpp @@ -59,3 +59,6 @@ THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor* #include #include + +#include +#include diff --git a/aten/src/THC/THCTensorCopy.cu b/aten/src/THC/THCTensorCopy.cu index 571d0e1ffbb..e5b3ad33803 100644 --- a/aten/src/THC/THCTensorCopy.cu +++ b/aten/src/THC/THCTensorCopy.cu @@ -3,6 +3,7 @@ #include #include #include +#include // Copy operator for the pointwise apply kernel template @@ -23,8 +24,18 @@ struct CopyOp { } }; +template <> +struct CopyOp { + __device__ __forceinline__ void operator()(at::BFloat16* dst, at::BFloat16* src) { + *dst = ScalarConvert::to(*src); + } +}; + #include #include #include #include + +#include +#include diff --git a/aten/src/THC/THCTensorCopy.h b/aten/src/THC/THCTensorCopy.h index ec8ede70fe0..9366c37b04f 100644 --- a/aten/src/THC/THCTensorCopy.h +++ b/aten/src/THC/THCTensorCopy.h @@ -12,4 +12,7 @@ #include #include +#include +#include + #endif diff --git a/aten/src/THC/generic/THCStorage.h b/aten/src/THC/generic/THCStorage.h index 5fdf41d5604..cbcdaf5f3ef 100644 --- a/aten/src/THC/generic/THCStorage.h +++ b/aten/src/THC/generic/THCStorage.h @@ -6,15 +6,16 @@ // These used to be distinct types; for some measure of backwards compatibility and documentation // alias these to the single THCStorage type. -#define THCudaStorage THCStorage -#define THCudaDoubleStorage THCStorage -#define THCudaHalfStorage THCStorage -#define THCudaByteStorage THCStorage -#define THCudaCharStorage THCStorage -#define THCudaShortStorage THCStorage -#define THCudaIntStorage THCStorage -#define THCudaLongStorage THCStorage -#define THCudaBoolStorage THCStorage +#define THCudaStorage THCStorage +#define THCudaDoubleStorage THCStorage +#define THCudaHalfStorage THCStorage +#define THCudaByteStorage THCStorage +#define THCudaCharStorage THCStorage +#define THCudaShortStorage THCStorage +#define THCudaIntStorage THCStorage +#define THCudaLongStorage THCStorage +#define THCudaBoolStorage THCStorage +#define THCudaBFloat16Storage THCStorage THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*); THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*); diff --git a/aten/src/THC/generic/THCStorageCopy.cpp b/aten/src/THC/generic/THCStorageCopy.cpp index c132defafbe..1cc31be5011 100644 --- a/aten/src/THC/generic/THCStorageCopy.cpp +++ b/aten/src/THC/generic/THCStorageCopy.cpp @@ -34,6 +34,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Float) TH_CUDA_STORAGE_IMPLEMENT_COPY(Half) TH_CUDA_STORAGE_IMPLEMENT_COPY(Double) TH_CUDA_STORAGE_IMPLEMENT_COPY(Bool) +TH_CUDA_STORAGE_IMPLEMENT_COPY(BFloat16) void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src) { @@ -67,6 +68,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float) TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half) TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double) TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Bool) +TH_CUDA_STORAGE_IMPLEMENT_COPYTO(BFloat16) #undef TH_CUDA_STORAGE_IMPLEMENT_COPY #undef TH_CUDA_STORAGE_IMPLEMENT_COPYTO diff --git a/aten/src/THC/generic/THCStorageCopy.cu b/aten/src/THC/generic/THCStorageCopy.cu index d372563d98a..18a5c89897c 100644 --- a/aten/src/THC/generic/THCStorageCopy.cu +++ b/aten/src/THC/generic/THCStorageCopy.cu @@ -29,6 +29,7 @@ THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,) // i.e. float THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double) THC_CUDA_STORAGE_IMPLEMENT_COPY(Half,Half) THC_CUDA_STORAGE_IMPLEMENT_COPY(Bool,Bool) +THC_CUDA_STORAGE_IMPLEMENT_COPY(BFloat16,BFloat16) #undef THC_CUDA_STORAGE_IMPLEMENT_COPY diff --git a/aten/src/THC/generic/THCStorageCopy.h b/aten/src/THC/generic/THCStorageCopy.h index 2375e186b0f..ffb37a048d1 100644 --- a/aten/src/THC/generic/THCStorageCopy.h +++ b/aten/src/THC/generic/THCStorageCopy.h @@ -15,6 +15,7 @@ THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src); THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src); THC_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src); +THC_API void THCStorage_(copyBFloat16)(THCState *state, THCStorage *storage, struct THBFloat16Storage *src); THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src); THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src); @@ -25,6 +26,7 @@ THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, st THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src); THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src); THC_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src); +THC_API void THCStorage_(copyCudaBFloat16)(THCState *state, THCStorage *storage, struct THCudaBFloat16Storage *src); THC_API void TH_CONCAT_2(THByteStorage_copyCuda , Real)(THCState *state, THByteStorage *self, struct THCStorage *src); THC_API void TH_CONCAT_2(THCharStorage_copyCuda , Real)(THCState *state, THCharStorage *self, struct THCStorage *src); @@ -35,6 +37,7 @@ THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloa THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src); THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src); THC_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src); +THC_API void TH_CONCAT_2(THBFloat16Storage_copyCuda, Real)(THCState *state, THBFloat16Storage *self, struct THCStorage *src); THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src); THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src); diff --git a/aten/src/THC/generic/THCTensor.h b/aten/src/THC/generic/THCTensor.h index 4a2fcc9415b..749eedc62dc 100644 --- a/aten/src/THC/generic/THCTensor.h +++ b/aten/src/THC/generic/THCTensor.h @@ -15,6 +15,7 @@ #define THCudaIntTensor THCTensor #define THCudaLongTensor THCTensor #define THCudaBoolTensor THCTensor +#define THCudaBFloat16Tensor THCTensor /**** access methods ****/ THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self); diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 6cb2102b2b0..c16b7aacac0 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -31,7 +31,8 @@ namespace c10 { _(bool, Bool, i) /* 11 */ \ _(c10::qint8, QInt8, i) /* 12 */ \ _(c10::quint8, QUInt8, i) /* 13 */ \ - _(c10::qint32, QInt32, i) /* 14 */ + _(c10::qint32, QInt32, i) /* 14 */ \ + _(c10::BFloat16, BFloat16, d) /* 15 */ // If you want to support ComplexHalf for real, replace occurrences // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index 9239efdc2aa..7dbc3f0d00b 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -33,7 +33,7 @@ namespace { bfloats[i].x = c10::detail::bits_from_f32(in[i]); out[i] = c10::detail::f32_from_bits(bfloats[i].x); - // The relative error should be less than 1/(2^7) since bfloat16 + // The relative error should be less than 1/(2^7) since BFloat16 // has 7 bits mantissa. EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); } diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index aa739209447..32ba979248b 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -53,11 +53,10 @@ struct alignas(2) BFloat16 { C10_HOST_DEVICE BFloat16() = default; #else BFloat16() = default; - #endif - explicit inline C10_HOST_DEVICE BFloat16(float value); - explicit inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE BFloat16(float value); + inline C10_HOST_DEVICE operator float() const; }; } // namespace c10 diff --git a/test/test_torch.py b/test/test_torch.py index 8366ff661f2..c5e2a1c35c6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8975,6 +8975,7 @@ class _TestTorchMixin(object): float = torch.FloatStorage().element_size() double = torch.DoubleStorage().element_size() bool = torch.BoolStorage().element_size() + bfloat16 = torch.BFloat16Storage().element_size() self.assertEqual(byte, torch.ByteTensor().element_size()) self.assertEqual(char, torch.CharTensor().element_size()) @@ -8983,6 +8984,7 @@ class _TestTorchMixin(object): self.assertEqual(long, torch.LongTensor().element_size()) self.assertEqual(float, torch.FloatTensor().element_size()) self.assertEqual(double, torch.DoubleTensor().element_size()) + self.assertEqual(bool, torch.BoolTensor().element_size()) self.assertGreater(byte, 0) self.assertGreater(char, 0) @@ -8992,6 +8994,7 @@ class _TestTorchMixin(object): self.assertGreater(float, 0) self.assertGreater(double, 0) self.assertGreater(bool, 0) + self.assertGreater(bfloat16, 0) # These tests are portable, not necessarily strict for your system. self.assertEqual(byte, 1) @@ -10303,6 +10306,13 @@ class _TestTorchMixin(object): self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) self.assertIs(halfStorage.dtype, torch.float16) + bfloat16Storage = storage.bfloat16() + self.assertEqual(bfloat16Storage.size(), 6) + self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage') + self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(bfloat16Storage.dtype, torch.bfloat16) + longStorage = storage.long() self.assertEqual(longStorage.size(), 6) self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4]) @@ -10419,6 +10429,8 @@ class _TestTorchMixin(object): obj.__repr__() str(obj) for t in torch._storage_classes: + if t == torch.BFloat16Storage: + continue # Fix once fill is enabled for bfloat16 if t.is_cuda and not torch.cuda.is_available(): continue if t == torch.BoolStorage or t == torch.cuda.BoolStorage: diff --git a/torch/__init__.py b/torch/__init__.py index faaa456d205..bf9705c9131 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -222,6 +222,11 @@ class ByteStorage(_C.ByteStorageBase, _StorageBase): class BoolStorage(_C.BoolStorageBase, _StorageBase): pass + +class BFloat16Storage(_C.BFloat16StorageBase, _StorageBase): + pass + + class QUInt8Storage(_C.QUInt8StorageBase, _StorageBase): pass @@ -235,7 +240,7 @@ class QInt32Storage(_C.QInt32StorageBase, _StorageBase): _storage_classes = { DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage, - QInt32Storage + QInt32Storage, BFloat16Storage } # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() @@ -286,6 +291,7 @@ del CharStorageBase del ByteStorageBase del BoolStorageBase del QUInt8StorageBase +del BFloat16StorageBase ################################################################################ # Import most common subpackages diff --git a/torch/_storage_docs.py b/torch/_storage_docs.py index fd540329391..672bc9fc3ce 100644 --- a/torch/_storage_docs.py +++ b/torch/_storage_docs.py @@ -13,6 +13,7 @@ storage_classes = [ 'CharStorageBase', 'ByteStorageBase', 'BoolStorageBase', + 'BFloat16StorageBase', ] diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 903557c06b3..0f23e316fdd 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -125,6 +125,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag THPQUInt8Storage_postInit(module); THPQInt8Storage_postInit(module); THPQInt32Storage_postInit(module); + THPBFloat16Storage_postInit(module); THPAutograd_initFunctions(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -514,6 +515,7 @@ bool THCPShortStorage_init(PyObject *module); bool THCPCharStorage_init(PyObject *module); bool THCPByteStorage_init(PyObject *module); bool THCPBoolStorage_init(PyObject *module); +bool THCPBFloat16Storage_init(PyObject *module); void THCPStream_init(PyObject *module); void THCPEvent_init(PyObject *module); @@ -536,6 +538,17 @@ void init__THCUNN(PyObject*); }} // namespace torch::nn +bool THDPDoubleStorage_init(PyObject *module); +bool THDPFloatStorage_init(PyObject *module); +//bool THDPHalfStorage_init(PyObject *module); +bool THDPLongStorage_init(PyObject *module); +bool THDPIntStorage_init(PyObject *module); +bool THDPShortStorage_init(PyObject *module); +bool THDPCharStorage_init(PyObject *module); +bool THDPByteStorage_init(PyObject *module); +bool THDPBoolStorage_init(PyObject *module); +bool THDPBFloat16Storage_init(PyObject *module); + static std::vector methods; // TODO: Refactor this in some less manual way @@ -664,6 +677,7 @@ PyObject* initModule() { ASSERT_TRUE(THPQUInt8Storage_init(module)); ASSERT_TRUE(THPQInt8Storage_init(module)); ASSERT_TRUE(THPQInt32Storage_init(module)); + ASSERT_TRUE(THPBFloat16Storage_init(module)); #ifdef USE_CUDA // This will only initialise base classes and attach them to library namespace @@ -679,6 +693,7 @@ PyObject* initModule() { ASSERT_TRUE(THCPCharStorage_init(module)); ASSERT_TRUE(THCPByteStorage_init(module)); ASSERT_TRUE(THCPBoolStorage_init(module)); + ASSERT_TRUE(THCPBFloat16Storage_init(module)); THCPStream_init(module); THCPEvent_init(module); diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index de38daf5b66..a59ae7dd4c4 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -29,6 +29,9 @@ #include #include +#include +#include + #include #include diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 9799ae07122..5b1f49ac132 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -29,6 +29,8 @@ PyObject_IsInstance(obj, THPQInt8StorageClass) #define THPQInt32Storage_Check(obj) \ PyObject_IsInstance(obj, THPQInt32StorageClass) +#define THPBFloat16Storage_Check(obj) \ + PyObject_IsInstance(obj, THPBFloat16StorageClass) #define THPDoubleStorage_CData(obj) (obj)->cdata @@ -43,6 +45,7 @@ #define THPQUInt8Storage_CData(obj) (obj)->cdata #define THPQInt8Storage_CData(obj) (obj)->cdata #define THPQInt32Storage_CData(obj) (obj)->cdata +#define THPBFloat16Storage_CData(obj) (obj)->cdata #ifdef _THP_CORE #define THPStorageType TH_CONCAT_3(THP,Real,StorageType) @@ -58,6 +61,9 @@ #include #include +#include +#include + #include #include diff --git a/torch/csrc/byte_order.cpp b/torch/csrc/byte_order.cpp index 03567912443..cf347e6a4e8 100644 --- a/torch/csrc/byte_order.cpp +++ b/torch/csrc/byte_order.cpp @@ -1,5 +1,5 @@ #include - +#include #include #if defined(_MSC_VER) @@ -140,6 +140,15 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s } } +void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len) +{ + for (size_t i = 0; i < len; i++) { + uint16_t x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); + std::memcpy(&dst[i], &x, sizeof(dst[i])); + src += sizeof(uint16_t); + } +} + void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len) { for (size_t i = 0; i < len; i++) { diff --git a/torch/csrc/byte_order.h b/torch/csrc/byte_order.h index c9bb5a4cd22..b8b0f7a22c9 100644 --- a/torch/csrc/byte_order.h +++ b/torch/csrc/byte_order.h @@ -4,6 +4,7 @@ #include #include #include +#include enum THPByteOrder { THP_LITTLE_ENDIAN = 0, @@ -19,6 +20,7 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len); void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len); void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len); +void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len); void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len); void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len); diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index edce5c6a303..20eb2117c55 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -316,6 +316,7 @@ static PyObject * THCPModule_initExtension(PyObject *self) THCPCharStorage_postInit(m); THCPByteStorage_postInit(m); THCPBoolStorage_postInit(m); + THCPBFloat16Storage_postInit(m); bool has_half = true; diff --git a/torch/csrc/cuda/Storage.cpp b/torch/csrc/cuda/Storage.cpp index e5d281e64fd..9aea709eab9 100644 --- a/torch/csrc/cuda/Storage.cpp +++ b/torch/csrc/cuda/Storage.cpp @@ -21,3 +21,6 @@ #define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp" #include + +#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp" +#include diff --git a/torch/csrc/cuda/Storage.h b/torch/csrc/cuda/Storage.h index 19e62f1ef31..81009ffcb42 100644 --- a/torch/csrc/cuda/Storage.h +++ b/torch/csrc/cuda/Storage.h @@ -23,15 +23,18 @@ PyObject_IsInstance(obj, THCPByteStorageClass) #define THCPBoolStorage_Check(obj) \ PyObject_IsInstance(obj, THCPBoolStorageClass) +#define THCPBFloat16Storage_Check(obj) \ + PyObject_IsInstance(obj, THCPBFloat16StorageClass) -#define THCPDoubleStorage_CData(obj) (obj)->cdata -#define THCPFloatStorage_CData(obj) (obj)->cdata -#define THCPLongStorage_CData(obj) (obj)->cdata -#define THCPIntStorage_CData(obj) (obj)->cdata -#define THCPShortStorage_CData(obj) (obj)->cdata -#define THCPCharStorage_CData(obj) (obj)->cdata -#define THCPByteStorage_CData(obj) (obj)->cdata -#define THCPBoolStorage_CData(obj) (obj)->cdata +#define THCPDoubleStorage_CData(obj) (obj)->cdata +#define THCPFloatStorage_CData(obj) (obj)->cdata +#define THCPLongStorage_CData(obj) (obj)->cdata +#define THCPIntStorage_CData(obj) (obj)->cdata +#define THCPShortStorage_CData(obj) (obj)->cdata +#define THCPCharStorage_CData(obj) (obj)->cdata +#define THCPByteStorage_CData(obj) (obj)->cdata +#define THCPBoolStorage_CData(obj) (obj)->cdata +#define THCPBFloat16Storage_CData(obj) (obj)->cdata #ifdef _THP_CORE #define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType) @@ -46,4 +49,7 @@ #define THC_GENERIC_FILE "torch/csrc/generic/Storage.h" #include +#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h" +#include + #endif diff --git a/torch/csrc/cuda/serialization.cpp b/torch/csrc/cuda/serialization.cpp index e83ea6e2a8a..b878196ae0b 100644 --- a/torch/csrc/cuda/serialization.cpp +++ b/torch/csrc/cuda/serialization.cpp @@ -12,3 +12,6 @@ #define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp" #include + +#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp" +#include diff --git a/torch/csrc/cuda/serialization.h b/torch/csrc/cuda/serialization.h index 3e3eb2d090c..f0cd8438a14 100644 --- a/torch/csrc/cuda/serialization.h +++ b/torch/csrc/cuda/serialization.h @@ -9,4 +9,7 @@ #define THC_GENERIC_FILE "torch/csrc/generic/serialization.h" #include +#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h" +#include + #endif diff --git a/torch/csrc/cuda/utils.cpp b/torch/csrc/cuda/utils.cpp index 8e2980303e8..d4c44b9cee0 100644 --- a/torch/csrc/cuda/utils.cpp +++ b/torch/csrc/cuda/utils.cpp @@ -11,6 +11,9 @@ #define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp" #include +#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp" +#include + #ifdef USE_CUDA // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use // whatever the current stream of the device the input is associated with was. diff --git a/torch/csrc/cuda/utils.h b/torch/csrc/cuda/utils.h index 209b453db6c..f05918a958a 100644 --- a/torch/csrc/cuda/utils.h +++ b/torch/csrc/cuda/utils.h @@ -18,4 +18,7 @@ #define THC_GENERIC_FILE "torch/csrc/generic/utils.h" #include + +#define THC_GENERIC_FILE "torch/csrc/generic/utils.h" +#include #endif diff --git a/torch/csrc/generic/Storage.cpp b/torch/csrc/generic/Storage.cpp index f6199e2fcf4..699d8090eae 100644 --- a/torch/csrc/generic/Storage.cpp +++ b/torch/csrc/generic/Storage.cpp @@ -317,6 +317,7 @@ void THPStorage_(initCopyMethods)() THPInsertStorageCopyFunction(&THPFloatStorageType, h, &THWStorage_(copyFloat)); THPInsertStorageCopyFunction(&THPDoubleStorageType, h, &THWStorage_(copyDouble)); THPInsertStorageCopyFunction(&THPBoolStorageType, h, &THWStorage_(copyBool)); + THPInsertStorageCopyFunction(&THPBFloat16StorageType, h, &THWStorage_(copyBFloat16)); #ifdef THC_GENERIC_FILE // copy from GPU types THPInsertStorageCopyFunction(&THCPByteStorageType, h, &THWStorage_(copyCudaByte)); @@ -328,6 +329,7 @@ void THPStorage_(initCopyMethods)() THPInsertStorageCopyFunction(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble)); THPInsertStorageCopyFunction(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf)); THPInsertStorageCopyFunction(&THCPBoolStorageType, h, &THWStorage_(copyCudaBool)); + THPInsertStorageCopyFunction(&THCPBFloat16StorageType, h, &THWStorage_(copyCudaBFloat16)); // add CPU <- GPU copies to base type /// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage) #define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name) @@ -342,6 +344,7 @@ void THPStorage_(initCopyMethods)() THPInsertStorageCopyFunction(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble)); THPInsertStorageCopyFunction(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf)); THPInsertStorageCopyFunction(&THCPBoolStorageType, b, &THCpuStorage_(copyCudaBool)); + THPInsertStorageCopyFunction(&THCPBFloat16StorageType, b, &THCpuStorage_(copyCudaBFloat16)); #undef THCpuStorage #undef THCpuStorage_ #endif diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index 66a838d1fd7..994ccb2ddee 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -170,6 +170,8 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO THP_decodeInt64Buffer((int64_t*) THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_HALF) THP_decodeHalfBuffer(THWStorage_(data)(storage), src + offset, byte_order, count); +#elif defined(TH_REAL_IS_BFLOAT16) + THP_decodeBFloat16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_FLOAT) THP_decodeFloatBuffer(THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_DOUBLE) diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index 7db5c59c438..87136fb1fd6 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -183,6 +183,9 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) { #include #include +#include +#include + #include #include diff --git a/torch/csrc/serialization.h b/torch/csrc/serialization.h index 2a5bacea6dc..d1fd07f27dc 100644 --- a/torch/csrc/serialization.h +++ b/torch/csrc/serialization.h @@ -10,6 +10,9 @@ #include #include +#include +#include + #include #include diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index 74dd9b4c025..3681b1dd334 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -17,6 +17,9 @@ #include #include +#include +#include + #include #include #include diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index 7fb55113cdd..04b78299bfa 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -81,6 +81,10 @@ #define THPHalfUtils_unpackReal(object) (at::Half)THPUtils_unpackReal_FLOAT(object) #define THPHalfUtils_newReal(value) PyFloat_FromDouble(value) #define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value) +#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object) +#define THPBFloat16Utils_unpackReal(object) (at::BFloat16)THPUtils_unpackReal_FLOAT(object) +#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value) +#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value) #define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object) #define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object) @@ -149,6 +153,9 @@ struct THPUtils_typeTraits {}; #include #include +#include +#include + #include #include diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 7450b381460..967b7328536 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -47,6 +47,8 @@ static std::pair getDtypeNames( return std::make_pair("quint8", ""); case at::ScalarType::QInt32: return std::make_pair("qint32", ""); + case at::ScalarType::BFloat16: + return std::make_pair("bfloat16", ""); default: throw std::runtime_error("Unimplemented scalar type"); } diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index e8dc4035685..cc89b42f9e9 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -32,7 +32,7 @@ def find_cuda_windows_lib(): # Override the default search process # Fixes https://github.com/pytorch/pytorch/issues/20202 # The libary selection will be done in these directories one by one - # 1. [Package Root]\Lib + # 1. [Package Root]\Lib # That's where our libraries are in, which should be loaded first. # 2. [Python Root]\Library\bin # That's where `cudatoolkit` store the cuda libraries. @@ -596,7 +596,7 @@ def _dummy_type(name): if not hasattr(torch._C, 'CudaDoubleStorageBase'): # Define dummy base classes - for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool']: + for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool', 'BFloat16']: storage_name = 'Cuda{0}StorageBase'.format(t) tensor_name = 'Cuda{0}TensorBase'.format(t) @@ -661,6 +661,10 @@ class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase): class BoolStorage(_CudaBase, torch._C.CudaBoolStorageBase, _StorageBase): pass + +class BFloat16Storage(_CudaBase, torch._C.CudaBFloat16StorageBase, _StorageBase): + pass + torch._storage_classes.add(DoubleStorage) torch._storage_classes.add(FloatStorage) torch._storage_classes.add(LongStorage) @@ -670,6 +674,7 @@ torch._storage_classes.add(CharStorage) torch._storage_classes.add(ByteStorage) torch._storage_classes.add(HalfStorage) torch._storage_classes.add(BoolStorage) +torch._storage_classes.add(BFloat16Storage) from . import sparse # noqa: F401 from . import profiler # noqa: F401 diff --git a/torch/storage.py b/torch/storage.py index 68caff85a19..7727379391e 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -87,6 +87,10 @@ class _StorageBase(object): """Casts this storage to bool type""" return self.type(type(self).__module__ + '.BoolStorage') + def bfloat16(self): + """Casts this storage to bfloat16 type""" + return self.type(type(self).__module__ + '.BFloat16Storage') + def pin_memory(self): """Copies the storage to pinned memory, if it's not already pinned.""" if self.is_cuda: