ConstantOfShape CUDA implementation (#1168)

* ConstantOfShape CUDA implementation

* Enhance the fallback logic, so the case that Shape -> ... -> ConstantOfShape won't fallback ConstantOfShape to CPU provider

* move shared code to cpu implementation

* do the fill based on sizeof(data_type)

* update method access level
This commit is contained in:
Hector Li 2019-06-07 11:41:58 -07:00 committed by GitHub
parent e43e64bf84
commit 8d68098c20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 218 additions and 134 deletions

View file

@ -29,159 +29,131 @@ ONNX_CPU_OPERATOR_KERNEL(
DataTypeImpl::GetTensorType<bool>()}),
ConstantOfShape);
#define FETCH_VALUE_DATA(field, c_type) \
{ \
c_type t; \
auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &t, 1); \
ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \
field = t; \
#define FETCH_VALUE_DATA(c_type) \
{ \
c_type val; \
auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1); \
ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \
SetValue(sizeof(c_type), reinterpret_cast<void*>(&val)); \
}
void onnxruntime::ConstantOfShape::SetValue(const ONNX_NAMESPACE::TensorProto& t_proto) {
void onnxruntime::ConstantOfShapeBase::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) {
using namespace utils;
ORT_ENFORCE(t_proto.has_data_type());
ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type()));
tensor_type_ = static_cast<TensorProto_DataType>(t_proto.data_type());
const auto tensor_type = static_cast<TensorProto_DataType>(t_proto.data_type());
const void* const raw_data = t_proto.has_raw_data() ? t_proto.raw_data().data() : nullptr;
const size_t raw_data_len = t_proto.has_raw_data() ? t_proto.raw_data().size() : 0;
switch (tensor_type_) {
switch (tensor_type) {
case TensorProto::BOOL:
FETCH_VALUE_DATA(value_.ui64_, bool);
FETCH_VALUE_DATA(bool);
break;
case TensorProto::FLOAT:
FETCH_VALUE_DATA(value_.fl_, float);
FETCH_VALUE_DATA(float);
break;
case TensorProto::FLOAT16:
FETCH_VALUE_DATA(value_.fl16_, MLFloat16);
FETCH_VALUE_DATA(MLFloat16);
break;
case TensorProto::DOUBLE:
FETCH_VALUE_DATA(value_.dbl_, double);
FETCH_VALUE_DATA(double);
break;
case TensorProto::INT8:
FETCH_VALUE_DATA(value_.i64_, int8_t);
FETCH_VALUE_DATA(int8_t);
break;
case TensorProto::INT16:
FETCH_VALUE_DATA(value_.i64_, int16_t);
FETCH_VALUE_DATA(int16_t);
break;
case TensorProto::INT32:
FETCH_VALUE_DATA(value_.i64_, int32_t);
FETCH_VALUE_DATA(int32_t);
break;
case TensorProto::INT64:
FETCH_VALUE_DATA(value_.i64_, int64_t);
FETCH_VALUE_DATA(int64_t);
break;
case TensorProto::UINT8:
FETCH_VALUE_DATA(value_.ui64_, uint8_t);
FETCH_VALUE_DATA(uint8_t);
break;
case TensorProto::UINT16:
FETCH_VALUE_DATA(value_.ui64_, uint16_t);
FETCH_VALUE_DATA(uint16_t);
break;
case TensorProto::UINT32:
FETCH_VALUE_DATA(value_.ui64_, uint32_t);
FETCH_VALUE_DATA(uint32_t);
break;
case TensorProto::UINT64:
FETCH_VALUE_DATA(value_.ui64_, uint64_t);
FETCH_VALUE_DATA(uint64_t);
break;
default:
ORT_THROW("Unsupported value attribute datatype: ", tensor_type_);
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
break;
}
}
#undef FETCH_VALUE_DATA
template <class T>
inline T onnxruntime::ConstantOfShape::Value::GetFromSigned() const {
return static_cast<T>(i64_);
}
template <class T>
inline T onnxruntime::ConstantOfShape::Value::GetFromUnsigned() const {
return static_cast<T>(ui64_);
}
template <class T>
inline void FilloutOutput(T value, Tensor* output_tensor) {
auto out = gsl::make_span(output_tensor->template MutableData<T>(), output_tensor->Shape().Size());
inline void FilloutOutput(T value, void* output_data, size_t size) {
auto out = gsl::make_span(reinterpret_cast<T*>(output_data), size);
std::fill(out.begin(), out.end(), value);
}
void onnxruntime::ConstantOfShape::DispatchTypeAndFillOutput(Tensor* output_tensor) const {
switch (tensor_type_) {
case TensorProto::BOOL:
FilloutOutput(value_.GetFromUnsigned<bool>(), output_tensor);
break;
case TensorProto::FLOAT:
FilloutOutput(value_.GetFloat(), output_tensor);
break;
case TensorProto::FLOAT16:
FilloutOutput(value_.GetFloat16(), output_tensor);
break;
case TensorProto::DOUBLE:
FilloutOutput(value_.GetDouble(), output_tensor);
break;
case TensorProto::INT8:
FilloutOutput(value_.GetFromSigned<int8_t>(), output_tensor);
break;
case TensorProto::INT16:
FilloutOutput(value_.GetFromSigned<int16_t>(), output_tensor);
break;
case TensorProto::INT32:
FilloutOutput(value_.GetFromSigned<int32_t>(), output_tensor);
break;
case TensorProto::INT64:
FilloutOutput(value_.GetFromSigned<int64_t>(), output_tensor);
break;
case TensorProto::UINT8:
FilloutOutput(value_.GetFromUnsigned<uint8_t>(), output_tensor);
break;
case TensorProto::UINT16:
FilloutOutput(value_.GetFromUnsigned<uint16_t>(), output_tensor);
break;
case TensorProto::UINT32:
FilloutOutput(value_.GetFromUnsigned<uint32_t>(), output_tensor);
break;
case TensorProto::UINT64:
FilloutOutput(value_.GetFromUnsigned<uint64_t>(), output_tensor);
break;
default:
ORT_THROW("Unsupported value attribute datatype: ", tensor_type_);
break;
}
}
ConstantOfShape::ConstantOfShape(const OpKernelInfo& info) : OpKernel(info) {
ConstantOfShapeBase::ConstantOfShapeBase(const OpKernelInfo& info){
TensorProto t_proto;
if (info.GetAttr<TensorProto>("value", &t_proto).IsOK()) {
ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension");
ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1");
SetValue(t_proto);
SetValueFromTensorProto(t_proto);
} else {
tensor_type_ = TensorProto::FLOAT;
value_.fl_ = 0.f;
float f_value = 0.f;
SetValue(sizeof(float), reinterpret_cast<void*>(&f_value));
}
}
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
auto shape_tensor = ctx->Input<Tensor>(0);
if (shape_tensor->DataType() != DataTypeImpl::GetType<int64_t>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor expected to contain int64 data");
}
auto& input_shape = shape_tensor->Shape();
Status ConstantOfShapeBase::PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const {
const auto shape_tensor = ctx->Input<Tensor>(0);
const auto& input_shape = shape_tensor->Shape();
// If empty the output is a scalar with empty shape
// TensorShape::Size() will still return 1 and we will output
// one value
std::vector<int64_t> output_dims;
if (input_shape.NumDimensions() > 0) {
auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_shape.Size());
output_dims.insert(output_dims.end(), span.cbegin(), span.cend());
}
ORT_ENFORCE(input_shape.NumDimensions() > 0, "Must have a valid input shape.");
const auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_shape.Size());
output_dims.insert(output_dims.end(), span.cbegin(), span.cend());
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);
DispatchTypeAndFillOutput(output_tensor);
(*output_tensor) = ctx->Output(0, output_shape);
return Status::OK();
}
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
Tensor* output_tensor = nullptr;
ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor));
auto output_data = output_tensor->MutableDataRaw();
const void* value_ptr = GetValuePtr();
const auto size = output_tensor->Shape().Size();
const auto element_size = output_tensor->DataType()->Size();
switch (element_size) {
case sizeof(int8_t):
FilloutOutput(*(reinterpret_cast<const int8_t*>(value_ptr)), output_data, size);
break;
case sizeof(int16_t):
FilloutOutput(*(reinterpret_cast<const int16_t*>(value_ptr)), output_data, size);
break;
case sizeof(int32_t):
FilloutOutput(*(reinterpret_cast<const int32_t*>(value_ptr)), output_data, size);
break;
case sizeof(int64_t):
FilloutOutput(*(reinterpret_cast<const int64_t*>(value_ptr)), output_data, size);
break;
default:
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size);
break;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -9,45 +9,56 @@
namespace onnxruntime {
class ConstantOfShape final : public OpKernel {
public:
explicit ConstantOfShape(const OpKernelInfo& info);
class ConstantOfShapeBase {
Status Compute(OpKernelContext* ctx) const override;
protected:
ConstantOfShapeBase(const OpKernelInfo& info);
Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const;
void* GetValuePtr() const { return p_value_; }
private:
ONNX_NAMESPACE::TensorProto_DataType tensor_type_;
union Value {
float fl_;
MLFloat16 fl16_;
double dbl_;
int64_t i64_;
uint64_t ui64_;
Value() : ui64_(0) {}
union SizeBasedValue {
int8_t int8_;
int16_t int16_;
int32_t int32_;
int64_t int64_;
} s_value_;
void* p_value_;
float GetFloat() const {
return fl_;
void SetValue(size_t size, void* value) {
switch (size) {
case sizeof(int8_t):
s_value_.int8_ = *(reinterpret_cast<int8_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int8_));
break;
case sizeof(int16_t):
s_value_.int16_ = *(reinterpret_cast<int16_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int16_));
break;
case sizeof(int32_t):
s_value_.int32_ = *(reinterpret_cast<int32_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int32_));
break;
case sizeof(int64_t):
s_value_.int64_ = *(reinterpret_cast<int64_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int64_));
break;
default:
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", size);
break;
}
}
MLFloat16 GetFloat16() const {
return fl16_;
}
void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&);
};
double GetDouble() const {
return dbl_;
}
class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel {
public:
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {};
template <class T>
T GetFromSigned() const;
template <class T>
T GetFromUnsigned() const;
} value_;
void SetValue(const ONNX_NAMESPACE::TensorProto&);
void DispatchTypeAndFillOutput(Tensor* output_tensor) const;
Status Compute(OpKernelContext* ctx) const override;
};
} // namespace onnxruntime

