diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 764cbc610f1..44a9a558a44 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -10,30 +10,59 @@ return __VA_ARGS__(); \ } +namespace detail { + +inline at::ScalarType scalar_type(at::ScalarType s) { + return s; +} + +C10_DEPRECATED_MESSAGE("passing at::Type to an AT_DISPATCH macro is deprecated, " \ + "pass an at::ScalarType instead") +inline at::ScalarType scalar_type(const at::Type &t) { + return t.scalarType(); +} + +C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, " \ + "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead") +inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {} + +C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, " \ + "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) " \ + "instead") +inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} + +} + #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ @@ -44,26 +73,30 @@ AT_PRIVATE_CASE_TYPE( \ at::ScalarType::ComplexHalf, std::complex, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ @@ -72,10 +105,88 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() +#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ + [&] { \ + detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ + }() + +#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ + }() + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ + }() + +#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...) \ + [&] { \ + detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() \ + const auto& the_type = TYPE; \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ + }() + + template struct MyTemplate; @@ -107,8 +218,7 @@ struct MyTemplate { #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - const at::Type& the_type = TYPE; \ - switch (the_type.scalarType()) { \ + switch (TYPE) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ @@ -123,6 +233,6 @@ struct MyTemplate { AT_PRIVATE_CASE_TYPE( \ at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ } \ }() diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 8e4ae10fde5..918f4c3da84 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -19,7 +19,7 @@ Scalar item(const Tensor& self) { Scalar _local_scalar_dense_cpu(const Tensor& self) { Scalar r; AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND( - at::ScalarType::Half, at::ScalarType::Bool, self.type(), "_local_scalar_dense_cpu", [&] { + at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cpu", [&] { scalar_t value = *self.data(); r = Scalar(value); }); diff --git a/c10/util/Deprecated.h b/c10/util/Deprecated.h index d2c3776dbd3..294fb3a8881 100644 --- a/c10/util/Deprecated.h +++ b/c10/util/Deprecated.h @@ -21,10 +21,15 @@ // portable way to declare something deprecated. #if defined(__cplusplus) && __cplusplus > 201402L # define C10_DEPRECATED [[deprecated]] +# define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] #elif defined(__GNUC__) # define C10_DEPRECATED __attribute__((deprecated)) +// TODO: is there some way to implement this? +# define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) #elif defined(_MSC_VER) # define C10_DEPRECATED __declspec(deprecated) +// TODO: is there some way to implement this? +# define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated) #else # warning "You need to implement C10_DEPRECATED for this compiler" # define C10_DEPRECATED