Eliminate memory allocations per recent profiling (#12225)

* Alloc begin

FeedsFetches refactoring
Refactor Tensor class
Fix buffer deletor
Remove new/delete deleted
Adjust alloc move
Fix up xnnpack provider
Clarifying the comment on Create()
This commit is contained in:
Dmitri Smirnov 2022-07-25 14:14:38 -07:00 committed by GitHub
parent 972bb9676c
commit 3bf614fd47
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
50 changed files with 289 additions and 227 deletions

View file

@ -64,4 +64,7 @@ constexpr auto AsSpan(const T (&arr)[N]) {
return details::AsSpanImpl(arr, N);
}
template<class T>
inline gsl::span<const T> EmptySpan() { return gsl::span<const T>(); }
}

View file

@ -12,7 +12,7 @@ class BufferDeleter {
public:
BufferDeleter() : alloc_(nullptr) {}
BufferDeleter(AllocatorPtr alloc)
: alloc_(alloc) {}
: alloc_(std::move(alloc)) {}
void operator()(void* p) const {
if (alloc_)

View file

@ -36,15 +36,10 @@ namespace onnxruntime {
*/
class Tensor final {
public:
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) {
return std::make_unique<Tensor>(p_type, shape, std::move(allocator));
}
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, void* p_data,
const OrtMemoryInfo& alloc, ptrdiff_t offset = 0,
gsl::span<const int64_t> strides = {}) {
return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset, strides);
}
// NB! Removing Create() methods returning unique_ptr<Tensor>. Still available in other EPs that are dynamically linked.
// Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine.
// Use InitOrtValue() methods to allocate for OrtValue.
Tensor() = default; // to allow creating vector<Tensor> to support seq(tensor)

View file

@ -302,7 +302,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
// buffer memory and we don not want it uninitialized and generate different hashes
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_weights_data, 0, packed_weights_data_size);
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(std::move(alloc)));
packed_weights_size_[qkv_index] = packb_size;
for (size_t i = 0; i < loop_len; i++) {
@ -470,7 +470,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
// D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
// (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator)));
auto Q = reinterpret_cast<T*>(gemm_data);
auto K = Q + static_cast<size_t>(batch_size) * sequence_length * q_hidden_size;

View file

@ -83,7 +83,7 @@ class AttentionCPUBase : public AttentionBase {
// Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H)
auto out_tmp_data =
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator)));
ComputeVxAttentionScore(output->template MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs), V,
batch_size, sequence_length, past_sequence_length, v_head_size, v_hidden_size,

View file

@ -96,7 +96,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_weights_data, 0, packed_weights_data_size);
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(std::move(alloc)));
for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, false /*AIsSigned*/, weights_is_signed_, packed_weights_data);
@ -212,7 +212,7 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
// STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH)
// D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned.
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator)));
auto Q = reinterpret_cast<T*>(gemm_data);
auto K = Q + static_cast<int64_t>(batch_size) * sequence_length * hidden_size;

View file

@ -215,7 +215,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
uint8_t* a_data_quant = static_cast<uint8_t*>(allocator->Alloc(SafeInt<size_t>(num_of_elements) * sizeof(uint8_t)));
BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(allocator));
BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(std::move(allocator)));
ParQuantizeLinear(a_data, a_data_quant, num_of_elements, a_scale, a_zero_point, ctx->GetOperatorThreadPool());

View file

@ -76,7 +76,7 @@ Status NhwcMaxPool<T8Bits>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
int64_t col_buffer_batch_count = std::min(output_image_size, output_batch_count);
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(const T8Bits*)) * kernel_size * col_buffer_batch_count);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
std::vector<T8Bits> padding_data(static_cast<size_t>(C), std::numeric_limits<T8Bits>::lowest());
const auto* Xdata = X->template Data<T8Bits>();

View file

@ -583,7 +583,7 @@ Status QLinearAveragePool::ComputeImpl(OpKernelContext* context) const {
BufferUniquePtr x_data_fp32_guard;
if (kernel_shape.size() <= 3) {
x_data_fp32 = (float*)allocator->Alloc(SafeInt<size_t>(x_shape.Size()) * sizeof(float));
x_data_fp32_guard = BufferUniquePtr(x_data_fp32, BufferDeleter(allocator));
x_data_fp32_guard = BufferUniquePtr(x_data_fp32, BufferDeleter(std::move(allocator)));
dequantize_array(x_shape.Size(), X_data, x_scale, x_zero_point, x_data_fp32, tp);
}

View file

@ -22,8 +22,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
AllocatorPtr allocator,
void* /*stream*/,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices) {
Tensor& output_values,
Tensor& output_indices) {
if (input->IsDataType<float>()) {
return GetTopK<float>(input, axis, k, largest, sorted, allocator, threadpool, output_values, output_indices);
}
@ -343,20 +343,20 @@ Status ProcessLogits(const OrtValue& logits, //
constexpr bool largest = true;
constexpr bool sorted = true; // results returned in sorted order.
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
Tensor topk_scores;
Tensor topk_indices;
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
topk_scores, topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
dumper->Print("topk_scores", topk_scores);
dumper->Print("topk_indices", topk_indices);
#endif
// Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
// next_indices = (next_tokens / vocab_size).long()
// next_tokens = next_tokens % vocab_size
gsl::span<const int64_t> next_token_indices = topk_indices->DataAsSpan<int64_t>();
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
offset = 0;
for (int i = 0; i < batch_size; i++) {
for (unsigned int j = 0; j < top_k; j++, offset++) {
@ -365,7 +365,7 @@ Status ProcessLogits(const OrtValue& logits, //
}
}
gsl::span<const T> next_scores = topk_scores->DataAsSpan<T>();
gsl::span<const T> next_scores = topk_scores.DataAsSpan<T>();
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());
@ -453,8 +453,8 @@ Status GreedySearchProcessLogits(
constexpr bool largest = true;
constexpr bool sorted = false;
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
Tensor topk_scores;
Tensor topk_indices;
ORT_RETURN_IF_ERROR(
TopK(&input,
axis,
@ -472,7 +472,7 @@ Status GreedySearchProcessLogits(
dumper->Print("topk_indices", *(topk_indices.get()));
#endif
gsl::span<const int64_t> next_token_indices = topk_indices->DataAsSpan<int64_t>();
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
gsl::copy(next_token_indices, greedy_state->next_tokens_cpu);
#ifdef DEBUG_GENERATION

View file

@ -37,8 +37,8 @@ using TopkFunc = std::function<Status(
AllocatorPtr allocator,
void* stream, // cudaStream_t
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices)>;
Tensor& output_values,
Tensor& output_indices)>;
// Create subgraph inputs: input_ids, position_ids and attention_mask (for GPT-2).
using CreateGptInputsFunc = std::function<Status(
@ -168,8 +168,8 @@ Status TopK(
AllocatorPtr allocator,
void* stream,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
Tensor& output_values,
Tensor& output_indices);
Status AddToFeeds(
const IExecutionProvider* execution_provider,

View file

@ -55,7 +55,7 @@ Status Subgraph::Setup(const SessionState& session_state,
session_state_ = &session_state;
subgraph_session_state_ = &subgraph_session_state;
std::vector<std::string> feed_names;
InlinedVector<std::string_view> feed_names;
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
// Use the first output (logits) to find device location.
@ -69,25 +69,24 @@ Status Subgraph::Setup(const SessionState& session_state,
feed_names.push_back(entry->Name());
}
std::vector<OrtDevice> feed_locations;
feed_locations.resize(feed_names.size());
InlinedVector<OrtDevice> feed_locations;
feed_locations.reserve(feed_names.size());
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
if (i >= subgraph_input_names.size()) { // Implicit inputs
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
feed_locations[i] = location.device;
feed_locations.push_back(location.device);
} else {
feed_locations[i] = default_location.device;
feed_locations.push_back(default_location.device);
}
}
std::unique_ptr<FeedsFetchesManager> ffm;
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
subgraph_session_state.GetOrtValueNameIdxMap(), feeds_fetches_manager_));
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *feeds_fetches_manager_));
// Setup the locations where we want the subgraph output to end up on
std::vector<const OrtMemoryInfo*> fetch_locations;
InlinedVector<const OrtMemoryInfo*> fetch_locations;
fetch_locations.reserve(num_subgraph_outputs);
// Past state need to be where we can feed them in to the next iteration, so set the location to match the feed.
@ -95,9 +94,7 @@ Status Subgraph::Setup(const SessionState& session_state,
fetch_locations.push_back(&default_location);
}
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
feeds_fetches_manager_ = std::move(ffm);
utils::FinalizeFeedFetchCopyInfo(*feeds_fetches_manager_, feed_locations, fetch_locations);
// Check subgraph only need once so put in Setup function.
auto& inputs = subgraph.GetInputs();

View file

@ -48,7 +48,9 @@ class Subgraph {
Status Setup(const SessionState& session_state,
const SessionState& subgraph_session_state);
FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); }
FeedsFetchesManager* GetFeedsFetchesManager() {
return (feeds_fetches_manager_.has_value()) ? &*feeds_fetches_manager_ : nullptr;
}
const IExecutionProvider* GetProvider() const;
@ -65,7 +67,7 @@ class Subgraph {
AllocatorPtr allocator_;
const SessionState* session_state_;
const SessionState* subgraph_session_state_;
std::unique_ptr<FeedsFetchesManager> feeds_fetches_manager_;
std::optional<FeedsFetchesManager> feeds_fetches_manager_;
bool is_output_float16_;
};

View file

@ -31,8 +31,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
AllocatorPtr allocator,
void* stream,
onnxruntime::concurrency::ThreadPool* /*threadpool*/,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices) {
Tensor& output_values,
Tensor& output_indices) {
ORT_ENFORCE(nullptr != input);
int32_t rank = static_cast<int32_t>(input->Shape().NumDimensions());
@ -51,15 +51,15 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
int64_t dimension = input_shape[axis];
int64_t N = elem_nums_cuda[0] / dimension;
output_values = Tensor::Create(input->DataType(), output_shape, allocator);
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
output_values = std::move(*Tensor::Create(input->DataType(), output_shape, allocator));
output_indices = std::move(*Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, std::move(allocator)));
if (input->IsDataType<float>()) {
return TopKImpl<float>(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here
reinterpret_cast<cudaStream_t>(stream),
input->Data<float>(),
static_cast<float*>(output_values->MutableDataRaw()),
static_cast<int64_t*>(output_indices->MutableDataRaw()),
static_cast<float*>(output_values.MutableDataRaw()),
static_cast<int64_t*>(output_indices.MutableDataRaw()),
elem_nums_cuda,
static_cast<size_t>(elem_nums_cuda.Size()),
static_cast<int32_t>(axis),
@ -72,8 +72,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
return TopKImpl<MLFloat16>(nullptr,
reinterpret_cast<cudaStream_t>(stream),
input->Data<MLFloat16>(),
static_cast<MLFloat16*>(output_values->MutableDataRaw()),
static_cast<int64_t*>(output_indices->MutableDataRaw()),
static_cast<MLFloat16*>(output_values.MutableDataRaw()),
static_cast<int64_t*>(output_indices.MutableDataRaw()),
elem_nums_cuda,
static_cast<size_t>(elem_nums_cuda.Size()),
static_cast<int32_t>(axis),
@ -350,10 +350,10 @@ Status ProcessLogits(const OrtValue& logits, //
constexpr bool largest = true;
constexpr bool sorted = true; // results returned in sorted order.
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
std::unique_ptr<Tensor> topk_scores = Tensor::CreateDefault();
std::unique_ptr<Tensor> topk_indices = Tensor::CreateDefault();
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
topk_scores, topk_indices));
*topk_scores, *topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));
@ -514,10 +514,10 @@ Status GreedySearchProcessLogits(
constexpr bool largest = true;
constexpr bool sorted = false;
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
auto topk_scores = Tensor::CreateDefault();
auto topk_indices = Tensor::CreateDefault();
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
topk_scores, topk_indices));
*topk_scores, *topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));

View file