View file

@ -568,6 +568,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int32_t, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, Split);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, ConstantOfShape);
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -896,6 +897,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int32_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>,
};
for (auto& function_table_entry : function_table) {
@ -1060,19 +1062,24 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// cast is not compute heavy, and may be placed outside
}
if (!not_supported && !force_inside) {
// Note that nodes with only inputs from initializer would not be place on CUDA
// Ideally, those nodes should be eliminated in constant folding
bool all_non_initializer_inputs_from_outside = true;
bool should_force_outside = true;
node.ForEachWithIndex(
node.InputDefs(),
[&](const NodeArg& def, size_t) {
[&](const NodeArg& def, size_t index) {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
if (!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def))
all_non_initializer_inputs_from_outside = false;
// The input is not a initializer and the input is from CPU
// or the input declared as CPU memory and is from CPU
// in that case we should still keep the node on CUDA
if ((!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) ||
(defs_outside_cuda.count(&def) && cuda_kernel_def->kernel_def->IsInputOnCpu(index)))
should_force_outside = false;
return Status::OK();
});
if (all_non_initializer_inputs_from_outside) {
if (should_force_outside) {
force_outside = true;
}
}

View file

@ -19,6 +19,13 @@ __global__ void _Fill(
output_data[id] = val;
}
template <typename T>
void Fill(T* output, T value, int64_t count) {
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
_Fill<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(output, value, N);
}
template <typename T>
class ConstantBufferImpl : public IConstantBuffer<T> {
public:
@ -38,9 +45,7 @@ class ConstantBufferImpl : public IConstantBuffer<T> {
CUDA_CALL_THROW(cudaMalloc(&buffer_, count * sizeof(T)));
count_ = count;
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
_Fill<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(buffer_, val_, N);
Fill(buffer_, val_, count);
}
return buffer_;
}
@ -60,5 +65,12 @@ template std::unique_ptr<IConstantBuffer<float>> CreateConstantOnes<float>();
template std::unique_ptr<IConstantBuffer<double>> CreateConstantOnes<double>();
template std::unique_ptr<IConstantBuffer<half>> CreateConstantOnes<half>();
#define SPECIALIZED_FILL(T) \
template void Fill<T>(T* output, T value, int64_t count);
SPECIALIZED_FILL(int8_t)
SPECIALIZED_FILL(int16_t)
SPECIALIZED_FILL(int32_t)
SPECIALIZED_FILL(int64_t)
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "constant_of_shape.h"
#include "core/providers/common.h"
#include "gsl/span"
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
ConstantOfShape,
kOnnxDomain,
9,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0)
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T2", DataTypeImpl::AllFixedSizeTensorTypes()),
ConstantOfShape);
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
Tensor* output_tensor = nullptr;
ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor));
auto output_data = output_tensor->MutableDataRaw();
const auto size = output_tensor->Shape().Size();
const void* value_ptr = GetValuePtr();
const auto element_size = output_tensor->DataType()->Size();
switch (element_size) {
case sizeof(int8_t):
cuda::Fill(reinterpret_cast<int8_t*>(output_data), *(reinterpret_cast<const int8_t*>(value_ptr)), size);
break;
case sizeof(int16_t):
cuda::Fill(reinterpret_cast<int16_t*>(output_data), *(reinterpret_cast<const int16_t*>(value_ptr)), size);
break;
case sizeof(int32_t):
cuda::Fill(reinterpret_cast<int32_t*>(output_data), *(reinterpret_cast<const int32_t*>(value_ptr)), size);
break;
case sizeof(int64_t):
cuda::Fill(reinterpret_cast<int64_t*>(output_data), *(reinterpret_cast<const int64_t*>(value_ptr)), size);
break;
default:
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size);
break;
}
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/data_types.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/generator/constant_of_shape.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace cuda {
class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel {
public:
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {};
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape);
Status Compute(OpKernelContext* ctx) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -30,5 +30,8 @@ class IConstantBuffer {
template <typename T>
std::unique_ptr<IConstantBuffer<T>> CreateConstantOnes();
template <typename T>
void Fill(T* output, T value, int64_t count);
} // namespace cuda
} // namespace onnxruntime