mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add backwards compatibility and other fixes to Dispatch macros. (#17996)
Summary: Changes: 1) https://github.com/pytorch/pytorch/pull/17527 changed dispatch macros to be ScalarType based instead of at::Type based. This broke cpp extensions that relied on dispatch macros. Since IMO these should be ScalarType based (and some extensions have already updated), we allow either at::Type or at::ScalarType to be passed, but passing at::Type will result in a deprecated warning. 2) Reintroduce macros that were deleted (AT_DISPATCH_ALL_TYPES_AND_HALF, AT_DISPATCH_COMPLEX_TYPES, AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX, AT_DISPATCH_ALL_TYPES_AND_COMPLEX); the AND_HALF ones now give a deprecated warning because there are more extensible macros that were introduced in their place. 3) Makes AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND into a ScalarType based macro (and updates usages). This was the result of a logical merge conflicts. 4) Adds a new macro, C10_DEPRECATED_MESSAGE for passing a deprecated message to the compiler. I didn't spend much time seeing if this can be enabled for versions before C++14. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17996 Reviewed By: ezyang Differential Revision: D14446203 Pulled By: gchanan fbshipit-source-id: 1da56e2e9c15aa8f913ebbf6bf1110c5b6dc375e
This commit is contained in:
parent
f3806094d5
commit
d1843d4173
3 changed files with 129 additions and 14 deletions
|
|
@ -10,30 +10,59 @@
|
|||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE("passing at::Type to an AT_DISPATCH macro is deprecated, " \
|
||||
"pass an at::ScalarType instead")
|
||||
inline at::ScalarType scalar_type(const at::Type &t) {
|
||||
return t.scalarType();
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, " \
|
||||
"use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
|
||||
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
|
||||
|
||||
C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, " \
|
||||
"use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) " \
|
||||
"instead")
|
||||
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
||||
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
|
||||
|
|
@ -44,26 +73,30 @@
|
|||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexHalf, std::complex<at::Half>, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
|
|
@ -72,10 +105,88 @@
|
|||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() \
|
||||
const auto& the_type = TYPE; \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
|
||||
template <at::ScalarType N>
|
||||
struct MyTemplate;
|
||||
|
||||
|
|
@ -107,8 +218,7 @@ struct MyTemplate<at::ScalarType::Bool> {
|
|||
|
||||
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const at::Type& the_type = TYPE; \
|
||||
switch (the_type.scalarType()) { \
|
||||
switch (TYPE) { \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
|
||||
|
|
@ -123,6 +233,6 @@ struct MyTemplate<at::ScalarType::Bool> {
|
|||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
|
||||
} \
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ Scalar item(const Tensor& self) {
|
|||
Scalar _local_scalar_dense_cpu(const Tensor& self) {
|
||||
Scalar r;
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, self.type(), "_local_scalar_dense_cpu", [&] {
|
||||
at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
|
||||
scalar_t value = *self.data<scalar_t>();
|
||||
r = Scalar(value);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -21,10 +21,15 @@
|
|||
// portable way to declare something deprecated.
|
||||
#if defined(__cplusplus) && __cplusplus > 201402L
|
||||
# define C10_DEPRECATED [[deprecated]]
|
||||
# define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]]
|
||||
#elif defined(__GNUC__)
|
||||
# define C10_DEPRECATED __attribute__((deprecated))
|
||||
// TODO: is there some way to implement this?
|
||||
# define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated))
|
||||
#elif defined(_MSC_VER)
|
||||
# define C10_DEPRECATED __declspec(deprecated)
|
||||
// TODO: is there some way to implement this?
|
||||
# define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated)
|
||||
#else
|
||||
# warning "You need to implement C10_DEPRECATED for this compiler"
|
||||
# define C10_DEPRECATED
|
||||
|
|
|
|||
Loading…
Reference in a new issue