Enabled BFloat16 storage (#21523)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21523
ghimport-source-id: 698b3cbd6b21c09b9ff8bf8011980df8e35c33b0

Test Plan: Imported from OSS

Differential Revision: D15819368

Pulled By: izdeby

fbshipit-source-id: f6b3bba7b3ca8ee677bd80a231dbb3920c07d61c
This commit is contained in:
Iurii Zdebskyi 2019-07-09 21:47:47 -07:00 committed by Facebook Github Bot
parent 932ec8aa9f
commit 3a8d7463bd
59 changed files with 264 additions and 26 deletions

View file

@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) {
case ScalarType::Bool:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::BFloat16:
throw std::logic_error("BFloat16 is not supported by dlpack");
break;
case ScalarType::QInt8:
throw std::logic_error("QInt8 is not supported by dlpack");
break;

View file

@ -209,6 +209,7 @@ scalar_types = [
('Long', 'int64_t', 'Long', False),
('Short', 'int16_t', 'Long', False),
('Half', 'Half', 'Double', True),
('BFloat16', 'BFloat16', 'BFloat16AccrealNotDefined', True),
]

View file

@ -64,6 +64,7 @@ INSTALL(FILES
THFilePrivate.h
${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h
THGenerateAllTypes.h
THGenerateBFloat16Type.h
THGenerateBoolType.h
THGenerateDoubleType.h
THGenerateFloatType.h

View file

@ -0,0 +1,19 @@
#ifndef TH_GENERIC_FILE
#error "You must define TH_GENERIC_FILE before including THGenerateBFloat16Type.h"
#endif
#include <c10/util/BFloat16.h>
#define scalar_t at::BFloat16
#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val)
#define Real BFloat16
#define TH_REAL_IS_BFLOAT16
#line 1 TH_GENERIC_FILE
#include TH_GENERIC_FILE
#undef scalar_t
#undef Real
#undef TH_REAL_IS_BFLOAT16
#undef TH_CONVERT_ACCREAL_TO_REAL
#ifndef THGenerateManyTypes
#undef TH_GENERIC_FILE
#endif

View file

@ -15,6 +15,9 @@
#include <TH/generic/THStorage.cpp>
#include <TH/THGenerateQTypes.h>
#include <TH/generic/THStorage.cpp>
#include <TH/THGenerateBFloat16Type.h>
#include <TH/generic/THStorageCopy.cpp>
#include <TH/THGenerateAllTypes.h>
@ -27,6 +30,9 @@
#include <TH/generic/THStorageCopy.cpp>
#include <TH/THGenerateQTypes.h>
#include <TH/generic/THStorageCopy.cpp>
#include <TH/THGenerateBFloat16Type.h>
THStorage* THStorage_new(caffe2::TypeMeta data_type) {
THStorage* storage = c10::make_intrusive<at::StorageImpl>(
data_type,

View file

@ -17,6 +17,9 @@
#include <TH/generic/THStorage.h>
#include <TH/THGenerateQTypes.h>
#include <TH/generic/THStorage.h>
#include <TH/THGenerateBFloat16Type.h>
#include <TH/generic/THStorageCopy.h>
#include <TH/THGenerateAllTypes.h>
@ -29,5 +32,8 @@
#include <TH/generic/THStorageCopy.h>
#include <TH/THGenerateQTypes.h>
#include <TH/generic/THStorageCopy.h>
#include <TH/THGenerateBFloat16Type.h>
// This exists to have a data-type independent way of freeing (necessary for THPPointer).
TH_API void THStorage_free(THStorage *storage);

View file

@ -9,6 +9,9 @@
#include <TH/generic/THTensor.cpp>
#include <TH/THGenerateBoolType.h>
#include <TH/generic/THTensor.cpp>
#include <TH/THGenerateBFloat16Type.h>
#include <ATen/native/Resize.h>
#include <numeric>

View file

@ -16,6 +16,9 @@
#include <TH/generic/THTensor.h>
#include <TH/THGenerateBoolType.h>
#include <TH/generic/THTensor.h>
#include <TH/THGenerateBFloat16Type.h>
/* random numbers */
#include <TH/generic/THTensorRandom.h>
#include <TH/THGenerateAllTypes.h>

View file

@ -130,3 +130,6 @@ TH_CPP_API c10::optional<std::vector<int64_t>> THTensor_compute_stride(
#include <TH/generic/THTensor.hpp>
#include <TH/THGenerateBoolType.h>
#include <TH/generic/THTensor.hpp>
#include <TH/THGenerateBFloat16Type.h>

View file

@ -34,6 +34,7 @@
#define THIntStorage THStorage
#define THLongStorage THStorage
#define THBoolStorage THStorage
#define THBFloat16Storage THStorage
TH_API scalar_t* THStorage_(data)(const THStorage*);
TH_API ptrdiff_t THStorage_(size)(const THStorage*);

View file

@ -38,5 +38,6 @@ IMPLEMENT_THStorage_COPY(Float)
IMPLEMENT_THStorage_COPY(Double)
IMPLEMENT_THStorage_COPY(Half)
IMPLEMENT_THStorage_COPY(Bool)
IMPLEMENT_THStorage_COPY(BFloat16)
#endif

View file

@ -15,5 +15,6 @@ TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src
TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src);
#endif

View file

@ -19,6 +19,7 @@
#define THIntTensor THTensor
#define THLongTensor THTensor
#define THBoolTensor THTensor
#define THBFloat16Tensor THTensor
/**** access methods ****/
TH_API THStorage* THTensor_(storage)(const THTensor *self);

View file

@ -85,6 +85,7 @@ INSTALL(FILES
THCDeviceTensorUtils.cuh
THCDeviceTensorUtils-inl.cuh
THCGenerateAllTypes.h
THCGenerateBFloat16Type.h
THCGenerateBoolType.h
THCGenerateByteType.h
THCGenerateCharType.h

View file

@ -0,0 +1,25 @@
#ifndef THC_GENERIC_FILE
#error "You must define THC_GENERIC_FILE before including THCGenerateBFloat16Type.h"
#endif
#include <c10/util/BFloat16.h>
#define scalar_t at::BFloat16
#define Real BFloat16
#define CReal CudaBFloat16
#define THC_REAL_IS_BFLOAT16
#line 1 THC_GENERIC_FILE
#include THC_GENERIC_FILE
#undef scalar_t
#undef Real
#undef CReal
#undef THC_REAL_IS_BFLOAT16
#ifndef THCGenerateAllTypes
#ifndef THCGenerateFloatTypes
#undef THC_GENERIC_FILE
#endif
#endif

View file

@ -11,6 +11,9 @@
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateBFloat16Type.h>
#include <c10/util/intrusive_ptr.h>
void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size)

View file

@ -14,3 +14,6 @@
#include <THC/generic/THCStorage.cu>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorage.cu>
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -12,4 +12,7 @@
#include <THC/generic/THCStorage.h>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorage.h>
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -8,3 +8,6 @@
#include <THC/generic/THCStorageCopy.cpp>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorageCopy.cpp>
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -11,3 +11,6 @@
#include <THC/generic/THCStorageCopy.cu>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorageCopy.cu>
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -11,4 +11,7 @@
#include <THC/generic/THCStorageCopy.h>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorageCopy.h>
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -10,6 +10,9 @@
#include <THC/generic/THCTensor.cpp>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensor.cpp>
#include <THC/THCGenerateBFloat16Type.h>
#include <THC/THCTensorInfo.cuh>
#include <ATen/native/cuda/Resize.cuh>
@ -66,6 +69,8 @@ THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) {
return THCudaDoubleTensor_new(state);
case at::ScalarType::Bool:
return THCudaBoolTensor_new(state);
case at::ScalarType::BFloat16:
return THCudaBFloat16Tensor_new(state);
default:
AT_ERROR("unexpected ScalarType: ", toString(scalar_type));
}

View file

@ -6,3 +6,6 @@
#include <THC/generic/THCTensor.cu>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensor.cu>
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -19,4 +19,8 @@ typedef struct THC_CLASS THCDescBuff
#include <THC/generic/THCTensor.h>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensor.h>
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -59,3 +59,6 @@ THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor*
#include <THC/generic/THCTensor.hpp>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensor.hpp>
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -3,6 +3,7 @@
#include <THC/THCNumerics.cuh>
#include <THC/THCTensorCopy.hpp>
#include <type_traits>
#include <c10/util/BFloat16.h>
// Copy operator for the pointwise apply kernel
template <typename T>
@ -23,8 +24,18 @@ struct CopyOp <bool> {
}
};
template <>
struct CopyOp <at::BFloat16> {
__device__ __forceinline__ void operator()(at::BFloat16* dst, at::BFloat16* src) {
*dst = ScalarConvert<at::BFloat16, at::BFloat16>::to(*src);
}
};
#include <THC/generic/THCTensorCopy.cu>
#include <THC/THCGenerateAllTypes.h>
#include <THC/generic/THCTensorCopy.cu>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensorCopy.cu>
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -12,4 +12,7 @@
#include <THC/generic/THCTensorCopy.h>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensorCopy.h>
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -6,15 +6,16 @@
// These used to be distinct types; for some measure of backwards compatibility and documentation
// alias these to the single THCStorage type.
#define THCudaStorage THCStorage
#define THCudaDoubleStorage THCStorage
#define THCudaHalfStorage THCStorage
#define THCudaByteStorage THCStorage
#define THCudaCharStorage THCStorage
#define THCudaShortStorage THCStorage
#define THCudaIntStorage THCStorage
#define THCudaLongStorage THCStorage
#define THCudaBoolStorage THCStorage
#define THCudaStorage THCStorage
#define THCudaDoubleStorage THCStorage
#define THCudaHalfStorage THCStorage
#define THCudaByteStorage THCStorage
#define THCudaCharStorage THCStorage
#define THCudaShortStorage THCStorage
#define THCudaIntStorage THCStorage
#define THCudaLongStorage THCStorage
#define THCudaBoolStorage THCStorage
#define THCudaBFloat16Storage THCStorage
THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*);
THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*);

View file

@ -34,6 +34,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Float)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Half)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Double)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Bool)
TH_CUDA_STORAGE_IMPLEMENT_COPY(BFloat16)
void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src)
{
@ -67,6 +68,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Bool)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(BFloat16)
#undef TH_CUDA_STORAGE_IMPLEMENT_COPY
#undef TH_CUDA_STORAGE_IMPLEMENT_COPYTO

View file

@ -29,6 +29,7 @@ THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,) // i.e. float
THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double)
THC_CUDA_STORAGE_IMPLEMENT_COPY(Half,Half)
THC_CUDA_STORAGE_IMPLEMENT_COPY(Bool,Bool)
THC_CUDA_STORAGE_IMPLEMENT_COPY(BFloat16,BFloat16)
#undef THC_CUDA_STORAGE_IMPLEMENT_COPY

View file

@ -15,6 +15,7 @@ THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct
THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src);
THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src);
THC_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src);
THC_API void THCStorage_(copyBFloat16)(THCState *state, THCStorage *storage, struct THBFloat16Storage *src);
THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src);
THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src);
@ -25,6 +26,7 @@ THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, st
THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src);
THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src);
THC_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src);
THC_API void THCStorage_(copyCudaBFloat16)(THCState *state, THCStorage *storage, struct THCudaBFloat16Storage *src);
THC_API void TH_CONCAT_2(THByteStorage_copyCuda , Real)(THCState *state, THByteStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THCharStorage_copyCuda , Real)(THCState *state, THCharStorage *self, struct THCStorage *src);
@ -35,6 +37,7 @@ THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloa
THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THBFloat16Storage_copyCuda, Real)(THCState *state, THBFloat16Storage *self, struct THCStorage *src);
THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src);
THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src);

