Add backwards compatibility and other fixes to Dispatch macros. (#17996)

Summary:
Changes:
1) https://github.com/pytorch/pytorch/pull/17527 changed dispatch macros to be ScalarType based instead of at::Type based.  This broke cpp extensions that relied on dispatch macros.  Since IMO these should be ScalarType based (and some extensions have already updated), we allow either at::Type or at::ScalarType to be passed, but passing at::Type will result in a deprecated warning.

2) Reintroduce macros that were deleted (AT_DISPATCH_ALL_TYPES_AND_HALF, AT_DISPATCH_COMPLEX_TYPES, AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX, AT_DISPATCH_ALL_TYPES_AND_COMPLEX); the AND_HALF ones now give a deprecated warning because there are more extensible macros that were introduced in their place.

3) Makes AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND into a ScalarType based macro (and updates usages).  This was the result of a logical merge conflicts.

4) Adds a new macro, C10_DEPRECATED_MESSAGE for passing a deprecated message to the compiler.  I didn't spend much time seeing if this can be enabled for versions before C++14.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17996

Reviewed By: ezyang

Differential Revision: D14446203

Pulled By: gchanan

fbshipit-source-id: 1da56e2e9c15aa8f913ebbf6bf1110c5b6dc375e
This commit is contained in:
Gregory Chanan 2019-03-15 14:16:22 -07:00 committed by Facebook Github Bot
parent f3806094d5
commit d1843d4173
3 changed files with 129 additions and 14 deletions

View file

@ -10,30 +10,59 @@
return __VA_ARGS__(); \
}
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
C10_DEPRECATED_MESSAGE("passing at::Type to an AT_DISPATCH macro is deprecated, " \
"pass an at::ScalarType instead")
inline at::ScalarType scalar_type(const at::Type &t) {
return t.scalarType();
}
C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, " \
"use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, " \
"use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) " \
"instead")
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
}
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
@ -44,26 +73,30 @@
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexHalf, std::complex<at::Half>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
@ -72,10 +105,88 @@
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...) \
[&] { \
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() \
const auto& the_type = TYPE; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
template <at::ScalarType N>
struct MyTemplate;
@ -107,8 +218,7 @@ struct MyTemplate<at::ScalarType::Bool> {
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
switch (TYPE) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
@ -123,6 +233,6 @@ struct MyTemplate<at::ScalarType::Bool> {
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
} \
}()

View file

@ -19,7 +19,7 @@ Scalar item(const Tensor& self) {
Scalar _local_scalar_dense_cpu(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
at::ScalarType::Half, at::ScalarType::Bool, self.type(), "_local_scalar_dense_cpu", [&] {
at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
scalar_t value = *self.data<scalar_t>();
r = Scalar(value);
});

View file

@ -21,10 +21,15 @@
// portable way to declare something deprecated.
#if defined(__cplusplus) && __cplusplus > 201402L
# define C10_DEPRECATED [[deprecated]]
# define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]]
#elif defined(__GNUC__)
# define C10_DEPRECATED __attribute__((deprecated))
// TODO: is there some way to implement this?
# define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated))
#elif defined(_MSC_VER)
# define C10_DEPRECATED __declspec(deprecated)
// TODO: is there some way to implement this?
# define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated)
#else
# warning "You need to implement C10_DEPRECATED for this compiler"
# define C10_DEPRECATED