From 00f685d2d8e2df080bc62fb37cdb551bb4637630 Mon Sep 17 00:00:00 2001 From: Hong Xu Date: Wed, 26 Feb 2020 22:21:31 -0800 Subject: [PATCH] Add Scalar::type() (#33603) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33603 This function returns ScalarType based on its value. This is helpful to avoid code generated in aten_op.h has returned Scalars depending on arg self to determine its type. Test Plan: Imported from OSS Differential Revision: D20100218 Pulled By: ezyang fbshipit-source-id: 337729a7559e6abb3a16b2a563a2b92aa96c7016 --- c10/core/Scalar.h | 14 ++++++++++++++ caffe2/contrib/aten/gen_op.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index cc1ee4354fd..25cb0101a63 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -98,6 +98,20 @@ class C10_API Scalar { Scalar operator-() const; + ScalarType type() const { + if (isComplex()) { + return ScalarType::ComplexDouble; + } else if (isFloatingPoint()) { + return ScalarType::Double; + } else if (isIntegral(/*includeBool=*/false)) { + return ScalarType::Long; + } else if (isBoolean()) { + return ScalarType::Bool; + } else { + throw std::runtime_error("Unknown scalar type."); + } + } + private: template::is_integer && ! std::is_same::value, bool>::type* = diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index bfd2501a16b..39e47062848 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -73,7 +73,7 @@ def value_is_tensor_type(v): # for each aten type, how do we handle a return value of that type? RETURN_MAP = { 'Tensor': 'assignTo(Output(${offset}),${output});', - 'Scalar': 'assignTo(Output(${offset}),self.scalar_type(), ${output});', + 'Scalar': 'assignTo(Output(${offset}),${output}.type(), ${output});', 'bool': 'assignToValue(Output(${offset}),${output});', 'int64_t': 'assignToValue(Output(${offset}),${output});', 'std::vector': 'assignListStartingAt(${offset}, ${output});',