mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enabled BFloat16 storage (#21523)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21523 ghimport-source-id: 698b3cbd6b21c09b9ff8bf8011980df8e35c33b0 Test Plan: Imported from OSS Differential Revision: D15819368 Pulled By: izdeby fbshipit-source-id: f6b3bba7b3ca8ee677bd80a231dbb3920c07d61c
This commit is contained in:
parent
932ec8aa9f
commit
3a8d7463bd
59 changed files with 264 additions and 26 deletions
|
|
@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) {
|
|||
case ScalarType::Bool:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case ScalarType::BFloat16:
|
||||
throw std::logic_error("BFloat16 is not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
throw std::logic_error("QInt8 is not supported by dlpack");
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -209,6 +209,7 @@ scalar_types = [
|
|||
('Long', 'int64_t', 'Long', False),
|
||||
('Short', 'int16_t', 'Long', False),
|
||||
('Half', 'Half', 'Double', True),
|
||||
('BFloat16', 'BFloat16', 'BFloat16AccrealNotDefined', True),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ INSTALL(FILES
|
|||
THFilePrivate.h
|
||||
${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h
|
||||
THGenerateAllTypes.h
|
||||
THGenerateBFloat16Type.h
|
||||
THGenerateBoolType.h
|
||||
THGenerateDoubleType.h
|
||||
THGenerateFloatType.h
|
||||
|
|
|
|||
19
aten/src/TH/THGenerateBFloat16Type.h
Normal file
19
aten/src/TH/THGenerateBFloat16Type.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#error "You must define TH_GENERIC_FILE before including THGenerateBFloat16Type.h"
|
||||
#endif
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#define scalar_t at::BFloat16
|
||||
#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val)
|
||||
#define Real BFloat16
|
||||
#define TH_REAL_IS_BFLOAT16
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef scalar_t
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_BFLOAT16
|
||||
#undef TH_CONVERT_ACCREAL_TO_REAL
|
||||
|
||||
#ifndef THGenerateManyTypes
|
||||
#undef TH_GENERIC_FILE
|
||||
#endif
|
||||
|
|
@ -15,6 +15,9 @@
|
|||
#include <TH/generic/THStorage.cpp>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
#include <TH/generic/THStorage.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <TH/generic/THStorageCopy.cpp>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
|
|
@ -27,6 +30,9 @@
|
|||
#include <TH/generic/THStorageCopy.cpp>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
#include <TH/generic/THStorageCopy.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
THStorage* THStorage_new(caffe2::TypeMeta data_type) {
|
||||
THStorage* storage = c10::make_intrusive<at::StorageImpl>(
|
||||
data_type,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@
|
|||
#include <TH/generic/THStorage.h>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
#include <TH/generic/THStorage.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <TH/generic/THStorageCopy.h>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
|
|
@ -29,5 +32,8 @@
|
|||
#include <TH/generic/THStorageCopy.h>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
#include <TH/generic/THStorageCopy.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
// This exists to have a data-type independent way of freeing (necessary for THPPointer).
|
||||
TH_API void THStorage_free(THStorage *storage);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
#include <TH/generic/THTensor.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensor.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <ATen/native/Resize.h>
|
||||
|
||||
#include <numeric>
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@
|
|||
#include <TH/generic/THTensor.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensor.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
/* random numbers */
|
||||
#include <TH/generic/THTensorRandom.h>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
|
|
|||
|
|
@ -130,3 +130,6 @@ TH_CPP_API c10::optional<std::vector<int64_t>> THTensor_compute_stride(
|
|||
|
||||
#include <TH/generic/THTensor.hpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <TH/generic/THTensor.hpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@
|
|||
#define THIntStorage THStorage
|
||||
#define THLongStorage THStorage
|
||||
#define THBoolStorage THStorage
|
||||
#define THBFloat16Storage THStorage
|
||||
|
||||
TH_API scalar_t* THStorage_(data)(const THStorage*);
|
||||
TH_API ptrdiff_t THStorage_(size)(const THStorage*);
|
||||
|
|
|
|||
|
|
@ -38,5 +38,6 @@ IMPLEMENT_THStorage_COPY(Float)
|
|||
IMPLEMENT_THStorage_COPY(Double)
|
||||
IMPLEMENT_THStorage_COPY(Half)
|
||||
IMPLEMENT_THStorage_COPY(Bool)
|
||||
IMPLEMENT_THStorage_COPY(BFloat16)
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -15,5 +15,6 @@ TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src
|
|||
TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
|
||||
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
|
||||
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
|
||||
TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src);
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#define THIntTensor THTensor
|
||||
#define THLongTensor THTensor
|
||||
#define THBoolTensor THTensor
|
||||
#define THBFloat16Tensor THTensor
|
||||
|
||||
/**** access methods ****/
|
||||
TH_API THStorage* THTensor_(storage)(const THTensor *self);
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ INSTALL(FILES
|
|||
THCDeviceTensorUtils.cuh
|
||||
THCDeviceTensorUtils-inl.cuh
|
||||
THCGenerateAllTypes.h
|
||||
THCGenerateBFloat16Type.h
|
||||
THCGenerateBoolType.h
|
||||
THCGenerateByteType.h
|
||||
THCGenerateCharType.h
|
||||
|
|
|
|||
25
aten/src/THC/THCGenerateBFloat16Type.h
Normal file
25
aten/src/THC/THCGenerateBFloat16Type.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#ifndef THC_GENERIC_FILE
|
||||
#error "You must define THC_GENERIC_FILE before including THCGenerateBFloat16Type.h"
|
||||
#endif
|
||||
#include <c10/util/BFloat16.h>
|
||||
|
||||
#define scalar_t at::BFloat16
|
||||
#define Real BFloat16
|
||||
|
||||
#define CReal CudaBFloat16
|
||||
|
||||
#define THC_REAL_IS_BFLOAT16
|
||||
#line 1 THC_GENERIC_FILE
|
||||
#include THC_GENERIC_FILE
|
||||
#undef scalar_t
|
||||
#undef Real
|
||||
|
||||
#undef CReal
|
||||
|
||||
#undef THC_REAL_IS_BFLOAT16
|
||||
|
||||
#ifndef THCGenerateAllTypes
|
||||
#ifndef THCGenerateFloatTypes
|
||||
#undef THC_GENERIC_FILE
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -11,6 +11,9 @@
|
|||
#include <THC/generic/THCStorage.cpp>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCStorage.cpp>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size)
|
||||
|
|
|
|||
|
|
@ -14,3 +14,6 @@
|
|||
|
||||
#include <THC/generic/THCStorage.cu>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCStorage.cu>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -12,4 +12,7 @@
|
|||
#include <THC/generic/THCStorage.h>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCStorage.h>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -8,3 +8,6 @@
|
|||
|
||||
#include <THC/generic/THCStorageCopy.cpp>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCStorageCopy.cpp>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -11,3 +11,6 @@
|
|||
|
||||
#include <THC/generic/THCStorageCopy.cu>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCStorageCopy.cu>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -11,4 +11,7 @@
|
|||
#include <THC/generic/THCStorageCopy.h>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCStorageCopy.h>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@
|
|||
#include <THC/generic/THCTensor.cpp>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCTensor.cpp>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#include <THC/THCTensorInfo.cuh>
|
||||
|
||||
#include <ATen/native/cuda/Resize.cuh>
|
||||
|
|
@ -66,6 +69,8 @@ THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) {
|
|||
return THCudaDoubleTensor_new(state);
|
||||
case at::ScalarType::Bool:
|
||||
return THCudaBoolTensor_new(state);
|
||||
case at::ScalarType::BFloat16:
|
||||
return THCudaBFloat16Tensor_new(state);
|
||||
default:
|
||||
AT_ERROR("unexpected ScalarType: ", toString(scalar_type));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,3 +6,6 @@
|
|||
|
||||
#include <THC/generic/THCTensor.cu>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCTensor.cu>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -19,4 +19,8 @@ typedef struct THC_CLASS THCDescBuff
|
|||
|
||||
#include <THC/generic/THCTensor.h>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCTensor.h>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -59,3 +59,6 @@ THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor*
|
|||
|
||||
#include <THC/generic/THCTensor.hpp>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCTensor.hpp>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <THC/THCNumerics.cuh>
|
||||
#include <THC/THCTensorCopy.hpp>
|
||||
#include <type_traits>
|
||||
#include <c10/util/BFloat16.h>
|
||||
|
||||
// Copy operator for the pointwise apply kernel
|
||||
template <typename T>
|
||||
|
|
@ -23,8 +24,18 @@ struct CopyOp <bool> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CopyOp <at::BFloat16> {
|
||||
__device__ __forceinline__ void operator()(at::BFloat16* dst, at::BFloat16* src) {
|
||||
*dst = ScalarConvert<at::BFloat16, at::BFloat16>::to(*src);
|
||||
}
|
||||
};
|
||||
|
||||
#include <THC/generic/THCTensorCopy.cu>
|
||||
#include <THC/THCGenerateAllTypes.h>
|
||||
|
||||
#include <THC/generic/THCTensorCopy.cu>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCTensorCopy.cu>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -12,4 +12,7 @@
|
|||
#include <THC/generic/THCTensorCopy.h>
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#include <THC/generic/THCTensorCopy.h>
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -6,15 +6,16 @@
|
|||
|
||||
// These used to be distinct types; for some measure of backwards compatibility and documentation
|
||||
// alias these to the single THCStorage type.
|
||||
#define THCudaStorage THCStorage
|
||||
#define THCudaDoubleStorage THCStorage
|
||||
#define THCudaHalfStorage THCStorage
|
||||
#define THCudaByteStorage THCStorage
|
||||
#define THCudaCharStorage THCStorage
|
||||
#define THCudaShortStorage THCStorage
|
||||
#define THCudaIntStorage THCStorage
|
||||
#define THCudaLongStorage THCStorage
|
||||
#define THCudaBoolStorage THCStorage
|
||||
#define THCudaStorage THCStorage
|
||||
#define THCudaDoubleStorage THCStorage
|
||||
#define THCudaHalfStorage THCStorage
|
||||
#define THCudaByteStorage THCStorage
|
||||
#define THCudaCharStorage THCStorage
|
||||
#define THCudaShortStorage THCStorage
|
||||
#define THCudaIntStorage THCStorage
|
||||
#define THCudaLongStorage THCStorage
|
||||
#define THCudaBoolStorage THCStorage
|
||||
#define THCudaBFloat16Storage THCStorage
|
||||
|
||||
THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*);
|
||||
THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*);
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Float)
|
|||
TH_CUDA_STORAGE_IMPLEMENT_COPY(Half)
|
||||
TH_CUDA_STORAGE_IMPLEMENT_COPY(Double)
|
||||
TH_CUDA_STORAGE_IMPLEMENT_COPY(Bool)
|
||||
TH_CUDA_STORAGE_IMPLEMENT_COPY(BFloat16)
|
||||
|
||||
void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src)
|
||||
{
|
||||
|
|
@ -67,6 +68,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float)
|
|||
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half)
|
||||
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double)
|
||||
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Bool)
|
||||
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(BFloat16)
|
||||
|
||||
#undef TH_CUDA_STORAGE_IMPLEMENT_COPY
|
||||
#undef TH_CUDA_STORAGE_IMPLEMENT_COPYTO
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,) // i.e. float
|
|||
THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double)
|
||||
THC_CUDA_STORAGE_IMPLEMENT_COPY(Half,Half)
|
||||
THC_CUDA_STORAGE_IMPLEMENT_COPY(Bool,Bool)
|
||||
THC_CUDA_STORAGE_IMPLEMENT_COPY(BFloat16,BFloat16)
|
||||
|
||||
#undef THC_CUDA_STORAGE_IMPLEMENT_COPY
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct
|
|||
THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src);
|
||||
THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src);
|
||||
THC_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src);
|
||||
THC_API void THCStorage_(copyBFloat16)(THCState *state, THCStorage *storage, struct THBFloat16Storage *src);
|
||||
|
||||
THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src);
|
||||
THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src);
|
||||
|
|
@ -25,6 +26,7 @@ THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, st
|
|||
THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src);
|
||||
THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src);
|
||||
THC_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src);
|
||||
THC_API void THCStorage_(copyCudaBFloat16)(THCState *state, THCStorage *storage, struct THCudaBFloat16Storage *src);
|
||||
|
||||
THC_API void TH_CONCAT_2(THByteStorage_copyCuda , Real)(THCState *state, THByteStorage *self, struct THCStorage *src);
|
||||
THC_API void TH_CONCAT_2(THCharStorage_copyCuda , Real)(THCState *state, THCharStorage *self, struct THCStorage *src);
|
||||
|
|
@ -35,6 +37,7 @@ THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloa
|
|||
THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src);
|
||||
THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src);
|
||||
THC_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src);
|
||||
THC_API void TH_CONCAT_2(THBFloat16Storage_copyCuda, Real)(THCState *state, THBFloat16Storage *self, struct THCStorage *src);
|
||||
|
||||
THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src);
|
||||
THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#define THCudaIntTensor THCTensor
|
||||
#define THCudaLongTensor THCTensor
|
||||
#define THCudaBoolTensor THCTensor
|
||||
#define THCudaBFloat16Tensor THCTensor
|
||||
|
||||
/**** access methods ****/
|
||||
THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self);
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ namespace c10 {
|
|||
_(bool, Bool, i) /* 11 */ \
|
||||
_(c10::qint8, QInt8, i) /* 12 */ \
|
||||
_(c10::quint8, QUInt8, i) /* 13 */ \
|
||||
_(c10::qint32, QInt32, i) /* 14 */
|
||||
_(c10::qint32, QInt32, i) /* 14 */ \
|
||||
_(c10::BFloat16, BFloat16, d) /* 15 */
|
||||
|
||||
// If you want to support ComplexHalf for real, replace occurrences
|
||||
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ namespace {
|
|||
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
|
||||
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
|
||||
|
||||
// The relative error should be less than 1/(2^7) since bfloat16
|
||||
// The relative error should be less than 1/(2^7) since BFloat16
|
||||
// has 7 bits mantissa.
|
||||
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,11 +53,10 @@ struct alignas(2) BFloat16 {
|
|||
C10_HOST_DEVICE BFloat16() = default;
|
||||
#else
|
||||
BFloat16() = default;
|
||||
|
||||
#endif
|
||||
|
||||
explicit inline C10_HOST_DEVICE BFloat16(float value);
|
||||
explicit inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE BFloat16(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -8975,6 +8975,7 @@ class _TestTorchMixin(object):
|
|||
float = torch.FloatStorage().element_size()
|
||||
double = torch.DoubleStorage().element_size()
|
||||
bool = torch.BoolStorage().element_size()
|
||||
bfloat16 = torch.BFloat16Storage().element_size()
|
||||
|
||||
self.assertEqual(byte, torch.ByteTensor().element_size())
|
||||
self.assertEqual(char, torch.CharTensor().element_size())
|
||||
|
|
@ -8983,6 +8984,7 @@ class _TestTorchMixin(object):
|
|||
self.assertEqual(long, torch.LongTensor().element_size())
|
||||
self.assertEqual(float, torch.FloatTensor().element_size())
|
||||
self.assertEqual(double, torch.DoubleTensor().element_size())
|
||||
self.assertEqual(bool, torch.BoolTensor().element_size())
|
||||
|
||||
self.assertGreater(byte, 0)
|
||||
self.assertGreater(char, 0)
|
||||
|
|
@ -8992,6 +8994,7 @@ class _TestTorchMixin(object):
|
|||
self.assertGreater(float, 0)
|
||||
self.assertGreater(double, 0)
|
||||
self.assertGreater(bool, 0)
|
||||
self.assertGreater(bfloat16, 0)
|
||||
|
||||
# These tests are portable, not necessarily strict for your system.
|
||||
self.assertEqual(byte, 1)
|
||||
|
|
@ -10303,6 +10306,13 @@ class _TestTorchMixin(object):
|
|||
self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
|
||||
self.assertIs(halfStorage.dtype, torch.float16)
|
||||
|
||||
bfloat16Storage = storage.bfloat16()
|
||||
self.assertEqual(bfloat16Storage.size(), 6)
|
||||
self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4])
|
||||
self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage')
|
||||
self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4])
|
||||
self.assertIs(bfloat16Storage.dtype, torch.bfloat16)
|
||||
|
||||
longStorage = storage.long()
|
||||
self.assertEqual(longStorage.size(), 6)
|
||||
self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
|
||||
|
|
@ -10419,6 +10429,8 @@ class _TestTorchMixin(object):
|
|||
obj.__repr__()
|
||||
str(obj)
|
||||
for t in torch._storage_classes:
|
||||
if t == torch.BFloat16Storage:
|
||||
continue # Fix once fill is enabled for bfloat16
|
||||
if t.is_cuda and not torch.cuda.is_available():
|
||||
continue
|
||||
if t == torch.BoolStorage or t == torch.cuda.BoolStorage:
|
||||
|
|
|
|||
|
|
@ -222,6 +222,11 @@ class ByteStorage(_C.ByteStorageBase, _StorageBase):
|
|||
class BoolStorage(_C.BoolStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class BFloat16Storage(_C.BFloat16StorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class QUInt8Storage(_C.QUInt8StorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
|
@ -235,7 +240,7 @@ class QInt32Storage(_C.QInt32StorageBase, _StorageBase):
|
|||
_storage_classes = {
|
||||
DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
|
||||
CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage,
|
||||
QInt32Storage
|
||||
QInt32Storage, BFloat16Storage
|
||||
}
|
||||
|
||||
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
||||
|
|
@ -286,6 +291,7 @@ del CharStorageBase
|
|||
del ByteStorageBase
|
||||
del BoolStorageBase
|
||||
del QUInt8StorageBase
|
||||
del BFloat16StorageBase
|
||||
|
||||
################################################################################
|
||||
# Import most common subpackages
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ storage_classes = [
|
|||
'CharStorageBase',
|
||||
'ByteStorageBase',
|
||||
'BoolStorageBase',
|
||||
'BFloat16StorageBase',
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
|
|||
THPQUInt8Storage_postInit(module);
|
||||
THPQInt8Storage_postInit(module);
|
||||
THPQInt32Storage_postInit(module);
|
||||
THPBFloat16Storage_postInit(module);
|
||||
THPAutograd_initFunctions();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
|
@ -514,6 +515,7 @@ bool THCPShortStorage_init(PyObject *module);
|
|||
bool THCPCharStorage_init(PyObject *module);
|
||||
bool THCPByteStorage_init(PyObject *module);
|
||||
bool THCPBoolStorage_init(PyObject *module);
|
||||
bool THCPBFloat16Storage_init(PyObject *module);
|
||||
|
||||
void THCPStream_init(PyObject *module);
|
||||
void THCPEvent_init(PyObject *module);
|
||||
|
|
@ -536,6 +538,17 @@ void init__THCUNN(PyObject*);
|
|||
|
||||
}} // namespace torch::nn
|
||||
|
||||
bool THDPDoubleStorage_init(PyObject *module);
|
||||
bool THDPFloatStorage_init(PyObject *module);
|
||||
//bool THDPHalfStorage_init(PyObject *module);
|
||||
bool THDPLongStorage_init(PyObject *module);
|
||||
bool THDPIntStorage_init(PyObject *module);
|
||||
bool THDPShortStorage_init(PyObject *module);
|
||||
bool THDPCharStorage_init(PyObject *module);
|
||||
bool THDPByteStorage_init(PyObject *module);
|
||||
bool THDPBoolStorage_init(PyObject *module);
|
||||
bool THDPBFloat16Storage_init(PyObject *module);
|
||||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
// TODO: Refactor this in some less manual way
|
||||
|
|
@ -664,6 +677,7 @@ PyObject* initModule() {
|
|||
ASSERT_TRUE(THPQUInt8Storage_init(module));
|
||||
ASSERT_TRUE(THPQInt8Storage_init(module));
|
||||
ASSERT_TRUE(THPQInt32Storage_init(module));
|
||||
ASSERT_TRUE(THPBFloat16Storage_init(module));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// This will only initialise base classes and attach them to library namespace
|
||||
|
|
@ -679,6 +693,7 @@ PyObject* initModule() {
|
|||
ASSERT_TRUE(THCPCharStorage_init(module));
|
||||
ASSERT_TRUE(THCPByteStorage_init(module));
|
||||
ASSERT_TRUE(THCPBoolStorage_init(module));
|
||||
ASSERT_TRUE(THCPBFloat16Storage_init(module));
|
||||
|
||||
THCPStream_init(module);
|
||||
THCPEvent_init(module);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@
|
|||
#include <torch/csrc/generic/Storage.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <torch/csrc/generic/Storage.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <torch/csrc/generic/Storage.cpp>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@
|
|||
PyObject_IsInstance(obj, THPQInt8StorageClass)
|
||||
#define THPQInt32Storage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THPQInt32StorageClass)
|
||||
#define THPBFloat16Storage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THPBFloat16StorageClass)
|
||||
|
||||
|
||||
#define THPDoubleStorage_CData(obj) (obj)->cdata
|
||||
|
|
@ -43,6 +45,7 @@
|
|||
#define THPQUInt8Storage_CData(obj) (obj)->cdata
|
||||
#define THPQInt8Storage_CData(obj) (obj)->cdata
|
||||
#define THPQInt32Storage_CData(obj) (obj)->cdata
|
||||
#define THPBFloat16Storage_CData(obj) (obj)->cdata
|
||||
|
||||
#ifdef _THP_CORE
|
||||
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
|
||||
|
|
@ -58,6 +61,9 @@
|
|||
#include <torch/csrc/generic/Storage.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <torch/csrc/generic/Storage.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <torch/csrc/generic/Storage.h>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#include <torch/csrc/byte_order.h>
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
|
@ -140,6 +140,15 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s
|
|||
}
|
||||
}
|
||||
|
||||
void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len)
|
||||
{
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
uint16_t x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src));
|
||||
std::memcpy(&dst[i], &x, sizeof(dst[i]));
|
||||
src += sizeof(uint16_t);
|
||||
}
|
||||
}
|
||||
|
||||
void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len)
|
||||
{
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <THHalf.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
|
||||
enum THPByteOrder {
|
||||
THP_LITTLE_ENDIAN = 0,
|
||||
|
|
@ -19,6 +20,7 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s
|
|||
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
|
||||
void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
|
||||
void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len);
|
||||
void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len);
|
||||
|
||||
void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len);
|
||||
void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len);
|
||||
|
|
|
|||
|
|
@ -316,6 +316,7 @@ static PyObject * THCPModule_initExtension(PyObject *self)
|
|||
THCPCharStorage_postInit(m);
|
||||
THCPByteStorage_postInit(m);
|
||||
THCPBoolStorage_postInit(m);
|
||||
THCPBFloat16Storage_postInit(m);
|
||||
|
||||
bool has_half = true;
|
||||
|
||||
|
|
|
|||
|
|
@ -21,3 +21,6 @@
|
|||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -23,15 +23,18 @@
|
|||
PyObject_IsInstance(obj, THCPByteStorageClass)
|
||||
#define THCPBoolStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THCPBoolStorageClass)
|
||||
#define THCPBFloat16Storage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THCPBFloat16StorageClass)
|
||||
|
||||
#define THCPDoubleStorage_CData(obj) (obj)->cdata
|
||||
#define THCPFloatStorage_CData(obj) (obj)->cdata
|
||||
#define THCPLongStorage_CData(obj) (obj)->cdata
|
||||
#define THCPIntStorage_CData(obj) (obj)->cdata
|
||||
#define THCPShortStorage_CData(obj) (obj)->cdata
|
||||
#define THCPCharStorage_CData(obj) (obj)->cdata
|
||||
#define THCPByteStorage_CData(obj) (obj)->cdata
|
||||
#define THCPBoolStorage_CData(obj) (obj)->cdata
|
||||
#define THCPDoubleStorage_CData(obj) (obj)->cdata
|
||||
#define THCPFloatStorage_CData(obj) (obj)->cdata
|
||||
#define THCPLongStorage_CData(obj) (obj)->cdata
|
||||
#define THCPIntStorage_CData(obj) (obj)->cdata
|
||||
#define THCPShortStorage_CData(obj) (obj)->cdata
|
||||
#define THCPCharStorage_CData(obj) (obj)->cdata
|
||||
#define THCPByteStorage_CData(obj) (obj)->cdata
|
||||
#define THCPBoolStorage_CData(obj) (obj)->cdata
|
||||
#define THCPBFloat16Storage_CData(obj) (obj)->cdata
|
||||
|
||||
#ifdef _THP_CORE
|
||||
#define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType)
|
||||
|
|
@ -46,4 +49,7 @@
|
|||
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -12,3 +12,6 @@
|
|||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
|
|
|||
|
|
@ -9,4 +9,7 @@
|
|||
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
|
||||
// whatever the current stream of the device the input is associated with was.
|
||||
|
|
|
|||
|
|
@ -18,4 +18,7 @@
|
|||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
|
||||
#include <THC/THCGenerateBoolType.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
|
||||
#include <THC/THCGenerateBFloat16Type.h>
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -317,6 +317,7 @@ void THPStorage_(initCopyMethods)()
|
|||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBFloat16StorageType, h, &THWStorage_(copyBFloat16));
|
||||
#ifdef THC_GENERIC_FILE
|
||||
// copy from GPU types
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
|
||||
|
|
@ -328,6 +329,7 @@ void THPStorage_(initCopyMethods)()
|
|||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, h, &THWStorage_(copyCudaBool));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBFloat16StorageType, h, &THWStorage_(copyCudaBFloat16));
|
||||
// add CPU <- GPU copies to base type
|
||||
/// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
|
||||
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
|
||||
|
|
@ -342,6 +344,7 @@ void THPStorage_(initCopyMethods)()
|
|||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, b, &THCpuStorage_(copyCudaBool));
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBFloat16StorageType, b, &THCpuStorage_(copyCudaBFloat16));
|
||||
#undef THCpuStorage
|
||||
#undef THCpuStorage_
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -170,6 +170,8 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
|
|||
THP_decodeInt64Buffer((int64_t*) THWStorage_(data)(storage), src + offset, byte_order, count);
|
||||
#elif defined(TH_REAL_IS_HALF)
|
||||
THP_decodeHalfBuffer(THWStorage_(data)(storage), src + offset, byte_order, count);
|
||||
#elif defined(TH_REAL_IS_BFLOAT16)
|
||||
THP_decodeBFloat16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count);
|
||||
#elif defined(TH_REAL_IS_FLOAT)
|
||||
THP_decodeFloatBuffer(THWStorage_(data)(storage), src + offset, byte_order, count);
|
||||
#elif defined(TH_REAL_IS_DOUBLE)
|
||||
|
|
|
|||
|
|
@ -183,6 +183,9 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) {
|
|||
#include <torch/csrc/generic/serialization.cpp>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#include <torch/csrc/generic/serialization.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <torch/csrc/generic/serialization.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@
|
|||
#include <torch/csrc/generic/serialization.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
#include <torch/csrc/generic/serialization.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <torch/csrc/generic/serialization.h>
|
||||
#include <TH/THGenerateQTypes.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@
|
|||
#include <torch/csrc/generic/utils.cpp>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#include <torch/csrc/generic/utils.cpp>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/generic/utils.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
|
|
|||
|
|
@ -81,6 +81,10 @@
|
|||
#define THPHalfUtils_unpackReal(object) (at::Half)THPUtils_unpackReal_FLOAT(object)
|
||||
#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value)
|
||||
#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value)
|
||||
#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object)
|
||||
#define THPBFloat16Utils_unpackReal(object) (at::BFloat16)THPUtils_unpackReal_FLOAT(object)
|
||||
#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value)
|
||||
#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value)
|
||||
|
||||
#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object)
|
||||
#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object)
|
||||
|
|
@ -149,6 +153,9 @@ struct THPUtils_typeTraits {};
|
|||
#include <torch/csrc/generic/utils.h>
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#include <torch/csrc/generic/utils.h>
|
||||
#include <TH/THGenerateBFloat16Type.h>
|
||||
|
||||
#include <torch/csrc/generic/utils.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ static std::pair<std::string, std::string> getDtypeNames(
|
|||
return std::make_pair("quint8", "");
|
||||
case at::ScalarType::QInt32:
|
||||
return std::make_pair("qint32", "");
|
||||
case at::ScalarType::BFloat16:
|
||||
return std::make_pair("bfloat16", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def find_cuda_windows_lib():
|
|||
# Override the default search process
|
||||
# Fixes https://github.com/pytorch/pytorch/issues/20202
|
||||
# The libary selection will be done in these directories one by one
|
||||
# 1. [Package Root]\Lib
|
||||
# 1. [Package Root]\Lib
|
||||
# That's where our libraries are in, which should be loaded first.
|
||||
# 2. [Python Root]\Library\bin
|
||||
# That's where `cudatoolkit` store the cuda libraries.
|
||||
|
|
@ -596,7 +596,7 @@ def _dummy_type(name):
|
|||
|
||||
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
||||
# Define dummy base classes
|
||||
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool']:
|
||||
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool', 'BFloat16']:
|
||||
storage_name = 'Cuda{0}StorageBase'.format(t)
|
||||
tensor_name = 'Cuda{0}TensorBase'.format(t)
|
||||
|
||||
|
|
@ -661,6 +661,10 @@ class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
|
|||
class BoolStorage(_CudaBase, torch._C.CudaBoolStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class BFloat16Storage(_CudaBase, torch._C.CudaBFloat16StorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
torch._storage_classes.add(DoubleStorage)
|
||||
torch._storage_classes.add(FloatStorage)
|
||||
torch._storage_classes.add(LongStorage)
|
||||
|
|
@ -670,6 +674,7 @@ torch._storage_classes.add(CharStorage)
|
|||
torch._storage_classes.add(ByteStorage)
|
||||
torch._storage_classes.add(HalfStorage)
|
||||
torch._storage_classes.add(BoolStorage)
|
||||
torch._storage_classes.add(BFloat16Storage)
|
||||
|
||||
from . import sparse # noqa: F401
|
||||
from . import profiler # noqa: F401
|
||||
|
|
|
|||
|
|
@ -87,6 +87,10 @@ class _StorageBase(object):
|
|||
"""Casts this storage to bool type"""
|
||||
return self.type(type(self).__module__ + '.BoolStorage')
|
||||
|
||||
def bfloat16(self):
|
||||
"""Casts this storage to bfloat16 type"""
|
||||
return self.type(type(self).__module__ + '.BFloat16Storage')
|
||||
|
||||
def pin_memory(self):
|
||||
"""Copies the storage to pinned memory, if it's not already pinned."""
|
||||
if self.is_cuda:
|
||||
|
|
|
|||
Loading…
Reference in a new issue