diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 766c6d96b7..f08df584ba 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -825,18 +825,22 @@ class Graph { // Options to control Graph::Resolve. struct ResolveOptions { + // Whether to override existing types with inferred types. bool override_types = false; + // Names of initializers to keep even if unused (optional). const std::unordered_set* initializer_names_to_preserve = nullptr; + // Whether to set that no proto sync is required after resolving. + // Useful for resolving right after loading from a GraphProto. bool no_proto_sync_required = false; }; /** Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed. 1. Run through all validation rules. - a. Node name and node output's names should be unique. - b. Attribute match between node and op definition. - c. Input/Output match between node and op definition. - d. Graph is acyclic and sort nodes in topological order. + a. Node name and node output's names should be unique. + b. Attribute match between node and op definition. + c. Input/Output match between node and op definition. + d. Graph is acyclic and sort nodes in topological order. 2. Check & Setup inner nodes' dependency. 3. Cleanup function definition lists. Note: the weights for training can't be cleaned during resolve. @@ -849,12 +853,6 @@ class Graph { return Resolve(default_options); } - common::Status ResolveAfterTypeTransformation() { - ResolveOptions options; - options.override_types = true; - return Resolve(options); - } - /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */ const Node* ParentNode() const { return parent_node_; } @@ -1031,8 +1029,7 @@ class Graph { template static auto GetProducerNodeImpl( - TInstance& instance, const std::string& node_arg_name) - -> decltype(instance.GetNode(0)) { + TInstance& instance, const std::string& node_arg_name) -> decltype(instance.GetNode(0)) { auto iter = instance.node_arg_to_producer_node_.find(node_arg_name); if (iter != instance.node_arg_to_producer_node_.end()) { auto node_index = iter->second; @@ -1043,8 +1040,7 @@ class Graph { template static auto GetConsumerNodesImpl( - TInstance& instance, const std::string& node_arg_name) - -> std::vector { + TInstance& instance, const std::string& node_arg_name) -> std::vector { std::vector results; auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name); if (iter != instance.node_arg_to_consumer_nodes_.end()) { diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 0c965cd468..f7c06636a3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -2,10 +2,11 @@ // Licensed under the MIT License. #include "core/providers/common.h" -#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" #include "core/framework/tensorprotoutils.h" #include "fast_gelu.h" #include "fast_gelu_impl.h" +#include "contrib_ops/cpu/bert/bias_gelu_helper.h" namespace onnxruntime { namespace contrib { @@ -32,50 +33,23 @@ FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel } template -Status FastGelu::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* input = ctx->Input(0); +Status FastGelu::ComputeInternal(OpKernelContext* context) const { + ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); - const auto input_dims = input->Shape().GetDims(); - if (input_dims.size() < 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); - } - - size_t num_inputs = OpKernel::Node().InputDefs().size(); - bool has_bias = (num_inputs == 2); - - int input_length = 1; - for (size_t i = 0; i < input_dims.size(); i++) { - input_length *= static_cast(input_dims[i]); - } - - int bias_length = 0; - const Tensor* bias = nullptr; - if (has_bias) { - bias = ctx->Input(1); - const auto bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 1 is expected to have 1 dimensions, got ", bias_dims.size()); - } - if (bias_dims[0] != input_dims[input_dims.size() - 1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 1 dimension 0 should have same length as the last dimension of input 0"); - } - bias_length = static_cast(bias_dims[0]); - } - - Tensor* output = ctx->Output(0, input->Shape()); + const Tensor* input = context->Input(0); + const Tensor* bias = context->Input(1); + Tensor* output = context->Output(0, input->Shape()); + int64_t input_length = input->Shape().Size(); + int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; - if (!LaunchFastGeluKernel( - GetDeviceProp(), - nullptr, - input_length, - bias_length, - reinterpret_cast(input->template Data()), - has_bias ? reinterpret_cast(bias->template Data()) : nullptr, - reinterpret_cast(output->template MutableData()))) { + if (!LaunchFastGeluKernel(GetDeviceProp(), + nullptr, + static_cast(input_length), + static_cast(bias_length), + reinterpret_cast(input->template Data()), + (nullptr != bias) ? reinterpret_cast(bias->template Data()) : nullptr, + reinterpret_cast(output->template MutableData()))) { CUDA_CALL(cudaGetLastError()); return Status(common::ONNXRUNTIME, common::FAIL); } diff --git a/onnxruntime/core/framework/callback.h b/onnxruntime/core/framework/callback.h index 7bdc6c6248..39e3908af5 100644 --- a/onnxruntime/core/framework/callback.h +++ b/onnxruntime/core/framework/callback.h @@ -50,7 +50,9 @@ class ScopedOrtCallbackInvoker { } ScopedOrtCallbackInvoker& operator=(ScopedOrtCallbackInvoker&& other) { - if (callback_.f) callback_.f(callback_.param); + if (callback_.f) { + callback_.f(callback_.param); + } callback_ = other.callback_; other.callback_.f = nullptr; @@ -60,7 +62,9 @@ class ScopedOrtCallbackInvoker { } ~ScopedOrtCallbackInvoker() { - if (callback_.f) callback_.f(callback_.param); + if (callback_.f) { + callback_.f(callback_.param); + } } private: diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 8851da66c6..f054d1b414 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -538,7 +538,7 @@ common::Status ExecuteGraph(const SessionState& session_state, execution_mode, terminate_flag, logger, only_execute_path_to_fetches); return status; -} // namespace utils +} common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager, const std::vector& feeds, std::vector& fetches, diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 16e2a504a2..a5e4c3c73f 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -719,7 +719,9 @@ void Node::ReplaceDefs(const std::mapResolve(/* no_proto_sync_required */ true); +// ResolveOptions options; +// options.no_proto_sync_required = true; +// auto status = new_graph->Resolve(options); // return status; //} using google::protobuf::RepeatedPtrField; @@ -2873,10 +2875,10 @@ Status Graph::InlineFunction(Node& node) { for (const auto& subgraph_node : subgraph.Nodes()) { if (subgraph_node.OpType() == kConstant) { // Copy constant nodes _value to name_to_initial_tensor_ - const gsl::not_null - tensor{graph_proto_->add_initializer()}; - *tensor = subgraph_node.GetAttributes().at("value").t(); - *(tensor->mutable_name()) = subgraph_node.OutputDefs()[0]->Name(); + ONNX_NAMESPACE::NodeProto subgraph_node_proto{}; + subgraph_node.ToProto(subgraph_node_proto); + const gsl::not_null tensor{graph_proto_->add_initializer()}; + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, *tensor)); name_to_initial_tensor_[tensor->name()] = tensor; } else { std::vector inputs, outputs; diff --git a/onnxruntime/python/session.py b/onnxruntime/python/session.py index f17736ee0a..a769cbb1c5 100644 --- a/onnxruntime/python/session.py +++ b/onnxruntime/python/session.py @@ -8,7 +8,7 @@ import os from onnxruntime.capi import _pybind_state as C -def getOrtDeviceType(device): +def get_ort_device_type(device): if device == 'cuda': return C.OrtDevice.cuda() elif device == 'cpu': @@ -196,12 +196,12 @@ class IOBinding: def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr): self._iobinding.bind_input(name, - C.OrtDevice(getOrtDeviceType(device_type), C.OrtDevice.default_memory(), device_id), + C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id), element_type, shape, buffer_ptr) def bind_output(self, name, device_type, device_id, element_type, shape, buffer_ptr): self._iobinding.bind_output(name, - C.OrtDevice(getOrtDeviceType(device_type), C.OrtDevice.default_memory(), device_id), + C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id), element_type, shape, buffer_ptr) def clear_binding_inputs(self):