diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc index 8407c78c08..2482096647 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -29,159 +29,131 @@ ONNX_CPU_OPERATOR_KERNEL( DataTypeImpl::GetTensorType()}), ConstantOfShape); -#define FETCH_VALUE_DATA(field, c_type) \ - { \ - c_type t; \ - auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &t, 1); \ - ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \ - field = t; \ +#define FETCH_VALUE_DATA(c_type) \ + { \ + c_type val; \ + auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1); \ + ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \ + SetValue(sizeof(c_type), reinterpret_cast(&val)); \ } -void onnxruntime::ConstantOfShape::SetValue(const ONNX_NAMESPACE::TensorProto& t_proto) { +void onnxruntime::ConstantOfShapeBase::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) { using namespace utils; ORT_ENFORCE(t_proto.has_data_type()); ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type())); - tensor_type_ = static_cast(t_proto.data_type()); + const auto tensor_type = static_cast(t_proto.data_type()); const void* const raw_data = t_proto.has_raw_data() ? t_proto.raw_data().data() : nullptr; const size_t raw_data_len = t_proto.has_raw_data() ? t_proto.raw_data().size() : 0; - switch (tensor_type_) { + switch (tensor_type) { case TensorProto::BOOL: - FETCH_VALUE_DATA(value_.ui64_, bool); + FETCH_VALUE_DATA(bool); break; case TensorProto::FLOAT: - FETCH_VALUE_DATA(value_.fl_, float); + FETCH_VALUE_DATA(float); break; case TensorProto::FLOAT16: - FETCH_VALUE_DATA(value_.fl16_, MLFloat16); + FETCH_VALUE_DATA(MLFloat16); break; case TensorProto::DOUBLE: - FETCH_VALUE_DATA(value_.dbl_, double); + FETCH_VALUE_DATA(double); break; case TensorProto::INT8: - FETCH_VALUE_DATA(value_.i64_, int8_t); + FETCH_VALUE_DATA(int8_t); break; case TensorProto::INT16: - FETCH_VALUE_DATA(value_.i64_, int16_t); + FETCH_VALUE_DATA(int16_t); break; case TensorProto::INT32: - FETCH_VALUE_DATA(value_.i64_, int32_t); + FETCH_VALUE_DATA(int32_t); break; case TensorProto::INT64: - FETCH_VALUE_DATA(value_.i64_, int64_t); + FETCH_VALUE_DATA(int64_t); break; case TensorProto::UINT8: - FETCH_VALUE_DATA(value_.ui64_, uint8_t); + FETCH_VALUE_DATA(uint8_t); break; case TensorProto::UINT16: - FETCH_VALUE_DATA(value_.ui64_, uint16_t); + FETCH_VALUE_DATA(uint16_t); break; case TensorProto::UINT32: - FETCH_VALUE_DATA(value_.ui64_, uint32_t); + FETCH_VALUE_DATA(uint32_t); break; case TensorProto::UINT64: - FETCH_VALUE_DATA(value_.ui64_, uint64_t); + FETCH_VALUE_DATA(uint64_t); break; default: - ORT_THROW("Unsupported value attribute datatype: ", tensor_type_); + ORT_THROW("Unsupported value attribute datatype: ", tensor_type); break; } } #undef FETCH_VALUE_DATA -template -inline T onnxruntime::ConstantOfShape::Value::GetFromSigned() const { - return static_cast(i64_); -} template -inline T onnxruntime::ConstantOfShape::Value::GetFromUnsigned() const { - return static_cast(ui64_); -} - -template -inline void FilloutOutput(T value, Tensor* output_tensor) { - auto out = gsl::make_span(output_tensor->template MutableData(), output_tensor->Shape().Size()); +inline void FilloutOutput(T value, void* output_data, size_t size) { + auto out = gsl::make_span(reinterpret_cast(output_data), size); std::fill(out.begin(), out.end(), value); } -void onnxruntime::ConstantOfShape::DispatchTypeAndFillOutput(Tensor* output_tensor) const { - switch (tensor_type_) { - case TensorProto::BOOL: - FilloutOutput(value_.GetFromUnsigned(), output_tensor); - break; - case TensorProto::FLOAT: - FilloutOutput(value_.GetFloat(), output_tensor); - break; - case TensorProto::FLOAT16: - FilloutOutput(value_.GetFloat16(), output_tensor); - break; - case TensorProto::DOUBLE: - FilloutOutput(value_.GetDouble(), output_tensor); - break; - case TensorProto::INT8: - FilloutOutput(value_.GetFromSigned(), output_tensor); - break; - case TensorProto::INT16: - FilloutOutput(value_.GetFromSigned(), output_tensor); - break; - case TensorProto::INT32: - FilloutOutput(value_.GetFromSigned(), output_tensor); - break; - case TensorProto::INT64: - FilloutOutput(value_.GetFromSigned(), output_tensor); - break; - case TensorProto::UINT8: - FilloutOutput(value_.GetFromUnsigned(), output_tensor); - break; - case TensorProto::UINT16: - FilloutOutput(value_.GetFromUnsigned(), output_tensor); - break; - case TensorProto::UINT32: - FilloutOutput(value_.GetFromUnsigned(), output_tensor); - break; - case TensorProto::UINT64: - FilloutOutput(value_.GetFromUnsigned(), output_tensor); - break; - default: - ORT_THROW("Unsupported value attribute datatype: ", tensor_type_); - break; - } -} - -ConstantOfShape::ConstantOfShape(const OpKernelInfo& info) : OpKernel(info) { +ConstantOfShapeBase::ConstantOfShapeBase(const OpKernelInfo& info){ TensorProto t_proto; if (info.GetAttr("value", &t_proto).IsOK()) { ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension"); ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1"); - SetValue(t_proto); + SetValueFromTensorProto(t_proto); } else { - tensor_type_ = TensorProto::FLOAT; - value_.fl_ = 0.f; + float f_value = 0.f; + SetValue(sizeof(float), reinterpret_cast(&f_value)); } } -Status ConstantOfShape::Compute(OpKernelContext* ctx) const { - auto shape_tensor = ctx->Input(0); - - if (shape_tensor->DataType() != DataTypeImpl::GetType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor expected to contain int64 data"); - } - - auto& input_shape = shape_tensor->Shape(); +Status ConstantOfShapeBase::PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const { + const auto shape_tensor = ctx->Input(0); + const auto& input_shape = shape_tensor->Shape(); // If empty the output is a scalar with empty shape // TensorShape::Size() will still return 1 and we will output // one value std::vector output_dims; - if (input_shape.NumDimensions() > 0) { - auto span = gsl::make_span(shape_tensor->Data(), input_shape.Size()); - output_dims.insert(output_dims.end(), span.cbegin(), span.cend()); - } + ORT_ENFORCE(input_shape.NumDimensions() > 0, "Must have a valid input shape."); + + const auto span = gsl::make_span(shape_tensor->Data(), input_shape.Size()); + output_dims.insert(output_dims.end(), span.cbegin(), span.cend()); TensorShape output_shape(output_dims); - auto output_tensor = ctx->Output(0, output_shape); - DispatchTypeAndFillOutput(output_tensor); + (*output_tensor) = ctx->Output(0, output_shape); + + return Status::OK(); +} + +Status ConstantOfShape::Compute(OpKernelContext* ctx) const { + + Tensor* output_tensor = nullptr; + ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor)); + + auto output_data = output_tensor->MutableDataRaw(); + const void* value_ptr = GetValuePtr(); + const auto size = output_tensor->Shape().Size(); + const auto element_size = output_tensor->DataType()->Size(); + switch (element_size) { + case sizeof(int8_t): + FilloutOutput(*(reinterpret_cast(value_ptr)), output_data, size); + break; + case sizeof(int16_t): + FilloutOutput(*(reinterpret_cast(value_ptr)), output_data, size); + break; + case sizeof(int32_t): + FilloutOutput(*(reinterpret_cast(value_ptr)), output_data, size); + break; + case sizeof(int64_t): + FilloutOutput(*(reinterpret_cast(value_ptr)), output_data, size); + break; + default: + ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size); + break; + } + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape.h index 217e94c8d7..70efadef3f 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.h @@ -9,45 +9,56 @@ namespace onnxruntime { -class ConstantOfShape final : public OpKernel { - public: - explicit ConstantOfShape(const OpKernelInfo& info); +class ConstantOfShapeBase { - Status Compute(OpKernelContext* ctx) const override; + protected: + ConstantOfShapeBase(const OpKernelInfo& info); + + Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const; + + void* GetValuePtr() const { return p_value_; } private: - ONNX_NAMESPACE::TensorProto_DataType tensor_type_; - union Value { - float fl_; - MLFloat16 fl16_; - double dbl_; - int64_t i64_; - uint64_t ui64_; - Value() : ui64_(0) {} + union SizeBasedValue { + int8_t int8_; + int16_t int16_; + int32_t int32_; + int64_t int64_; + } s_value_; + void* p_value_; - float GetFloat() const { - return fl_; + void SetValue(size_t size, void* value) { + switch (size) { + case sizeof(int8_t): + s_value_.int8_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int8_)); + break; + case sizeof(int16_t): + s_value_.int16_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int16_)); + break; + case sizeof(int32_t): + s_value_.int32_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int32_)); + break; + case sizeof(int64_t): + s_value_.int64_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int64_)); + break; + default: + ORT_THROW("Unsupported value attribute datatype with sizeof=: ", size); + break; } + } - MLFloat16 GetFloat16() const { - return fl16_; - } + void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&); +}; - double GetDouble() const { - return dbl_; - } +class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel { + public: + explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {}; - template - T GetFromSigned() const; - - template - T GetFromUnsigned() const; - - } value_; - - void SetValue(const ONNX_NAMESPACE::TensorProto&); - - void DispatchTypeAndFillOutput(Tensor* output_tensor) const; + Status Compute(OpKernelContext* ctx) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6abbadc2d2..aa50128fd1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -568,6 +568,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, int32_t, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, ConstantOfShape); static void RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -896,6 +897,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -1060,19 +1062,24 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // cast is not compute heavy, and may be placed outside } + if (!not_supported && !force_inside) { // Note that nodes with only inputs from initializer would not be place on CUDA // Ideally, those nodes should be eliminated in constant folding - bool all_non_initializer_inputs_from_outside = true; + bool should_force_outside = true; node.ForEachWithIndex( node.InputDefs(), - [&](const NodeArg& def, size_t) { + [&](const NodeArg& def, size_t index) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) - all_non_initializer_inputs_from_outside = false; + // The input is not a initializer and the input is from CPU + // or the input declared as CPU memory and is from CPU + // in that case we should still keep the node on CUDA + if ((!graph.GetInitializedTensor(def.Name(), initializer) && !defs_outside_cuda.count(&def)) || + (defs_outside_cuda.count(&def) && cuda_kernel_def->kernel_def->IsInputOnCpu(index))) + should_force_outside = false; return Status::OK(); }); - if (all_non_initializer_inputs_from_outside) { + if (should_force_outside) { force_outside = true; } } diff --git a/onnxruntime/core/providers/cuda/cuda_utils.cu b/onnxruntime/core/providers/cuda/cuda_utils.cu index d6cdebb17c..bbfbefac9b 100644 --- a/onnxruntime/core/providers/cuda/cuda_utils.cu +++ b/onnxruntime/core/providers/cuda/cuda_utils.cu @@ -19,6 +19,13 @@ __global__ void _Fill( output_data[id] = val; } +template +void Fill(T* output, T value, int64_t count) { + int blocksPerGrid = (int)(ceil(static_cast(count) / GridDim::maxThreadsPerBlock)); + CUDA_LONG N = static_cast(count); + _Fill<<>>(output, value, N); +} + template class ConstantBufferImpl : public IConstantBuffer { public: @@ -38,9 +45,7 @@ class ConstantBufferImpl : public IConstantBuffer { CUDA_CALL_THROW(cudaMalloc(&buffer_, count * sizeof(T))); count_ = count; - int blocksPerGrid = (int)(ceil(static_cast(count) / GridDim::maxThreadsPerBlock)); - CUDA_LONG N = static_cast(count); - _Fill<<>>(buffer_, val_, N); + Fill(buffer_, val_, count); } return buffer_; } @@ -60,5 +65,12 @@ template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); +#define SPECIALIZED_FILL(T) \ +template void Fill(T* output, T value, int64_t count); + +SPECIALIZED_FILL(int8_t) +SPECIALIZED_FILL(int16_t) +SPECIALIZED_FILL(int32_t) +SPECIALIZED_FILL(int64_t) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc b/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc new file mode 100644 index 0000000000..45548da1bd --- /dev/null +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "constant_of_shape.h" +#include "core/providers/common.h" +#include "gsl/span" + +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + ConstantOfShape, + kOnnxDomain, + 9, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::AllFixedSizeTensorTypes()), + ConstantOfShape); + +Status ConstantOfShape::Compute(OpKernelContext* ctx) const { + Tensor* output_tensor = nullptr; + ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor)); + auto output_data = output_tensor->MutableDataRaw(); + const auto size = output_tensor->Shape().Size(); + const void* value_ptr = GetValuePtr(); + const auto element_size = output_tensor->DataType()->Size(); + switch (element_size) { + case sizeof(int8_t): + cuda::Fill(reinterpret_cast(output_data), *(reinterpret_cast(value_ptr)), size); + break; + case sizeof(int16_t): + cuda::Fill(reinterpret_cast(output_data), *(reinterpret_cast(value_ptr)), size); + break; + case sizeof(int32_t): + cuda::Fill(reinterpret_cast(output_data), *(reinterpret_cast(value_ptr)), size); + break; + case sizeof(int64_t): + cuda::Fill(reinterpret_cast(output_data), *(reinterpret_cast(value_ptr)), size); + break; + default: + ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size); + break; + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h new file mode 100644 index 0000000000..b31585fbfd --- /dev/null +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cpu/generator/constant_of_shape.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel { + public: + explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {}; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape); + + Status Compute(OpKernelContext* ctx) const override; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index 741087685f..faa041ee27 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -30,5 +30,8 @@ class IConstantBuffer { template std::unique_ptr> CreateConstantOnes(); +template +void Fill(T* output, T value, int64_t count); + } // namespace cuda } // namespace onnxruntime