mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Addressing review comments (#3690)
- https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414359326 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414359463 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414360023 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414361667 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414368707 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414371480 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414379362 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414374516 - https://github.com/microsoft/onnxruntime/pull/3681#discussion_r414801087
This commit is contained in:
parent
7347c73139
commit
4aa033b99e
6 changed files with 43 additions and 67 deletions
|
|
@ -825,18 +825,22 @@ class Graph {
|
|||
|
||||
// Options to control Graph::Resolve.
|
||||
struct ResolveOptions {
|
||||
// Whether to override existing types with inferred types.
|
||||
bool override_types = false;
|
||||
// Names of initializers to keep even if unused (optional).
|
||||
const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr;
|
||||
// Whether to set that no proto sync is required after resolving.
|
||||
// Useful for resolving right after loading from a GraphProto.
|
||||
bool no_proto_sync_required = false;
|
||||
};
|
||||
|
||||
/**
|
||||
Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed.
|
||||
1. Run through all validation rules.
|
||||
a. Node name and node output's names should be unique.
|
||||
b. Attribute match between node and op definition.
|
||||
c. Input/Output match between node and op definition.
|
||||
d. Graph is acyclic and sort nodes in topological order.
|
||||
a. Node name and node output's names should be unique.
|
||||
b. Attribute match between node and op definition.
|
||||
c. Input/Output match between node and op definition.
|
||||
d. Graph is acyclic and sort nodes in topological order.
|
||||
2. Check & Setup inner nodes' dependency.
|
||||
3. Cleanup function definition lists.
|
||||
Note: the weights for training can't be cleaned during resolve.
|
||||
|
|
@ -849,12 +853,6 @@ class Graph {
|
|||
return Resolve(default_options);
|
||||
}
|
||||
|
||||
common::Status ResolveAfterTypeTransformation() {
|
||||
ResolveOptions options;
|
||||
options.override_types = true;
|
||||
return Resolve(options);
|
||||
}
|
||||
|
||||
/** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */
|
||||
const Node* ParentNode() const { return parent_node_; }
|
||||
|
||||
|
|
@ -1031,8 +1029,7 @@ class Graph {
|
|||
|
||||
template <typename TInstance>
|
||||
static auto GetProducerNodeImpl(
|
||||
TInstance& instance, const std::string& node_arg_name)
|
||||
-> decltype(instance.GetNode(0)) {
|
||||
TInstance& instance, const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
|
||||
auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
|
||||
if (iter != instance.node_arg_to_producer_node_.end()) {
|
||||
auto node_index = iter->second;
|
||||
|
|
@ -1043,8 +1040,7 @@ class Graph {
|
|||
|
||||
template <typename TInstance>
|
||||
static auto GetConsumerNodesImpl(
|
||||
TInstance& instance, const std::string& node_arg_name)
|
||||
-> std::vector<decltype(instance.GetNode(0))> {
|
||||
TInstance& instance, const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
|
||||
std::vector<decltype(instance.GetNode(0))> results;
|
||||
auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
|
||||
if (iter != instance.node_arg_to_consumer_nodes_.end()) {
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "fast_gelu.h"
|
||||
#include "fast_gelu_impl.h"
|
||||
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -32,50 +33,23 @@ FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status FastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* input = ctx->Input<Tensor>(0);
|
||||
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
|
||||
|
||||
const auto input_dims = input->Shape().GetDims();
|
||||
if (input_dims.size() < 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
|
||||
}
|
||||
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
bool has_bias = (num_inputs == 2);
|
||||
|
||||
int input_length = 1;
|
||||
for (size_t i = 0; i < input_dims.size(); i++) {
|
||||
input_length *= static_cast<int>(input_dims[i]);
|
||||
}
|
||||
|
||||
int bias_length = 0;
|
||||
const Tensor* bias = nullptr;
|
||||
if (has_bias) {
|
||||
bias = ctx->Input<Tensor>(1);
|
||||
const auto bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 is expected to have 1 dimensions, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != input_dims[input_dims.size() - 1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 dimension 0 should have same length as the last dimension of input 0");
|
||||
}
|
||||
bias_length = static_cast<int>(bias_dims[0]);
|
||||
}
|
||||
|
||||
Tensor* output = ctx->Output(0, input->Shape());
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* bias = context->Input<Tensor>(1);
|
||||
Tensor* output = context->Output(0, input->Shape());
|
||||
|
||||
int64_t input_length = input->Shape().Size();
|
||||
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
if (!LaunchFastGeluKernel<CudaT>(
|
||||
GetDeviceProp(),
|
||||
nullptr,
|
||||
input_length,
|
||||
bias_length,
|
||||
reinterpret_cast<const CudaT*>(input->template Data<T>()),
|
||||
has_bias ? reinterpret_cast<const CudaT*>(bias->template Data<T>()) : nullptr,
|
||||
reinterpret_cast<CudaT*>(output->template MutableData<T>()))) {
|
||||
if (!LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
|
||||
nullptr,
|
||||
static_cast<int>(input_length),
|
||||
static_cast<int>(bias_length),
|
||||
reinterpret_cast<const CudaT*>(input->template Data<T>()),
|
||||
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->template Data<T>()) : nullptr,
|
||||
reinterpret_cast<CudaT*>(output->template MutableData<T>()))) {
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,7 +50,9 @@ class ScopedOrtCallbackInvoker {
|
|||
}
|
||||
|
||||
ScopedOrtCallbackInvoker& operator=(ScopedOrtCallbackInvoker&& other) {
|
||||
if (callback_.f) callback_.f(callback_.param);
|
||||
if (callback_.f) {
|
||||
callback_.f(callback_.param);
|
||||
}
|
||||
|
||||
callback_ = other.callback_;
|
||||
other.callback_.f = nullptr;
|
||||
|
|
@ -60,7 +62,9 @@ class ScopedOrtCallbackInvoker {
|
|||
}
|
||||
|
||||
~ScopedOrtCallbackInvoker() {
|
||||
if (callback_.f) callback_.f(callback_.param);
|
||||
if (callback_.f) {
|
||||
callback_.f(callback_.param);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
execution_mode, terminate_flag, logger, only_execute_path_to_fetches);
|
||||
|
||||
return status;
|
||||
} // namespace utils
|
||||
}
|
||||
|
||||
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
|
|
|
|||
|
|
@ -719,7 +719,9 @@ void Node::ReplaceDefs(const std::map<const onnxruntime::NodeArg*, onnxruntime::
|
|||
//
|
||||
// // as we just loaded from file we want to fully initialize/Resolve, but not let that change
|
||||
// // the proto sync flag
|
||||
// auto status = new_graph->Resolve(/* no_proto_sync_required */ true);
|
||||
// ResolveOptions options;
|
||||
// options.no_proto_sync_required = true;
|
||||
// auto status = new_graph->Resolve(options);
|
||||
// return status;
|
||||
//}
|
||||
using google::protobuf::RepeatedPtrField;
|
||||
|
|
@ -2873,10 +2875,10 @@ Status Graph::InlineFunction(Node& node) {
|
|||
for (const auto& subgraph_node : subgraph.Nodes()) {
|
||||
if (subgraph_node.OpType() == kConstant) {
|
||||
// Copy constant nodes _value to name_to_initial_tensor_
|
||||
const gsl::not_null<TensorProto*>
|
||||
tensor{graph_proto_->add_initializer()};
|
||||
*tensor = subgraph_node.GetAttributes().at("value").t();
|
||||
*(tensor->mutable_name()) = subgraph_node.OutputDefs()[0]->Name();
|
||||
ONNX_NAMESPACE::NodeProto subgraph_node_proto{};
|
||||
subgraph_node.ToProto(subgraph_node_proto);
|
||||
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
|
||||
ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, *tensor));
|
||||
name_to_initial_tensor_[tensor->name()] = tensor;
|
||||
} else {
|
||||
std::vector<NodeArg*> inputs, outputs;
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import os
|
|||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
||||
def getOrtDeviceType(device):
|
||||
def get_ort_device_type(device):
|
||||
if device == 'cuda':
|
||||
return C.OrtDevice.cuda()
|
||||
elif device == 'cpu':
|
||||
|
|
@ -196,12 +196,12 @@ class IOBinding:
|
|||
|
||||
def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr):
|
||||
self._iobinding.bind_input(name,
|
||||
C.OrtDevice(getOrtDeviceType(device_type), C.OrtDevice.default_memory(), device_id),
|
||||
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id),
|
||||
element_type, shape, buffer_ptr)
|
||||
|
||||
def bind_output(self, name, device_type, device_id, element_type, shape, buffer_ptr):
|
||||
self._iobinding.bind_output(name,
|
||||
C.OrtDevice(getOrtDeviceType(device_type), C.OrtDevice.default_memory(), device_id),
|
||||
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id),
|
||||
element_type, shape, buffer_ptr)
|
||||
|
||||
def clear_binding_inputs(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue