mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
ConstantOfShape CUDA implementation (#1168)
* ConstantOfShape CUDA implementation * Enhance the fallback logic, so the case that Shape -> ... -> ConstantOfShape won't fallback ConstantOfShape to CPU provider * move shared code to cpu implementation * do the fill based on sizeof(data_type) * update method access level
This commit is contained in:
parent
e43e64bf84
commit
8d68098c20
7 changed files with 218 additions and 134 deletions
|
|
@ -29,159 +29,131 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
DataTypeImpl::GetTensorType<bool>()}),
|
||||
ConstantOfShape);
|
||||
|
||||
#define FETCH_VALUE_DATA(field, c_type) \
|
||||
{ \
|
||||
c_type t; \
|
||||
auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &t, 1); \
|
||||
ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \
|
||||
field = t; \
|
||||
#define FETCH_VALUE_DATA(c_type) \
|
||||
{ \
|
||||
c_type val; \
|
||||
auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1); \
|
||||
ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \
|
||||
SetValue(sizeof(c_type), reinterpret_cast<void*>(&val)); \
|
||||
}
|
||||
|
||||
void onnxruntime::ConstantOfShape::SetValue(const ONNX_NAMESPACE::TensorProto& t_proto) {
|
||||
void onnxruntime::ConstantOfShapeBase::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) {
|
||||
using namespace utils;
|
||||
ORT_ENFORCE(t_proto.has_data_type());
|
||||
ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type()));
|
||||
tensor_type_ = static_cast<TensorProto_DataType>(t_proto.data_type());
|
||||
const auto tensor_type = static_cast<TensorProto_DataType>(t_proto.data_type());
|
||||
const void* const raw_data = t_proto.has_raw_data() ? t_proto.raw_data().data() : nullptr;
|
||||
const size_t raw_data_len = t_proto.has_raw_data() ? t_proto.raw_data().size() : 0;
|
||||
switch (tensor_type_) {
|
||||
switch (tensor_type) {
|
||||
case TensorProto::BOOL:
|
||||
FETCH_VALUE_DATA(value_.ui64_, bool);
|
||||
FETCH_VALUE_DATA(bool);
|
||||
break;
|
||||
case TensorProto::FLOAT:
|
||||
FETCH_VALUE_DATA(value_.fl_, float);
|
||||
FETCH_VALUE_DATA(float);
|
||||
break;
|
||||
case TensorProto::FLOAT16:
|
||||
FETCH_VALUE_DATA(value_.fl16_, MLFloat16);
|
||||
FETCH_VALUE_DATA(MLFloat16);
|
||||
break;
|
||||
case TensorProto::DOUBLE:
|
||||
FETCH_VALUE_DATA(value_.dbl_, double);
|
||||
FETCH_VALUE_DATA(double);
|
||||
break;
|
||||
case TensorProto::INT8:
|
||||
FETCH_VALUE_DATA(value_.i64_, int8_t);
|
||||
FETCH_VALUE_DATA(int8_t);
|
||||
break;
|
||||
case TensorProto::INT16:
|
||||
FETCH_VALUE_DATA(value_.i64_, int16_t);
|
||||
FETCH_VALUE_DATA(int16_t);
|
||||
break;
|
||||
case TensorProto::INT32:
|
||||
FETCH_VALUE_DATA(value_.i64_, int32_t);
|
||||
FETCH_VALUE_DATA(int32_t);
|
||||
break;
|
||||
case TensorProto::INT64:
|
||||
FETCH_VALUE_DATA(value_.i64_, int64_t);
|
||||
FETCH_VALUE_DATA(int64_t);
|
||||
break;
|
||||
case TensorProto::UINT8:
|
||||
FETCH_VALUE_DATA(value_.ui64_, uint8_t);
|
||||
FETCH_VALUE_DATA(uint8_t);
|
||||
break;
|
||||
case TensorProto::UINT16:
|
||||
FETCH_VALUE_DATA(value_.ui64_, uint16_t);
|
||||
FETCH_VALUE_DATA(uint16_t);
|
||||
break;
|
||||
case TensorProto::UINT32:
|
||||
FETCH_VALUE_DATA(value_.ui64_, uint32_t);
|
||||
FETCH_VALUE_DATA(uint32_t);
|
||||
break;
|
||||
case TensorProto::UINT64:
|
||||
FETCH_VALUE_DATA(value_.ui64_, uint64_t);
|
||||
FETCH_VALUE_DATA(uint64_t);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported value attribute datatype: ", tensor_type_);
|
||||
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#undef FETCH_VALUE_DATA
|
||||
|
||||
template <class T>
|
||||
inline T onnxruntime::ConstantOfShape::Value::GetFromSigned() const {
|
||||
return static_cast<T>(i64_);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline T onnxruntime::ConstantOfShape::Value::GetFromUnsigned() const {
|
||||
return static_cast<T>(ui64_);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline void FilloutOutput(T value, Tensor* output_tensor) {
|
||||
auto out = gsl::make_span(output_tensor->template MutableData<T>(), output_tensor->Shape().Size());
|
||||
inline void FilloutOutput(T value, void* output_data, size_t size) {
|
||||
auto out = gsl::make_span(reinterpret_cast<T*>(output_data), size);
|
||||
std::fill(out.begin(), out.end(), value);
|
||||
}
|
||||
|
||||
void onnxruntime::ConstantOfShape::DispatchTypeAndFillOutput(Tensor* output_tensor) const {
|
||||
switch (tensor_type_) {
|
||||
case TensorProto::BOOL:
|
||||
FilloutOutput(value_.GetFromUnsigned<bool>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::FLOAT:
|
||||
FilloutOutput(value_.GetFloat(), output_tensor);
|
||||
break;
|
||||
case TensorProto::FLOAT16:
|
||||
FilloutOutput(value_.GetFloat16(), output_tensor);
|
||||
break;
|
||||
case TensorProto::DOUBLE:
|
||||
FilloutOutput(value_.GetDouble(), output_tensor);
|
||||
break;
|
||||
case TensorProto::INT8:
|
||||
FilloutOutput(value_.GetFromSigned<int8_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::INT16:
|
||||
FilloutOutput(value_.GetFromSigned<int16_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::INT32:
|
||||
FilloutOutput(value_.GetFromSigned<int32_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::INT64:
|
||||
FilloutOutput(value_.GetFromSigned<int64_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::UINT8:
|
||||
FilloutOutput(value_.GetFromUnsigned<uint8_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::UINT16:
|
||||
FilloutOutput(value_.GetFromUnsigned<uint16_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::UINT32:
|
||||
FilloutOutput(value_.GetFromUnsigned<uint32_t>(), output_tensor);
|
||||
break;
|
||||
case TensorProto::UINT64:
|
||||
FilloutOutput(value_.GetFromUnsigned<uint64_t>(), output_tensor);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported value attribute datatype: ", tensor_type_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ConstantOfShape::ConstantOfShape(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ConstantOfShapeBase::ConstantOfShapeBase(const OpKernelInfo& info){
|
||||
TensorProto t_proto;
|
||||
if (info.GetAttr<TensorProto>("value", &t_proto).IsOK()) {
|
||||
ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension");
|
||||
ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1");
|
||||
SetValue(t_proto);
|
||||
SetValueFromTensorProto(t_proto);
|
||||
} else {
|
||||
tensor_type_ = TensorProto::FLOAT;
|
||||
value_.fl_ = 0.f;
|
||||
float f_value = 0.f;
|
||||
SetValue(sizeof(float), reinterpret_cast<void*>(&f_value));
|
||||
}
|
||||
}
|
||||
|
||||
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
|
||||
auto shape_tensor = ctx->Input<Tensor>(0);
|
||||
|
||||
if (shape_tensor->DataType() != DataTypeImpl::GetType<int64_t>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor expected to contain int64 data");
|
||||
}
|
||||
|
||||
auto& input_shape = shape_tensor->Shape();
|
||||
Status ConstantOfShapeBase::PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const {
|
||||
const auto shape_tensor = ctx->Input<Tensor>(0);
|
||||
const auto& input_shape = shape_tensor->Shape();
|
||||
|
||||
// If empty the output is a scalar with empty shape
|
||||
// TensorShape::Size() will still return 1 and we will output
|
||||
// one value
|
||||
std::vector<int64_t> output_dims;
|
||||
if (input_shape.NumDimensions() > 0) {
|
||||
auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_shape.Size());
|
||||
output_dims.insert(output_dims.end(), span.cbegin(), span.cend());
|
||||
}
|
||||
ORT_ENFORCE(input_shape.NumDimensions() > 0, "Must have a valid input shape.");
|
||||
|
||||
const auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_shape.Size());
|
||||
output_dims.insert(output_dims.end(), span.cbegin(), span.cend());
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
DispatchTypeAndFillOutput(output_tensor);
|
||||
(*output_tensor) = ctx->Output(0, output_shape);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor));
|
||||
|
||||
auto output_data = output_tensor->MutableDataRaw();
|
||||
const void* value_ptr = GetValuePtr();
|
||||
const auto size = output_tensor->Shape().Size();
|
||||
const auto element_size = output_tensor->DataType()->Size();
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
FilloutOutput(*(reinterpret_cast<const int8_t*>(value_ptr)), output_data, size);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
FilloutOutput(*(reinterpret_cast<const int16_t*>(value_ptr)), output_data, size);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
FilloutOutput(*(reinterpret_cast<const int32_t*>(value_ptr)), output_data, size);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
FilloutOutput(*(reinterpret_cast<const int64_t*>(value_ptr)), output_data, size);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size);
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,45 +9,56 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ConstantOfShape final : public OpKernel {
|
||||
public:
|
||||
explicit ConstantOfShape(const OpKernelInfo& info);
|
||||
class ConstantOfShapeBase {
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
protected:
|
||||
ConstantOfShapeBase(const OpKernelInfo& info);
|
||||
|
||||
Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const;
|
||||
|
||||
void* GetValuePtr() const { return p_value_; }
|
||||
|
||||
private:
|
||||
ONNX_NAMESPACE::TensorProto_DataType tensor_type_;
|
||||
union Value {
|
||||
float fl_;
|
||||
MLFloat16 fl16_;
|
||||
double dbl_;
|
||||
int64_t i64_;
|
||||
uint64_t ui64_;
|
||||
Value() : ui64_(0) {}
|
||||
union SizeBasedValue {
|
||||
int8_t int8_;
|
||||
int16_t int16_;
|
||||
int32_t int32_;
|
||||
int64_t int64_;
|
||||
} s_value_;
|
||||
void* p_value_;
|
||||
|
||||
float GetFloat() const {
|
||||
return fl_;
|
||||
void SetValue(size_t size, void* value) {
|
||||
switch (size) {
|
||||
case sizeof(int8_t):
|
||||
s_value_.int8_ = *(reinterpret_cast<int8_t*>(value));
|
||||
p_value_ = reinterpret_cast<void*>(&(s_value_.int8_));
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
s_value_.int16_ = *(reinterpret_cast<int16_t*>(value));
|
||||
p_value_ = reinterpret_cast<void*>(&(s_value_.int16_));
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
s_value_.int32_ = *(reinterpret_cast<int32_t*>(value));
|
||||
p_value_ = reinterpret_cast<void*>(&(s_value_.int32_));
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
s_value_.int64_ = *(reinterpret_cast<int64_t*>(value));
|
||||
p_value_ = reinterpret_cast<void*>(&(s_value_.int64_));
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
MLFloat16 GetFloat16() const {
|
||||
return fl16_;
|
||||
}
|
||||
void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&);
|
||||
};
|
||||
|
||||
double GetDouble() const {
|
||||
return dbl_;
|
||||
}
|
||||
class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel {
|
||||
public:
|
||||
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {};
|
||||
|
||||
template <class T>
|
||||
T GetFromSigned() const;
|
||||
|
||||
template <class T>
|
||||
T GetFromUnsigned() const;
|
||||
|
||||
} value_;
|
||||
|
||||
void SetValue(const ONNX_NAMESPACE::TensorProto&);
|
||||
|
||||
void DispatchTypeAndFillOutput(Tensor* output_tensor) const;
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -568,6 +568,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int32_t, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, Split);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, ConstantOfShape);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -896,6 +897,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int32_t, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
@ -1060,19 +1062,24 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
// cast is not compute heavy, and may be placed outside
|
||||
}
|
||||
|
||||
|
||||
if (!not_supported && !force_inside) {
|
||||
// Note that nodes with only inputs from initializer would not be place on CUDA
|
||||
// Ideally, those nodes should be eliminated in constant folding
|
||||
bool all_non_initializer_inputs_from_outside = true;
|
||||
bool should_force_outside = true;
|
||||
node.ForEachWithIndex(
|
||||
node.InputDefs(),
|
||||
[&](const NodeArg& def, size_t) {
|
||||
[&](const NodeArg& def, size_t index) {
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
|
||||
if (!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def))
|
||||
all_non_initializer_inputs_from_outside = false;
|
||||
// The input is not a initializer and the input is from CPU
|
||||
// or the input declared as CPU memory and is from CPU
|
||||
// in that case we should still keep the node on CUDA
|
||||
if ((!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) ||
|
||||
(defs_outside_cuda.count(&def) && cuda_kernel_def->kernel_def->IsInputOnCpu(index)))
|
||||
should_force_outside = false;
|
||||
return Status::OK();
|
||||
});
|
||||
if (all_non_initializer_inputs_from_outside) {
|
||||
if (should_force_outside) {
|
||||
force_outside = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,13 @@ __global__ void _Fill(
|
|||
output_data[id] = val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Fill(T* output, T value, int64_t count) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(count);
|
||||
_Fill<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(output, value, N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class ConstantBufferImpl : public IConstantBuffer<T> {
|
||||
public:
|
||||
|
|
@ -38,9 +45,7 @@ class ConstantBufferImpl : public IConstantBuffer<T> {
|
|||
CUDA_CALL_THROW(cudaMalloc(&buffer_, count * sizeof(T)));
|
||||
count_ = count;
|
||||
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(count);
|
||||
_Fill<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(buffer_, val_, N);
|
||||
Fill(buffer_, val_, count);
|
||||
}
|
||||
return buffer_;
|
||||
}
|
||||
|
|
@ -60,5 +65,12 @@ template std::unique_ptr<IConstantBuffer<float>> CreateConstantOnes<float>();
|
|||
template std::unique_ptr<IConstantBuffer<double>> CreateConstantOnes<double>();
|
||||
template std::unique_ptr<IConstantBuffer<half>> CreateConstantOnes<half>();
|
||||
|
||||
#define SPECIALIZED_FILL(T) \
|
||||
template void Fill<T>(T* output, T value, int64_t count);
|
||||
|
||||
SPECIALIZED_FILL(int8_t)
|
||||
SPECIALIZED_FILL(int16_t)
|
||||
SPECIALIZED_FILL(int32_t)
|
||||
SPECIALIZED_FILL(int64_t)
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "constant_of_shape.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "gsl/span"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
ConstantOfShape,
|
||||
kOnnxDomain,
|
||||
9,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0)
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("T2", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
ConstantOfShape);
|
||||
|
||||
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
|
||||
Tensor* output_tensor = nullptr;
|
||||
ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor));
|
||||
auto output_data = output_tensor->MutableDataRaw();
|
||||
const auto size = output_tensor->Shape().Size();
|
||||
const void* value_ptr = GetValuePtr();
|
||||
const auto element_size = output_tensor->DataType()->Size();
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
cuda::Fill(reinterpret_cast<int8_t*>(output_data), *(reinterpret_cast<const int8_t*>(value_ptr)), size);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
cuda::Fill(reinterpret_cast<int16_t*>(output_data), *(reinterpret_cast<const int16_t*>(value_ptr)), size);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
cuda::Fill(reinterpret_cast<int32_t*>(output_data), *(reinterpret_cast<const int32_t*>(value_ptr)), size);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
cuda::Fill(reinterpret_cast<int64_t*>(output_data), *(reinterpret_cast<const int64_t*>(value_ptr)), size);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size);
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cpu/generator/constant_of_shape.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel {
|
||||
public:
|
||||
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {};
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape);
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -30,5 +30,8 @@ class IConstantBuffer {
|
|||
template <typename T>
|
||||
std::unique_ptr<IConstantBuffer<T>> CreateConstantOnes();
|
||||
|
||||
template <typename T>
|
||||
void Fill(T* output, T value, int64_t count);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue