edgchen1 2020-04-24 14:57:18 -07:00 committed by GitHub
parent 7347c73139
commit 4aa033b99e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 43 additions and 67 deletions

View file

@ -825,18 +825,22 @@ class Graph {
// Options to control Graph::Resolve.
struct ResolveOptions {
// Whether to override existing types with inferred types.
bool override_types = false;
// Names of initializers to keep even if unused (optional).
const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr;
// Whether to set that no proto sync is required after resolving.
// Useful for resolving right after loading from a GraphProto.
bool no_proto_sync_required = false;
};
/**
Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed.
1. Run through all validation rules.
a. Node name and node output's names should be unique.
b. Attribute match between node and op definition.
c. Input/Output match between node and op definition.
d. Graph is acyclic and sort nodes in topological order.
a. Node name and node output's names should be unique.
b. Attribute match between node and op definition.
c. Input/Output match between node and op definition.
d. Graph is acyclic and sort nodes in topological order.
2. Check & Setup inner nodes' dependency.
3. Cleanup function definition lists.
Note: the weights for training can't be cleaned during resolve.
@ -849,12 +853,6 @@ class Graph {
return Resolve(default_options);
}
common::Status ResolveAfterTypeTransformation() {
ResolveOptions options;
options.override_types = true;
return Resolve(options);
}
/** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */
const Node* ParentNode() const { return parent_node_; }
@ -1031,8 +1029,7 @@ class Graph {
template <typename TInstance>
static auto GetProducerNodeImpl(
TInstance& instance, const std::string& node_arg_name)
-> decltype(instance.GetNode(0)) {
TInstance& instance, const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
if (iter != instance.node_arg_to_producer_node_.end()) {
auto node_index = iter->second;
@ -1043,8 +1040,7 @@ class Graph {
template <typename TInstance>
static auto GetConsumerNodesImpl(
TInstance& instance, const std::string& node_arg_name)
-> std::vector<decltype(instance.GetNode(0))> {
TInstance& instance, const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
std::vector<decltype(instance.GetNode(0))> results;
auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
if (iter != instance.node_arg_to_consumer_nodes_.end()) {

View file

@ -2,10 +2,11 @@
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/framework/tensorprotoutils.h"
#include "fast_gelu.h"
#include "fast_gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
namespace onnxruntime {
namespace contrib {
@ -32,50 +33,23 @@ FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel
}
template <typename T>
Status FastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
const auto input_dims = input->Shape().GetDims();
if (input_dims.size() < 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
}
size_t num_inputs = OpKernel::Node().InputDefs().size();
bool has_bias = (num_inputs == 2);
int input_length = 1;
for (size_t i = 0; i < input_dims.size(); i++) {
input_length *= static_cast<int>(input_dims[i]);
}
int bias_length = 0;
const Tensor* bias = nullptr;
if (has_bias) {
bias = ctx->Input<Tensor>(1);
const auto bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 is expected to have 1 dimensions, got ", bias_dims.size());
}
if (bias_dims[0] != input_dims[input_dims.size() - 1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 dimension 0 should have same length as the last dimension of input 0");
}
bias_length = static_cast<int>(bias_dims[0]);
}
Tensor* output = ctx->Output(0, input->Shape());
const Tensor* input = context->Input<Tensor>(0);
const Tensor* bias = context->Input<Tensor>(1);
Tensor* output = context->Output(0, input->Shape());
int64_t input_length = input->Shape().Size();
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType<T>::MappedType CudaT;
if (!LaunchFastGeluKernel<CudaT>(
GetDeviceProp(),
nullptr,
input_length,
bias_length,
reinterpret_cast<const CudaT*>(input->template Data<T>()),
has_bias ? reinterpret_cast<const CudaT*>(bias->template Data<T>()) : nullptr,
reinterpret_cast<CudaT*>(output->template MutableData<T>()))) {
if (!LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
nullptr,
static_cast<int>(input_length),
static_cast<int>(bias_length),
reinterpret_cast<const CudaT*>(input->template Data<T>()),
(nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->template Data<T>()) : nullptr,
reinterpret_cast<CudaT*>(output->template MutableData<T>()))) {
CUDA_CALL(cudaGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}

View file

@ -50,7 +50,9 @@ class ScopedOrtCallbackInvoker {
}
ScopedOrtCallbackInvoker& operator=(ScopedOrtCallbackInvoker&& other) {
if (callback_.f) callback_.f(callback_.param);
if (callback_.f) {
callback_.f(callback_.param);
}
callback_ = other.callback_;
other.callback_.f = nullptr;
@ -60,7 +62,9 @@ class ScopedOrtCallbackInvoker {
}
~ScopedOrtCallbackInvoker() {
if (callback_.f) callback_.f(callback_.param);
if (callback_.f) {
callback_.f(callback_.param);
}
}
private:

View file

@ -538,7 +538,7 @@ common::Status ExecuteGraph(const SessionState& session_state,
execution_mode, terminate_flag, logger, only_execute_path_to_fetches);
return status;
} // namespace utils
}
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,

View file

@ -719,7 +719,9 @@ void Node::ReplaceDefs(const std::map<const onnxruntime::NodeArg*, onnxruntime::
//
// // as we just loaded from file we want to fully initialize/Resolve, but not let that change
// // the proto sync flag
// auto status = new_graph->Resolve(/* no_proto_sync_required */ true);
// ResolveOptions options;
// options.no_proto_sync_required = true;
// auto status = new_graph->Resolve(options);
// return status;
//}
using google::protobuf::RepeatedPtrField;
@ -2873,10 +2875,10 @@ Status Graph::InlineFunction(Node& node) {
for (const auto& subgraph_node : subgraph.Nodes()) {
if (subgraph_node.OpType() == kConstant) {
// Copy constant nodes _value to name_to_initial_tensor_
const gsl::not_null<TensorProto*>
tensor{graph_proto_->add_initializer()};
*tensor = subgraph_node.GetAttributes().at("value").t();
*(tensor->mutable_name()) = subgraph_node.OutputDefs()[0]->Name();
ONNX_NAMESPACE::NodeProto subgraph_node_proto{};
subgraph_node.ToProto(subgraph_node_proto);
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, *tensor));
name_to_initial_tensor_[tensor->name()] = tensor;
} else {
std::vector<NodeArg*> inputs, outputs;

View file

@ -8,7 +8,7 @@ import os
from onnxruntime.capi import _pybind_state as C
def getOrtDeviceType(device):
def get_ort_device_type(device):
if device == 'cuda':
return C.OrtDevice.cuda()
elif device == 'cpu':
@ -196,12 +196,12 @@ class IOBinding:
def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr):
self._iobinding.bind_input(name,
C.OrtDevice(getOrtDeviceType(device_type), C.OrtDevice.default_memory(), device_id),
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id),
element_type, shape, buffer_ptr)
def bind_output(self, name, device_type, device_id, element_type, shape, buffer_ptr):
self._iobinding.bind_output(name,
C.OrtDevice(getOrtDeviceType(device_type), C.OrtDevice.default_memory(), device_id),
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id),
element_type, shape, buffer_ptr)
def clear_binding_inputs(self):