mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add Scalar::type() (#33603)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33603 This function returns ScalarType based on its value. This is helpful to avoid code generated in aten_op.h has returned Scalars depending on arg self to determine its type. Test Plan: Imported from OSS Differential Revision: D20100218 Pulled By: ezyang fbshipit-source-id: 337729a7559e6abb3a16b2a563a2b92aa96c7016
This commit is contained in:
parent
d41c8d0461
commit
00f685d2d8
2 changed files with 15 additions and 1 deletions
|
|
@ -98,6 +98,20 @@ class C10_API Scalar {
|
|||
|
||||
Scalar operator-() const;
|
||||
|
||||
ScalarType type() const {
|
||||
if (isComplex()) {
|
||||
return ScalarType::ComplexDouble;
|
||||
} else if (isFloatingPoint()) {
|
||||
return ScalarType::Double;
|
||||
} else if (isIntegral(/*includeBool=*/false)) {
|
||||
return ScalarType::Long;
|
||||
} else if (isBoolean()) {
|
||||
return ScalarType::Bool;
|
||||
} else {
|
||||
throw std::runtime_error("Unknown scalar type.");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template<typename T,
|
||||
typename std::enable_if<std::numeric_limits<T>::is_integer && ! std::is_same<T, bool>::value, bool>::type* =
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ def value_is_tensor_type(v):
|
|||
# for each aten type, how do we handle a return value of that type?
|
||||
RETURN_MAP = {
|
||||
'Tensor': 'assignTo(Output(${offset}),${output});',
|
||||
'Scalar': 'assignTo(Output(${offset}),self.scalar_type(), ${output});',
|
||||
'Scalar': 'assignTo(Output(${offset}),${output}.type(), ${output});',
|
||||
'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
|
||||
'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
|
||||
'std::vector<Tensor>': 'assignListStartingAt(${offset}, ${output});',
|
||||
|
|
|
|||
Loading…
Reference in a new issue