Switch AT_DISPATCH_COMPLEX_TYPES_AND and AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX to c10::complex (#37697)

Summary:
These two macros only appear in `Dispatch.h`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37697

Differential Revision: D21666340

Pulled By: anjali411

fbshipit-source-id: 1f31ab46c08b77f1011367e471874d390ffa70fb
This commit is contained in:
Xiang Gao 2020-05-21 14:57:20 -07:00 committed by Facebook GitHub Bot
parent 0e2a0478af
commit 9b656dac7f

View file

@ -433,22 +433,6 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} \
}()
#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
SCALARTYPE, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), \
__VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}()
#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
@ -583,27 +567,3 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...) \
[&] { \
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()