@ -25,8 +25,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
AllocatorPtr allocator,
void* stream,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
Tensor& output_values,
Tensor& output_indices);
Status AddToFeeds(const IExecutionProvider* execution_provider,
std::initializer_list<OrtValue> inputs,

View file

@ -8,9 +8,9 @@
#include "core/framework/utils.h"
namespace onnxruntime {
common::Status FeedsFetchesInfo::MapNamesToMLValueIdxs(const std::vector<std::string>& names,
common::Status FeedsFetchesInfo::MapNamesToMLValueIdxs(gsl::span<const std::string> names,
const OrtValueNameIdxMap& ort_value_name_idx_map,
std::vector<int>& ort_value_idxs) {
InlinedVector<int>& ort_value_idxs) {
auto status = Status::OK();
ort_value_idxs.reserve(names.size());
@ -40,8 +40,8 @@ Status FeedsFetchesInfo::SetMLValueIdxs(const OrtValueNameIdxMap& ort_value_name
return status;
}
Status FeedsFetchesManager::Create(const std::vector<std::string>& feed_names,
const std::vector<std::string>& output_names,
Status FeedsFetchesManager::Create(gsl::span<const std::string> feed_names,
gsl::span<const std::string> output_names,
const OrtValueNameIdxMap& ort_value_name_idx_map,
std::unique_ptr<FeedsFetchesManager>& feed_fetch_manager) {
FeedsFetchesInfo info{feed_names, output_names, ort_value_name_idx_map};
@ -51,11 +51,22 @@ Status FeedsFetchesManager::Create(const std::vector<std::string>& feed_names,
return Status::OK();
}
Status FeedsFetchesManager::Create(gsl::span<const std::string_view> feed_names,
gsl::span<const std::string> output_names,
const OrtValueNameIdxMap& ort_value_name_idx_map,
std::optional<FeedsFetchesManager>& feed_fetch_manager) {
FeedsFetchesInfo info{feed_names, output_names, ort_value_name_idx_map};
feed_fetch_manager.emplace(std::move(info));
return Status::OK();
}
FeedsFetchesManager::FeedsFetchesManager(FeedsFetchesInfo&& info)
: feeds_fetches_info_{info} {
: feeds_fetches_info_(std::move(info)) {
// init with default values
feeds_device_copy_info_.resize(info.feed_names.size());
fetches_device_copy_info_.resize(info.output_names.size());
feeds_device_copy_info_.resize(feeds_fetches_info_.feed_names.size());
fetches_device_copy_info_.resize(feeds_fetches_info_.output_names.size());
}
void FeedsFetchesManager::SetDeviceCopyChecks(DeviceCopyCheck input_copy_needed, DeviceCopyCheck output_copy_needed) {

View file

@ -5,6 +5,8 @@
#include <string>
#include <vector>
#include <optional>
#include "core/common/inlined_containers_fwd.h"
#ifndef SHARED_PROVIDER
#include "core/framework/ort_value.h"
@ -30,25 +32,43 @@ struct DeviceCopyChecks {
struct FeedsFetchesInfo {
FeedsFetchesInfo() = default;
FeedsFetchesInfo(const std::vector<std::string>& feed_names_in,
const std::vector<std::string>& output_names_in,
FeedsFetchesInfo(gsl::span<const std::string> feed_names_in,
gsl::span<const std::string> output_names_in,
const OrtValueNameIdxMap& ort_value_name_idx_map)
: feed_names{feed_names_in}, output_names{output_names_in} {
: feed_names(),
output_names() {
feed_names.reserve(feed_names_in.size());
feed_names.assign(feed_names_in.begin(), feed_names_in.end());
output_names.reserve(output_names_in.size());
output_names.assign(output_names_in.begin(), output_names_in.end());
ORT_THROW_IF_ERROR(SetMLValueIdxs(ort_value_name_idx_map));
}
static Status MapNamesToMLValueIdxs(const std::vector<std::string>& names,
FeedsFetchesInfo(gsl::span<const std::string_view> feed_names_in,
gsl::span<const std::string> output_names_in,
const OrtValueNameIdxMap& ort_value_name_idx_map)
: feed_names(),
output_names() {
feed_names.reserve(feed_names_in.size());
feed_names.assign(feed_names_in.begin(), feed_names_in.end());
output_names.reserve(output_names_in.size());
output_names.assign(output_names_in.begin(), output_names_in.end());
ORT_THROW_IF_ERROR(SetMLValueIdxs(ort_value_name_idx_map));
}
static Status MapNamesToMLValueIdxs(gsl::span<const std::string> names,
const OrtValueNameIdxMap& ort_value_name_idx_map,
std::vector<int>& ort_value_idxs);
InlinedVector<int>& ort_value_idxs);
// set the ort_value_idxs for the current values in feed_names and output_names
Status SetMLValueIdxs(const OrtValueNameIdxMap& ort_value_name_idx_map);
std::vector<std::string> feed_names;
std::vector<std::string> output_names;
InlinedVector<std::string> feed_names;
InlinedVector<std::string> output_names;
std::vector<int> feeds_mlvalue_idxs;
std::vector<int> fetches_mlvalue_idxs;
InlinedVector<int> feeds_mlvalue_idxs;
InlinedVector<int> fetches_mlvalue_idxs;
};
struct MLValueCopyInfo {
@ -58,10 +78,14 @@ struct MLValueCopyInfo {
class FeedsFetchesManager {
public:
static Status Create(const std::vector<std::string>& feed_names, const std::vector<std::string>& output_names,
static Status Create(gsl::span<const std::string> feed_names, gsl::span<const std::string> output_names,
const OrtValueNameIdxMap& ort_value_name_idx_map,
std::unique_ptr<FeedsFetchesManager>& feeds_fetches_manager);
static Status Create(gsl::span<const std::string_view> feed_names, gsl::span<const std::string> output_names,
const OrtValueNameIdxMap& ort_value_name_idx_map,
std::optional<FeedsFetchesManager>& feeds_fetches_manager);
FeedsFetchesManager(FeedsFetchesInfo&& info);
const FeedsFetchesInfo& GetFeedsFetchesInfo() const { return feeds_fetches_info_; }

View file

@ -28,9 +28,9 @@ class IExecutor {
* The lifetime of 'fetches' is limited by 'session_state'
*/
common::Status Execute(const SessionState& session_state,
const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds,
const std::vector<int>& fetch_mlvalue_idxs,
gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds,
gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const logging::Logger& logger) {
std::unordered_map<size_t, CustomAllocator> fetch_allocators;
@ -38,8 +38,8 @@ class IExecutor {
}
// TODO: as fetch_allocators is optional, it should be a pointer instead of reference
virtual common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
virtual common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
// optional custom allocators. key is index in fetches
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,

View file

@ -30,7 +30,7 @@ class OrtValueNameIdxMap {
return p.first->second;
}
common::Status GetIdx(const std::string& name, int& idx) const {
common::Status GetIdx(std::string_view name, int& idx) const {
idx = -1;
auto it = map_.find(name);

View file

@ -80,9 +80,9 @@ void OrtValueTensorSlicer<T>::Iterator::MaterializeMLValue() const {
//
// TODO: Ideally we could avoid the overhead of creating a new Tensor (mainly cost of copying type and shape info)
// and would simply update Tensor::p_data_ given all other info remains constant for each slice.
auto sub_tensor = Tensor::Create(tensor_data_type_, per_iteration_shape_, const_cast<void*>(tensor_slice_data_raw), *tensor_location_);
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
current_ = OrtValue{sub_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()};
OrtValue val;
Tensor::InitOrtValue(tensor_data_type_, per_iteration_shape_, const_cast<void*>(tensor_slice_data_raw), *tensor_location_, val);
current_ = std::move(val);
}
template class OrtValueTensorSlicer<OrtValue>;

View file

@ -130,8 +130,8 @@ static Status ReleaseNodeMLValues(ExecutionFrame& frame,
return Status::OK();
}
Status PartialExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
Status PartialExecutor::Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) {

View file

@ -28,8 +28,8 @@ class PartialExecutor : public IExecutor {
ORT_UNUSED_PARAMETER(partial_graph_index_);
}
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) override;

View file

@ -27,8 +27,8 @@ ParallelExecutor::ParallelExecutor(const SessionState& session_state, const bool
}
}
Status ParallelExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
Status ParallelExecutor::Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) {

View file

@ -22,8 +22,8 @@ class ParallelExecutor : public IExecutor {
public:
ParallelExecutor(const SessionState& session_state, const bool& terminate_flag = false);
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) override;

View file

@ -26,8 +26,9 @@ struct PartialGraphExecutionState {
size_t GetProgramCounterStart() { return program_counter_start_; }
size_t GetProgramCounterEnd() { return program_counter_end_; }
ExecutionFrame& GetExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
ExecutionFrame& GetExecutionFrame(gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
gsl::span<const OrtValue> fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state) {
if (execution_frame_ == nullptr) {

View file

@ -141,8 +141,8 @@ static Status ReleaseNodeMLValues(ExecutionFrame& frame,
const SequentialExecutionPlan::NodeExecutionPlan& node_exec_plan,
const logging::Logger& logger);
Status SequentialExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
Status SequentialExecutor::Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) {

View file

@ -21,8 +21,8 @@ class SequentialExecutor : public IExecutor {
SequentialExecutor(const bool& terminate_flag = false, const bool only_execute_path_to_fetches = false)
: terminate_flag_{terminate_flag}, only_execute_path_to_fetches_(only_execute_path_to_fetches) {}
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) override;

View file

@ -251,7 +251,7 @@ static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_pr
static const OrtMemoryInfo& FindMemoryInfoForValue(const OrtValueNameIdxMap& map,
const SequentialExecutionPlan& plan,
const std::string& name) {
std::string_view name) {
int idx = -1;
auto status = map.GetIdx(name, idx);
ORT_THROW_IF_ERROR(status);
@ -261,7 +261,7 @@ static const OrtMemoryInfo& FindMemoryInfoForValue(const OrtValueNameIdxMap& map
}
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
const std::string& name) {
std::string_view name) {
const auto* exec_plan_ptr = session_state.GetExecutionPlan();
ORT_ENFORCE(exec_plan_ptr);
@ -304,7 +304,7 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session
}
static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& session_state,
const std::vector<std::string>& feed_names,
gsl::span<const std::string> feed_names,
std::vector<MLValueCopyInfo>& copy_info) {
for (size_t idx = 0, end = feed_names.size(); idx < end; ++idx) {
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, feed_names[idx], copy_info[idx]));
@ -316,7 +316,7 @@ static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& sessio
// get the source device info for the node producing each output that we will return in the fetches.
// target device info is not known until runtime.
static common::Status CalculateStaticCopyInfoForFetches(const SessionState& session_state,
const std::vector<std::string>& fetch_names,
gsl::span<const std::string> fetch_names,
std::vector<MLValueCopyInfo>& copy_info) {
for (size_t idx = 0, end = fetch_names.size(); idx < end; ++idx) {
const std::string& output_name = fetch_names[idx];
@ -362,7 +362,7 @@ common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
}
// update the allocation_provider in the copy info based on the actual feeds
static bool FinalizeCopyInfoForFeeds(const std::vector<OrtDevice>& feed_locations,
static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
std::vector<MLValueCopyInfo>& copy_info) {
ORT_ENFORCE(feed_locations.size() == copy_info.size());
bool copy_needed = false;
@ -378,7 +378,7 @@ static bool FinalizeCopyInfoForFeeds(const std::vector<OrtDevice>& feed_location
return copy_needed;
}
static bool FinalizeCopyInfoForFetches(const std::vector<const OrtMemoryInfo*>& fetch_alloc_info,
static bool FinalizeCopyInfoForFetches(gsl::span<const OrtMemoryInfo* const>& fetch_alloc_info,
std::vector<MLValueCopyInfo>& copy_info) {
ORT_ENFORCE(fetch_alloc_info.size() == copy_info.size());
bool copy_needed = false;
@ -402,8 +402,8 @@ static bool FinalizeCopyInfoForFetches(const std::vector<const OrtMemoryInfo*>&
// Finalize the copy info using the OrtDevice and OrtMemoryInfo for the feeds and fetches
// This can be used by control flow nodes prior to the execution of the overall graph.
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtDevice>& feed_locations,
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info) {
gsl::span<const OrtDevice> feed_locations,
gsl::span<const OrtMemoryInfo* const> fetch_alloc_info) {
if (feeds_fetches_manager.GetDeviceCopyChecks().status == DeviceCopyCheck::NoCopy)
return;
@ -418,7 +418,7 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
// Finalize the copy info using the OrtValue instances for the feeds and fetches
static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds,
gsl::span<const OrtValue> feeds,
std::vector<OrtValue>& fetches) {
if (feeds_fetches_manager.GetDeviceCopyChecks().status == DeviceCopyCheck::NoCopy)
return;
@ -465,9 +465,9 @@ static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager
}
static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
const std::vector<OrtValue>& orig_feeds,
gsl::span<const OrtValue> orig_feeds,
std::vector<OrtValue>& new_feeds,
const std::vector<MLValueCopyInfo>& copy_info) {
gsl::span<const MLValueCopyInfo> copy_info) {
size_t num_feeds = orig_feeds.size();
ORT_ENFORCE(copy_info.size() == num_feeds);
@ -560,20 +560,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
static common::Status ExecuteGraphImpl(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag,
const logging::Logger& logger, const bool only_execute_path_to_fetches = false) {
std::unique_ptr<IExecutor> p_exec;
// avoid memory allocations
std::optional<SequentialExecutor> seq_executor;
std::optional<ParallelExecutor> par_executor;
IExecutor* p_exec = nullptr;
if (execution_mode == ExecutionMode::ORT_SEQUENTIAL) {
p_exec = std::make_unique<SequentialExecutor>(terminate_flag, only_execute_path_to_fetches);
seq_executor.emplace(terminate_flag, only_execute_path_to_fetches);
p_exec = &seq_executor.value();
} else if (execution_mode == ExecutionMode::ORT_PARALLEL) {
auto* p_inter_op_thread_pool = session_state.GetInterOpThreadPool();
if (!p_inter_op_thread_pool) {
LOGS(logger, WARNING) << "Only one thread was configured for parallel execution. Hence will use sequential execution.";
p_exec = std::make_unique<SequentialExecutor>(terminate_flag, only_execute_path_to_fetches);
seq_executor.emplace(terminate_flag, only_execute_path_to_fetches);
p_exec = &seq_executor.value();
} else {
p_exec = std::make_unique<ParallelExecutor>(session_state, terminate_flag);
par_executor.emplace(session_state, terminate_flag);
p_exec = &par_executor.value();
}
}
@ -588,7 +594,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
feeds_fetches_info.fetches_mlvalue_idxs, fetches, fetch_allocators,
logger));
} else {
const std::vector<OrtValue>* p_feeds = &feeds;
auto feeds_to_use = feeds;
std::vector<OrtValue>* p_fetches = &fetches;
std::vector<OrtValue> device_feeds;
std::vector<OrtValue> device_fetches;
@ -596,7 +602,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo();
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info));
p_feeds = &device_feeds;
feeds_to_use = device_feeds;
}
auto num_outputs = fetches.size();
@ -619,7 +625,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
}
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state,
feeds_fetches_info.feeds_mlvalue_idxs, *p_feeds,
feeds_fetches_info.feeds_mlvalue_idxs, feeds_to_use,
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, fetch_allocators,
logger));
@ -633,7 +639,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag,
const logging::Logger& logger, bool only_execute_path_to_fetches) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));
@ -649,7 +655,7 @@ common::Status ExecuteGraph(const SessionState& session_state,
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
int32_t partial_graph_index) {
@ -667,7 +673,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
feeds_fetches_info.fetches_mlvalue_idxs, fetches, {},
logger));
} else {
const std::vector<OrtValue>* p_feeds = &feeds;
auto p_feeds = feeds;
std::vector<OrtValue>* p_fetches = &fetches;
std::vector<OrtValue> device_feeds;
std::vector<OrtValue> device_fetches;
@ -675,7 +681,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo();
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info));
p_feeds = &device_feeds;
p_feeds = device_feeds;
}
auto num_outputs = fetches.size();
@ -698,7 +704,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
}
ORT_RETURN_IF_ERROR(executor.Execute(session_state,
feeds_fetches_info.feeds_mlvalue_idxs, *p_feeds,
feeds_fetches_info.feeds_mlvalue_idxs, p_feeds,
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, {},
logger));
@ -712,7 +718,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
#endif
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger) {
auto status = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,

View file

@ -71,7 +71,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
// Searches the allocation plan from the session_state to find the OrtMemoryInfo for the value 'name'.
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
const std::string& name);
std::string_view name);
// Initialize the feed and fetch copy info using session_state.
// Determines the device that each graph input that will be fed will be consumed on,
@ -82,18 +82,18 @@ common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
// Finalize the feed and fetch copy info using session_state and the device and location information from the feeds
// and fetches that will be used in graph execution.
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtDevice>& feed_locations,
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info);
gsl::span<const OrtDevice> feed_locations,
gsl::span<const OrtMemoryInfo* const> fetch_alloc_info);
// Execute the main graph. The feed_fetches_manager will be finalized based on the provided feeds and fetches.
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
bool only_execute_path_to_fetches = false);
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
int32_t partial_graph_index);
@ -102,7 +102,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
// See IControlFlowNode::SetupSubgraphExecutionInfo usage in the control flow kernels.
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger);

View file

@ -262,11 +262,11 @@ std::vector<uint8_t> ApiTensor::Data() const {
const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto_.data_type())->GetElementType();
auto tensor_shape_dims = utils::GetTensorShapeFromTensorProto(tensor_proto_);
TensorShape tensor_shape{std::move(tensor_shape_dims)};
auto tensor = onnxruntime::Tensor::Create(tensor_dtype, tensor_shape, cpu_allocator_);
onnxruntime::Tensor tensor(tensor_dtype, tensor_shape, cpu_allocator_);
ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path_.ToPathString().c_str(),
tensor_proto_, *tensor));
size_t num_bytes = gsl::narrow_cast<size_t>(tensor->SizeInBytes());
const uint8_t* data = static_cast<const uint8_t*>(tensor->DataRaw());
tensor_proto_, tensor));
size_t num_bytes = gsl::narrow_cast<size_t>(tensor.SizeInBytes());
const uint8_t* data = static_cast<const uint8_t*>(tensor.DataRaw());
return std::vector<uint8_t>(data, data + num_bytes);
}
// </ApiTensor>
@ -515,7 +515,7 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto->data_type())->GetElementType();
auto tensor_shape_dims = utils::GetTensorShapeFromTensorProto(*tensor_proto);
TensorShape tensor_shape{tensor_shape_dims};
std::unique_ptr<Tensor> in_tensor = Tensor::Create(tensor_dtype, tensor_shape, cpu_allocator_);
Tensor in_tensor(tensor_dtype, tensor_shape, cpu_allocator_);
std::vector<int64_t> new_tensor_shape_dims;
std::vector<size_t> permutations;
@ -528,13 +528,12 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
}
TensorShape new_tensor_shape(new_tensor_shape_dims);
std::unique_ptr<Tensor> out_tensor = Tensor::Create(tensor_dtype, new_tensor_shape, cpu_allocator_);
Tensor out_tensor(tensor_dtype, new_tensor_shape, cpu_allocator_);
ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), graph_.ModelPath().ToPathString().c_str(),
*tensor_proto, *in_tensor));
*tensor_proto, in_tensor));
ORT_THROW_IF_ERROR(Transpose::DoTranspose(permutations, *in_tensor, *out_tensor));
ORT_THROW_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor));
auto& node_arg = *graph_.GetNodeArg(name_str);
TensorShapeProto new_shape;
@ -544,7 +543,7 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
node_arg.SetShape(new_shape);
ONNX_NAMESPACE::TensorProto new_tensor_proto = utils::TensorToTensorProto(*out_tensor, name_str);
ONNX_NAMESPACE::TensorProto new_tensor_proto = utils::TensorToTensorProto(out_tensor, name_str);
graph_.RemoveInitializedTensor(name_str);
graph_.AddInitializedTensor(new_tensor_proto);
}

View file

@ -247,7 +247,7 @@ static Status MultinomialCompute(OpKernelContext* ctx,
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
auto cdf_data = static_cast<double*>(alloc->Alloc(SafeInt<size_t>(sizeof(double)) * num_classes));
BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(alloc));
BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(std::move(alloc)));
Eigen::array<int64_t, 1> cdf_dims = {{num_classes}};
auto cdf = EigenVector<double>(cdf_data, cdf_dims);
// END create temporary tensor

View file

@ -26,7 +26,6 @@
#include <algorithm>
#include <cmath>
using namespace std;
namespace onnxruntime {
template <typename T>
@ -123,7 +122,7 @@ static void HeapifyIthPosition(int64_t* heap, size_t i, size_t k, const HeapCmp&
template <class Comparator>
static void SelectTopK(const Comparator& comparer,
int64_t row_offset, int64_t num_blocks, int64_t block_slice, int64_t inter_block_offset,
const unsigned k, bool sort_top_k, vector<int64_t>& data_holder) {
const unsigned k, bool sort_top_k, std::vector<int64_t>& data_holder) {
for (int64_t l = 0; l < num_blocks; ++l) {
data_holder[l] = (row_offset + (l * block_slice + inter_block_offset));
}
@ -375,8 +374,8 @@ template <typename T>
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices) {
Tensor& output_values,
Tensor& output_indices) {
const TensorShape& input_shape = input->Shape();
// Will return axis_ as is if positive or fixes it in case it is negative
@ -394,8 +393,8 @@ Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool large
TensorShape output_shape = input_shape;
output_shape[axis_parsed] = k;
output_values = Tensor::Create(input->DataType(), output_shape, allocator);
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
output_values = Tensor(input->DataType(), output_shape, allocator);
output_indices = Tensor(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
// no-op - no output buffers to fill - return silently
if (k == 0) {
@ -403,10 +402,10 @@ Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool large
}
if (largest) {
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, &output_values, &output_indices, output_shape, k, sorted,
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
} else {
FindTopKElements<LesserValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
FindTopKElements<LesserValueCmp<T>>(input, input_shape, &output_values, &output_indices, output_shape, k, sorted,
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
}
@ -417,8 +416,8 @@ Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool large
template Status GetTopK<float>(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
Tensor& output_values,
Tensor& output_indices);
// Opset ver - 1 to 9

View file

@ -24,6 +24,6 @@ template <typename T>
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
Tensor& output_values,
Tensor& output_indices);
} // namespace onnxruntime

View file

@ -80,7 +80,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(T)) * col_buffer_size);
col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc));
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
}
T* col_buffer_data = static_cast<T*>(col_buffer.get());
@ -234,7 +234,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
auto* working_data = WorkingBufferSize > 0 ? alloc->Alloc(SafeInt<size_t>(sizeof(float)) * WorkingBufferSize)
: nullptr;
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
BufferUniquePtr working_buffer(working_data, BufferDeleter(std::move(alloc)));
MlasConv(&Parameters,
Xdata,
@ -254,7 +254,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
const int64_t col_buffer_size = kernel_dim * output_image_size;
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
auto* col_buffer_data = static_cast<float*>(col_buffer.get());
for (int image_id = 0; image_id < N; ++image_id) {

View file

@ -73,7 +73,7 @@ Status ConvTranspose<float>::PrePack(const Tensor& tensor, int input_idx, Alloca
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_filter_data, 0, packed_filter_data_size);
transposed_filter_ = BufferUniquePtr(packed_filter_data, BufferDeleter(alloc));
transposed_filter_ = BufferUniquePtr(packed_filter_data, BufferDeleter(std::move(alloc)));
for (int64_t group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
MlasTranspose(tensor.Data<float>() + (group_id * N * K),
@ -146,7 +146,7 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
const int64_t col_buffer_size = kernel_dim * p.input_shape.Size();
auto col_data = alloc->Alloc(SafeInt<size_t>(sizeof(T)) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
T* col_buffer_data = static_cast<T*>(col_buffer.get());
const T* Xdata = p.X->template Data<T>();
@ -246,7 +246,7 @@ Status ConvTranspose<float>::DoConvTranspose(OpKernelContext* context, bool dyna
const int64_t col_buffer_size = kernel_dim * p.input_shape.Size();
auto col_data = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
float* col_buffer_data = static_cast<float*>(col_buffer.get());
const float* Xdata = p.X->template Data<float>();

View file

@ -80,7 +80,7 @@ Status LRN<float>::Compute(OpKernelContext* context) const {
const size_t padded_square_size = (static_cast<size_t>(C) + size_ - 1) * H * W;
auto psdata = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * padded_square_size);
BufferUniquePtr padded_square_buffer(psdata, BufferDeleter(alloc));
BufferUniquePtr padded_square_buffer(psdata, BufferDeleter(std::move(alloc)));
auto* padded_square_data = static_cast<float*>(padded_square_buffer.get());
math::Set<float, CPUMathUtil>(padded_square_size, 0.0f, padded_square_data, &CPUMathUtil::Instance());

View file

@ -102,7 +102,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * col_buffer_size);
col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc));
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
}
auto* col_buffer_data = static_cast<uint8_t*>(col_buffer.get());

View file

@ -53,7 +53,7 @@ class MatMulIntegerBase : public OpKernel {
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_b_data, 0, packed_b_size);
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc));
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(std::move(alloc)));
MlasGemmPackB(N, K, b_data, N, a_is_signed, b_is_signed_, packed_b_data);
bool share_prepacked_weights = (prepacked_weights != nullptr);

View file

@ -107,7 +107,7 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
auto gemm_output_data = alloc->Alloc(SafeInt<size_t>(gemm_shape.M) *
gemm_shape.N * sizeof(int32_t) * num_gemms);
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(std::move(alloc)));
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_params(num_gemms);

View file

@ -734,7 +734,9 @@ struct ProviderHost {
// Tensor
virtual std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) = 0;
virtual std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) = 0;
virtual void Tensor__operator_delete(Tensor* p) = 0;
virtual std::unique_ptr<Tensor> Tensor__construct_default() = 0;
virtual void Tensor__move_assign(Tensor& lhs, Tensor&& rhs) noexcept = 0;
virtual void Tensor__operator_delete(Tensor* p) noexcept = 0;
virtual void Tensor__InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, OrtValue& ort_value) = 0;
virtual void Tensor__InitOrtValue(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, OrtValue& ort_value) = 0;

View file

@ -882,10 +882,11 @@ class SessionState {
};
struct Tensor final {
static std::unique_ptr<Tensor> CreateDefault() { return g_host->Tensor__construct_default(); }
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) { return g_host->Tensor__construct(p_type, shape, std::move(allocator)); }
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset = 0) { return g_host->Tensor__construct(p_type, shape, p_data, alloc, offset); }
static void operator delete(void* p) { g_host->Tensor__operator_delete(reinterpret_cast<Tensor*>(p)); }
static void operator delete(void* p) noexcept { g_host->Tensor__operator_delete(reinterpret_cast<Tensor*>(p)); }
static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, OrtValue& ort_value) {
g_host->Tensor__InitOrtValue(elt_type, shape, std::move(allocator), ort_value);
@ -935,6 +936,10 @@ struct Tensor final {
Tensor() = delete;
Tensor(const Tensor&) = delete;
void operator=(const Tensor&) = delete;
Tensor& operator=(Tensor&& o) noexcept {
g_host->Tensor__move_assign(*this, std::move(o));
return *this;
}
};
template <>

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "conv.h"
#include "core/common/inlined_containers_fwd.h"
#include "core/graph/constants.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
@ -230,22 +231,22 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
auto orig_shape = tensor.Shape();
std::vector<size_t> perm{0, 2, 3, 1};
std::vector<int64_t> new_dims{orig_shape[0],
orig_shape[2],
orig_shape[3],
orig_shape[1]};
InlinedVector<size_t> perm{0, 2, 3, 1};
TensorShapeVector new_dims{orig_shape[0],
orig_shape[2],
orig_shape[3],
orig_shape[1]};
packed_w_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), alloc);
packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc));
SingleAxisTranspose(perm, tensor, *packed_w_, /*from*/ 1, /*to*/ 3);
SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3);
is_packed = true;
// we can create the kernel now
struct xnn_operator* p = nullptr;
ORT_RETURN_IF_ERROR(CreateXnnpackKernel(conv_attrs_, C_, M_, kernel_shape_, clip_min_max_, IsDepthwise(),
*packed_w_, B_ ? B_->Data<float>() : nullptr, p));
packed_w_, B_ ? B_->Data<float>() : nullptr, p));
op0_.reset(p);
}

View file

@ -38,7 +38,7 @@ class Conv : public OpKernel {
TensorShapeVector kernel_shape_;
int64_t C_;
int64_t M_;
std::unique_ptr<Tensor> packed_w_;
Tensor packed_w_;
const Tensor* B_{nullptr};
std::optional<std::pair<float, float>> clip_min_max_;

View file

@ -1636,8 +1636,8 @@ static common::Status CheckTypes(MLDataType actual, MLDataType expected, const s
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str());
}
common::Status InferenceSession::ValidateInputs(const std::vector<std::string>& feed_names,
const std::vector<OrtValue>& feeds) const {
common::Status InferenceSession::ValidateInputs(gsl::span<const std::string> feed_names,
gsl::span<const OrtValue> feeds) const {
if (feed_names.size() != feeds.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(),
"elements, but feeds has ", feeds.size(), " elements.");
@ -1735,7 +1735,7 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
return Status::OK();
}
common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>& output_names,
common::Status InferenceSession::ValidateOutputs(gsl::span<const std::string> output_names,
const std::vector<OrtValue>* p_fetches) const {
if (p_fetches == nullptr) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL");
@ -1863,8 +1863,8 @@ struct ThreadPoolSpinningSwitch {
} // namespace
Status InferenceSession::Run(const RunOptions& run_options,
const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds,
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches,
gsl::span<const std::string> feed_names, gsl::span<const OrtValue> feeds,
gsl::span<const std::string> output_names, std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info) {
TimePoint tp;
if (session_profiler_.IsEnabled()) {
@ -1895,10 +1895,10 @@ Status InferenceSession::Run(const RunOptions& run_options,
<< " CUDA Graph for this model with tag: " << run_options.run_tag;
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph());
} else {
std::vector<IExecutionProvider*> exec_providers_to_stop;
InlinedVector<IExecutionProvider*> exec_providers_to_stop;
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
std::vector<AllocatorPtr> arenas_to_shrink;
InlinedVector<AllocatorPtr> arenas_to_shrink;
ORT_TRY {
if (!is_inited_) {
@ -2037,17 +2037,17 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}
common::Status InferenceSession::Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);
}
common::Status InferenceSession::Run(const RunOptions& run_options, const NameMLValMap& feeds_map,
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches) {
std::vector<std::string> feed_names;
std::vector<OrtValue> feeds;
gsl::span<const std::string> output_names, std::vector<OrtValue>* p_fetches) {
InlinedVector<std::string> feed_names;
InlinedVector<OrtValue> feeds;
auto num_feeds = feeds_map.size();
const auto num_feeds = feeds_map.size();
feed_names.reserve(num_feeds);
feeds.reserve(num_feeds);
@ -2177,7 +2177,7 @@ AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const
}
common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::string& ort_device_list,
/*out*/ std::vector<AllocatorPtr>& arenas_to_shrink) const {
/*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const {
arenas_to_shrink.reserve(5); // Allocate some memory for the container (we are unlikely to see more than 5 memory arena shrink requests)
std::istringstream ss_1(ort_device_list);
@ -2234,7 +2234,7 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st
return Status::OK();
}
void InferenceSession::ShrinkMemoryArenas(const std::vector<AllocatorPtr>& arenas_to_shrink) {
void InferenceSession::ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink) {
for (auto& alloc : arenas_to_shrink) {
auto status = static_cast<BFCArena*>(alloc.get())->Shrink();

View file

@ -300,8 +300,8 @@ class InferenceSession {
*/
common::Status Initialize() ORT_MUST_USE_RESULT;
common::Status Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
common::Status Run(const RunOptions& run_options, gsl::span<const std::string> feed_names,
gsl::span<const OrtValue> feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr) ORT_MUST_USE_RESULT;
@ -315,7 +315,7 @@ class InferenceSession {
* This should not be changed during execution of this function.
* @return OK if success.
*/
common::Status Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
common::Status Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
/**
@ -324,7 +324,7 @@ class InferenceSession {
* @param run_options use this to tune the Run call to your needs.
*/
common::Status Run(const RunOptions& run_options, const NameMLValMap& feeds,
const std::vector<std::string>& output_names,
gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
/**
@ -595,10 +595,10 @@ class InferenceSession {
common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
const TensorShape& expected_shape) const ORT_MUST_USE_RESULT;
common::Status ValidateInputs(const std::vector<std::string>& feed_names,
const std::vector<OrtValue>& feeds) const ORT_MUST_USE_RESULT;
common::Status ValidateInputs(gsl::span<const std::string> feed_names,
gsl::span<const OrtValue> feeds) const ORT_MUST_USE_RESULT;
common::Status ValidateOutputs(const std::vector<std::string>& output_names,
common::Status ValidateOutputs(gsl::span<const std::string> output_names,
const std::vector<OrtValue>* p_fetches) const ORT_MUST_USE_RESULT;
common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) ORT_MUST_USE_RESULT;
@ -617,13 +617,13 @@ class InferenceSession {
*/
common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list,
/*out*/ std::vector<AllocatorPtr>& arenas_to_shrink) const ORT_MUST_USE_RESULT;
/*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const ORT_MUST_USE_RESULT;
/*
* Performs the shrinkage of arenas requested to be shrunk by the user
* The `arenas_to_shrink` parameter is got from ValidateAndParseShrinkArenaString()
*/
void ShrinkMemoryArenas(const std::vector<AllocatorPtr>& arenas_to_shrink);
void ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink);
#if !defined(ORT_MINIMAL_BUILD)
virtual common::Status AddPredefinedTransformers(

View file

@ -830,9 +830,23 @@ struct ProviderHostImpl : ProviderHost {
const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); }
// Tensor (wrapped)
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) override { return std::make_unique<Tensor>(p_type, shape, std::move(allocator)); }
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) override { return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset); }
void Tensor__operator_delete(Tensor* p) override { delete p; }
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) override {
return std::make_unique<Tensor>(p_type, shape, std::move(allocator));
}
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) override {
return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset);
}
std::unique_ptr<Tensor> Tensor__construct_default() override {
return std::make_unique<Tensor>();
}
virtual void Tensor__move_assign(Tensor& lhs, Tensor&& rhs) noexcept override {
lhs = std::move(rhs);
};
void Tensor__operator_delete(Tensor* p) noexcept override { delete p; }
void Tensor__InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, OrtValue& ort_value) override {
Tensor::InitOrtValue(elt_type, shape, std::move(allocator), ort_value);

View file

@ -439,7 +439,8 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
const void* orig_buffer = results[0].Get<Tensor>().DataRaw();
RunOptions ro;
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<const std::string>(),
EmptySpan<const OrtValue>(), AsSpan({std::string("values")}), &results, nullptr));
EXPECT_EQ(results[0].Get<Tensor>().DataRaw(), orig_buffer);
EXPECT_THAT(results[0].Get<Tensor>().DataAsSpan<float>(), ::testing::ContainerEq(gsl::make_span(expected)));
@ -453,7 +454,8 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
std::vector<OrtValue> results;
RunOptions ro;
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<std::string>(),
EmptySpan<OrtValue>(), AsSpan({std::string("values")}), &results, nullptr));
// output buffer should not be the same as the initializer in SessionState
const auto& initializers = session.GetSessionState().GetInitializedTensors();
@ -464,7 +466,6 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
#if !defined(DISABLE_SPARSE_TENSORS)
TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
const std::vector<int64_t> dense_shape{3, 3};
std::vector<float> dense_data = {
0, 0, 1.764052391052246f,
@ -491,7 +492,7 @@ TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
results[0].Init(p_tensor.release(), ml_type, ml_type->GetDeleteFunc());
RunOptions ro;
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<std::string>(), EmptySpan<OrtValue>(), AsSpan<std::string>({"values"}), &results, nullptr));
ASSERT_TRUE(results[0].IsAllocated());
ASSERT_TRUE(results[0].IsSparseTensor());
@ -504,7 +505,7 @@ TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
EXPECT_THAT(coo_view.Indices().DataAsSpan<int64_t>(), ::testing::ContainerEq(gsl::make_span(expected_linear_indices)));
}
}
#endif // !defined(DISABLE_SPARSE_TENSORS)
#endif // !defined(DISABLE_SPARSE_TENSORS)
} // namespace test
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#include "onnx/defs/parser.h"
#include "core/common/span_utils.h"
#include "core/graph/model.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/session/inference_session.h"
@ -13,6 +14,7 @@
#include "test/framework/test_utils.h"
#include "test/common/tensor_op_test_utils.h"
// Unit tests to check the implementation of functions, model-local functions,
// function-inlining etc.
@ -56,7 +58,7 @@ static void Check(const char* source,
std::vector<OrtValue> fetches;
status = session_object.Run(run_options, feeds, {output_name}, &fetches);
status = session_object.Run(run_options, feeds, AsSpan({std::string(output_name)}), &fetches);
ASSERT_TRUE(status.IsOK()) << "Session Run failed.";
auto& tensor = fetches[0].Get<Tensor>();

View file

@ -86,11 +86,11 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
bool is_transpose_required = opset_ >= 13 && axis != (rank - 1);
std::unique_ptr<Tensor> transposed_dY;
std::unique_ptr<Tensor> transposed_Y;
std::vector<int64_t> transposed_input_dims;
std::unique_ptr<Tensor> intermediate_output; // output that the softmax implementation will write into while using transposed input
std::vector<size_t> permutation(rank);
Tensor transposed_dY;
Tensor transposed_Y;
TensorShapeVector transposed_input_dims;
Tensor intermediate_output; // output that the softmax implementation will write into while using transposed input
InlinedVector<size_t> permutation(rank);
if (is_transpose_required) {
AllocatorPtr alloc;
@ -112,26 +112,26 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
D = TensorShape(transposed_input_dims).SizeFromDimension(rank - 1);
// Allocate a temporary tensor to hold transposed input
auto temp_input0 = Tensor::Create(Y.DataType(), TensorShape(transposed_input_dims), alloc);
auto temp_input0 = Tensor(Y.DataType(), TensorShape(transposed_input_dims), alloc);
// Perform the transpose
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, Y, *temp_input0));
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, Y, temp_input0));
transposed_Y = std::move(temp_input0);
auto temp_input1 = Tensor::Create(Y.DataType(), TensorShape(transposed_input_dims), alloc);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, dY, *temp_input1));
auto temp_input1 = Tensor(Y.DataType(), TensorShape(transposed_input_dims), alloc);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, dY, temp_input1));
transposed_dY = std::move(temp_input1);
// Allocate memory for the intermediate output
intermediate_output = Tensor::Create(dX.DataType(), TensorShape(transposed_input_dims), alloc);
intermediate_output = Tensor(dX.DataType(), TensorShape(transposed_input_dims), alloc);
}
const int n = gsl::narrow_cast<int>(N);
const int d = gsl::narrow_cast<int>(D);
const int nd = gsl::narrow_cast<int>(N * D);
const float* Ydata = is_transpose_required ? transposed_Y->template Data<T>() : Y.template Data<float>();
const float* dYdata = is_transpose_required ? transposed_dY->template Data<T>() : dY.template Data<float>();
float* dXdata = is_transpose_required ? intermediate_output->template MutableData<T>() : dX.template MutableData<float>();
const float* Ydata = is_transpose_required ? transposed_Y.template Data<T>() : Y.template Data<float>();
const float* dYdata = is_transpose_required ? transposed_dY.template Data<T>() : dY.template Data<float>();
float* dXdata = is_transpose_required ? intermediate_output.template MutableData<T>() : dX.template MutableData<float>();
gsl::copy(gsl::make_span(dYdata, nd), gsl::make_span(dXdata, nd));
if (is_logsoftmaxgrad_) {
@ -164,7 +164,7 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
}
if (is_transpose_required) {
// Perform the transpose to get the axes back to the original ordering
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, *intermediate_output, dX));
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, intermediate_output, dX));
}
return Status::OK();

View file

@ -87,10 +87,10 @@ Status AdamWOptimizerBase::GenerateOutputs(OpKernelContext* ctx, size_t number_o
updated_values->Reserve(number_of_values);
for (size_t input_idx = 0; input_idx < number_of_values; ++input_idx) {
const Tensor& source_tensor = values->Get(input_idx);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
source_tensor.Shape(), alloc);
ORT_RETURN_IF_ERROR(CopyInputTensorToOutputTensor(source_tensor, *target_tensor));
updated_values->Add(std::move(*target_tensor)); // Add will check for type consistency
Tensor target_tensor(source_tensor.DataType(),
source_tensor.Shape(), alloc);
ORT_RETURN_IF_ERROR(CopyInputTensorToOutputTensor(source_tensor, target_tensor));
updated_values->Add(std::move(target_tensor)); // Add will check for type consistency
}
}