View file

@ -15,6 +15,7 @@
#define THCudaIntTensor THCTensor
#define THCudaLongTensor THCTensor
#define THCudaBoolTensor THCTensor
#define THCudaBFloat16Tensor THCTensor
/**** access methods ****/
THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self);

View file

@ -31,7 +31,8 @@ namespace c10 {
_(bool, Bool, i) /* 11 */ \
_(c10::qint8, QInt8, i) /* 12 */ \
_(c10::quint8, QUInt8, i) /* 13 */ \
_(c10::qint32, QInt32, i) /* 14 */
_(c10::qint32, QInt32, i) /* 14 */ \
_(c10::BFloat16, BFloat16, d) /* 15 */
// If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But

View file

@ -33,7 +33,7 @@ namespace {
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// The relative error should be less than 1/(2^7) since bfloat16
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}

View file

@ -53,11 +53,10 @@ struct alignas(2) BFloat16 {
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif
explicit inline C10_HOST_DEVICE BFloat16(float value);
explicit inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
};
} // namespace c10

View file

@ -8975,6 +8975,7 @@ class _TestTorchMixin(object):
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
bool = torch.BoolStorage().element_size()
bfloat16 = torch.BFloat16Storage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
@ -8983,6 +8984,7 @@ class _TestTorchMixin(object):
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertEqual(bool, torch.BoolTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
@ -8992,6 +8994,7 @@ class _TestTorchMixin(object):
self.assertGreater(float, 0)
self.assertGreater(double, 0)
self.assertGreater(bool, 0)
self.assertGreater(bfloat16, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
@ -10303,6 +10306,13 @@ class _TestTorchMixin(object):
self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
self.assertIs(halfStorage.dtype, torch.float16)
bfloat16Storage = storage.bfloat16()
self.assertEqual(bfloat16Storage.size(), 6)
self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4])
self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage')
self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4])
self.assertIs(bfloat16Storage.dtype, torch.bfloat16)
longStorage = storage.long()
self.assertEqual(longStorage.size(), 6)
self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
@ -10419,6 +10429,8 @@ class _TestTorchMixin(object):
obj.__repr__()
str(obj)
for t in torch._storage_classes:
if t == torch.BFloat16Storage:
continue # Fix once fill is enabled for bfloat16
if t.is_cuda and not torch.cuda.is_available():
continue
if t == torch.BoolStorage or t == torch.cuda.BoolStorage:

View file

@ -222,6 +222,11 @@ class ByteStorage(_C.ByteStorageBase, _StorageBase):
class BoolStorage(_C.BoolStorageBase, _StorageBase):
pass
class BFloat16Storage(_C.BFloat16StorageBase, _StorageBase):
pass
class QUInt8Storage(_C.QUInt8StorageBase, _StorageBase):
pass
@ -235,7 +240,7 @@ class QInt32Storage(_C.QInt32StorageBase, _StorageBase):
_storage_classes = {
DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage,
QInt32Storage
QInt32Storage, BFloat16Storage
}
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
@ -286,6 +291,7 @@ del CharStorageBase
del ByteStorageBase
del BoolStorageBase
del QUInt8StorageBase
del BFloat16StorageBase
################################################################################
# Import most common subpackages

View file

@ -13,6 +13,7 @@ storage_classes = [
'CharStorageBase',
'ByteStorageBase',
'BoolStorageBase',
'BFloat16StorageBase',
]

View file

@ -125,6 +125,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
THPQUInt8Storage_postInit(module);
THPQInt8Storage_postInit(module);
THPQInt32Storage_postInit(module);
THPBFloat16Storage_postInit(module);
THPAutograd_initFunctions();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
@ -514,6 +515,7 @@ bool THCPShortStorage_init(PyObject *module);
bool THCPCharStorage_init(PyObject *module);
bool THCPByteStorage_init(PyObject *module);
bool THCPBoolStorage_init(PyObject *module);
bool THCPBFloat16Storage_init(PyObject *module);
void THCPStream_init(PyObject *module);
void THCPEvent_init(PyObject *module);
@ -536,6 +538,17 @@ void init__THCUNN(PyObject*);
}} // namespace torch::nn
bool THDPDoubleStorage_init(PyObject *module);
bool THDPFloatStorage_init(PyObject *module);
//bool THDPHalfStorage_init(PyObject *module);
bool THDPLongStorage_init(PyObject *module);
bool THDPIntStorage_init(PyObject *module);
bool THDPShortStorage_init(PyObject *module);
bool THDPCharStorage_init(PyObject *module);
bool THDPByteStorage_init(PyObject *module);
bool THDPBoolStorage_init(PyObject *module);
bool THDPBFloat16Storage_init(PyObject *module);
static std::vector<PyMethodDef> methods;
// TODO: Refactor this in some less manual way
@ -664,6 +677,7 @@ PyObject* initModule() {
ASSERT_TRUE(THPQUInt8Storage_init(module));
ASSERT_TRUE(THPQInt8Storage_init(module));
ASSERT_TRUE(THPQInt32Storage_init(module));
ASSERT_TRUE(THPBFloat16Storage_init(module));
#ifdef USE_CUDA
// This will only initialise base classes and attach them to library namespace
@ -679,6 +693,7 @@ PyObject* initModule() {
ASSERT_TRUE(THCPCharStorage_init(module));
ASSERT_TRUE(THCPByteStorage_init(module));
ASSERT_TRUE(THCPBoolStorage_init(module));
ASSERT_TRUE(THCPBFloat16Storage_init(module));
THCPStream_init(module);
THCPEvent_init(module);

View file

@ -29,6 +29,9 @@
#include <torch/csrc/generic/Storage.cpp>
#include <TH/THGenerateBoolType.h>
#include <torch/csrc/generic/Storage.cpp>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/generic/Storage.cpp>
#include <TH/THGenerateQTypes.h>

View file

@ -29,6 +29,8 @@
PyObject_IsInstance(obj, THPQInt8StorageClass)
#define THPQInt32Storage_Check(obj) \
PyObject_IsInstance(obj, THPQInt32StorageClass)
#define THPBFloat16Storage_Check(obj) \
PyObject_IsInstance(obj, THPBFloat16StorageClass)
#define THPDoubleStorage_CData(obj) (obj)->cdata
@ -43,6 +45,7 @@
#define THPQUInt8Storage_CData(obj) (obj)->cdata
#define THPQInt8Storage_CData(obj) (obj)->cdata
#define THPQInt32Storage_CData(obj) (obj)->cdata
#define THPBFloat16Storage_CData(obj) (obj)->cdata
#ifdef _THP_CORE
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
@ -58,6 +61,9 @@
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateBoolType.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateQTypes.h>

View file

@ -1,5 +1,5 @@
#include <torch/csrc/byte_order.h>
#include <c10/util/BFloat16.h>
#include <cstring>
#if defined(_MSC_VER)
@ -140,6 +140,15 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s
}
}
void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
uint16_t x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src));
std::memcpy(&dst[i], &x, sizeof(dst[i]));
src += sizeof(uint16_t);
}
}
void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {

View file

@ -4,6 +4,7 @@
#include <cstdint>
#include <cstddef>
#include <THHalf.h>
#include <c10/util/BFloat16.h>
enum THPByteOrder {
THP_LITTLE_ENDIAN = 0,
@ -19,6 +20,7 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len);
void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len);

View file

@ -316,6 +316,7 @@ static PyObject * THCPModule_initExtension(PyObject *self)
THCPCharStorage_postInit(m);
THCPByteStorage_postInit(m);
THCPBoolStorage_postInit(m);
THCPBFloat16Storage_postInit(m);
bool has_half = true;

View file

@ -21,3 +21,6 @@
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -23,15 +23,18 @@
PyObject_IsInstance(obj, THCPByteStorageClass)
#define THCPBoolStorage_Check(obj) \
PyObject_IsInstance(obj, THCPBoolStorageClass)
#define THCPBFloat16Storage_Check(obj) \
PyObject_IsInstance(obj, THCPBFloat16StorageClass)
#define THCPDoubleStorage_CData(obj) (obj)->cdata
#define THCPFloatStorage_CData(obj) (obj)->cdata
#define THCPLongStorage_CData(obj) (obj)->cdata
#define THCPIntStorage_CData(obj) (obj)->cdata
#define THCPShortStorage_CData(obj) (obj)->cdata
#define THCPCharStorage_CData(obj) (obj)->cdata
#define THCPByteStorage_CData(obj) (obj)->cdata
#define THCPBoolStorage_CData(obj) (obj)->cdata
#define THCPDoubleStorage_CData(obj) (obj)->cdata
#define THCPFloatStorage_CData(obj) (obj)->cdata
#define THCPLongStorage_CData(obj) (obj)->cdata
#define THCPIntStorage_CData(obj) (obj)->cdata
#define THCPShortStorage_CData(obj) (obj)->cdata
#define THCPCharStorage_CData(obj) (obj)->cdata
#define THCPByteStorage_CData(obj) (obj)->cdata
#define THCPBoolStorage_CData(obj) (obj)->cdata
#define THCPBFloat16Storage_CData(obj) (obj)->cdata
#ifdef _THP_CORE
#define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType)
@ -46,4 +49,7 @@
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -12,3 +12,6 @@
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
#include <THC/THCGenerateBFloat16Type.h>

View file

@ -9,4 +9,7 @@
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -11,6 +11,9 @@
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateBFloat16Type.h>
#ifdef USE_CUDA
// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
// whatever the current stream of the device the input is associated with was.

View file

@ -18,4 +18,7 @@
#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
#include <THC/THCGenerateBFloat16Type.h>
#endif

View file

@ -317,6 +317,7 @@ void THPStorage_(initCopyMethods)()
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBFloat16StorageType, h, &THWStorage_(copyBFloat16));
#ifdef THC_GENERIC_FILE
// copy from GPU types
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
@ -328,6 +329,7 @@ void THPStorage_(initCopyMethods)()
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, h, &THWStorage_(copyCudaBool));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBFloat16StorageType, h, &THWStorage_(copyCudaBFloat16));
// add CPU <- GPU copies to base type
/// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
@ -342,6 +344,7 @@ void THPStorage_(initCopyMethods)()
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, b, &THCpuStorage_(copyCudaBool));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBFloat16StorageType, b, &THCpuStorage_(copyCudaBFloat16));
#undef THCpuStorage
#undef THCpuStorage_
#endif

View file

@ -170,6 +170,8 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
THP_decodeInt64Buffer((int64_t*) THWStorage_(data)(storage), src + offset, byte_order, count);
#elif defined(TH_REAL_IS_HALF)
THP_decodeHalfBuffer(THWStorage_(data)(storage), src + offset, byte_order, count);
#elif defined(TH_REAL_IS_BFLOAT16)
THP_decodeBFloat16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count);
#elif defined(TH_REAL_IS_FLOAT)
THP_decodeFloatBuffer(THWStorage_(data)(storage), src + offset, byte_order, count);
#elif defined(TH_REAL_IS_DOUBLE)

View file

@ -183,6 +183,9 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) {
#include <torch/csrc/generic/serialization.cpp>
#include <TH/THGenerateHalfType.h>
#include <torch/csrc/generic/serialization.cpp>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/generic/serialization.cpp>
#include <TH/THGenerateBoolType.h>

View file

@ -10,6 +10,9 @@
#include <torch/csrc/generic/serialization.h>
#include <TH/THGenerateBoolType.h>
#include <torch/csrc/generic/serialization.h>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/generic/serialization.h>
#include <TH/THGenerateQTypes.h>

View file

@ -17,6 +17,9 @@
#include <torch/csrc/generic/utils.cpp>
#include <TH/THGenerateHalfType.h>
#include <torch/csrc/generic/utils.cpp>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/generic/utils.cpp>
#include <TH/THGenerateBoolType.h>

View file

@ -81,6 +81,10 @@
#define THPHalfUtils_unpackReal(object) (at::Half)THPUtils_unpackReal_FLOAT(object)
#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value)
#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value)
#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object)
#define THPBFloat16Utils_unpackReal(object) (at::BFloat16)THPUtils_unpackReal_FLOAT(object)
#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value)
#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value)
#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object)
#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object)
@ -149,6 +153,9 @@ struct THPUtils_typeTraits {};
#include <torch/csrc/generic/utils.h>
#include <TH/THGenerateHalfType.h>
#include <torch/csrc/generic/utils.h>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/generic/utils.h>
#include <TH/THGenerateBoolType.h>

View file

@ -47,6 +47,8 @@ static std::pair<std::string, std::string> getDtypeNames(
return std::make_pair("quint8", "");
case at::ScalarType::QInt32:
return std::make_pair("qint32", "");
case at::ScalarType::BFloat16:
return std::make_pair("bfloat16", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}

View file

@ -32,7 +32,7 @@ def find_cuda_windows_lib():
# Override the default search process
# Fixes https://github.com/pytorch/pytorch/issues/20202
# The libary selection will be done in these directories one by one
# 1. [Package Root]\Lib
# 1. [Package Root]\Lib
# That's where our libraries are in, which should be loaded first.
# 2. [Python Root]\Library\bin
# That's where `cudatoolkit` store the cuda libraries.
@ -596,7 +596,7 @@ def _dummy_type(name):
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
# Define dummy base classes
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool']:
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool', 'BFloat16']:
storage_name = 'Cuda{0}StorageBase'.format(t)
tensor_name = 'Cuda{0}TensorBase'.format(t)
@ -661,6 +661,10 @@ class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
class BoolStorage(_CudaBase, torch._C.CudaBoolStorageBase, _StorageBase):
pass
class BFloat16Storage(_CudaBase, torch._C.CudaBFloat16StorageBase, _StorageBase):
pass
torch._storage_classes.add(DoubleStorage)
torch._storage_classes.add(FloatStorage)
torch._storage_classes.add(LongStorage)
@ -670,6 +674,7 @@ torch._storage_classes.add(CharStorage)
torch._storage_classes.add(ByteStorage)
torch._storage_classes.add(HalfStorage)
torch._storage_classes.add(BoolStorage)
torch._storage_classes.add(BFloat16Storage)
from . import sparse # noqa: F401
from . import profiler # noqa: F401

View file

@ -87,6 +87,10 @@ class _StorageBase(object):
"""Casts this storage to bool type"""
return self.type(type(self).__module__ + '.BoolStorage')
def bfloat16(self):
"""Casts this storage to bfloat16 type"""
return self.type(type(self).__module__ + '.BFloat16Storage')
def pin_memory(self):
"""Copies the storage to pinned memory, if it's not already pinned."""
if self.is_cuda: