mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Eliminate memory allocations per recent profiling (#12225)
* Alloc begin FeedsFetches refactoring Refactor Tensor class Fix buffer deletor Remove new/delete deleted Adjust alloc move Fix up xnnpack provider Clarifying the comment on Create()
This commit is contained in:
parent
972bb9676c
commit
3bf614fd47
50 changed files with 289 additions and 227 deletions
|
|
@ -64,4 +64,7 @@ constexpr auto AsSpan(const T (&arr)[N]) {
|
|||
return details::AsSpanImpl(arr, N);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline gsl::span<const T> EmptySpan() { return gsl::span<const T>(); }
|
||||
|
||||
}
|
||||
|
|
@ -12,7 +12,7 @@ class BufferDeleter {
|
|||
public:
|
||||
BufferDeleter() : alloc_(nullptr) {}
|
||||
BufferDeleter(AllocatorPtr alloc)
|
||||
: alloc_(alloc) {}
|
||||
: alloc_(std::move(alloc)) {}
|
||||
|
||||
void operator()(void* p) const {
|
||||
if (alloc_)
|
||||
|
|
|
|||
|
|
@ -36,15 +36,10 @@ namespace onnxruntime {
|
|||
*/
|
||||
class Tensor final {
|
||||
public:
|
||||
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) {
|
||||
return std::make_unique<Tensor>(p_type, shape, std::move(allocator));
|
||||
}
|
||||
|
||||
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, void* p_data,
|
||||
const OrtMemoryInfo& alloc, ptrdiff_t offset = 0,
|
||||
gsl::span<const int64_t> strides = {}) {
|
||||
return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset, strides);
|
||||
}
|
||||
// NB! Removing Create() methods returning unique_ptr<Tensor>. Still available in other EPs that are dynamically linked.
|
||||
// Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine.
|
||||
// Use InitOrtValue() methods to allocate for OrtValue.
|
||||
|
||||
Tensor() = default; // to allow creating vector<Tensor> to support seq(tensor)
|
||||
|
||||
|
|
|
|||
|
|
@ -302,7 +302,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
|
|||
// buffer memory and we don not want it uninitialized and generate different hashes
|
||||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_weights_data, 0, packed_weights_data_size);
|
||||
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(std::move(alloc)));
|
||||
packed_weights_size_[qkv_index] = packb_size;
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
|
|
@ -470,7 +470,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
// D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
|
||||
// (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator)));
|
||||
|
||||
auto Q = reinterpret_cast<T*>(gemm_data);
|
||||
auto K = Q + static_cast<size_t>(batch_size) * sequence_length * q_hidden_size;
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class AttentionCPUBase : public AttentionBase {
|
|||
// Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H)
|
||||
auto out_tmp_data =
|
||||
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
|
||||
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
|
||||
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator)));
|
||||
|
||||
ComputeVxAttentionScore(output->template MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs), V,
|
||||
batch_size, sequence_length, past_sequence_length, v_head_size, v_hidden_size,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
|
|||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_weights_data, 0, packed_weights_data_size);
|
||||
|
||||
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(std::move(alloc)));
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, false /*AIsSigned*/, weights_is_signed_, packed_weights_data);
|
||||
|
|
@ -212,7 +212,7 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
// STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH)
|
||||
// D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned.
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator)));
|
||||
|
||||
auto Q = reinterpret_cast<T*>(gemm_data);
|
||||
auto K = Q + static_cast<int64_t>(batch_size) * sequence_length * hidden_size;
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
|
|||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
|
||||
uint8_t* a_data_quant = static_cast<uint8_t*>(allocator->Alloc(SafeInt<size_t>(num_of_elements) * sizeof(uint8_t)));
|
||||
BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(allocator));
|
||||
BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(std::move(allocator)));
|
||||
|
||||
ParQuantizeLinear(a_data, a_data_quant, num_of_elements, a_scale, a_zero_point, ctx->GetOperatorThreadPool());
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ Status NhwcMaxPool<T8Bits>::Compute(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
int64_t col_buffer_batch_count = std::min(output_image_size, output_batch_count);
|
||||
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(const T8Bits*)) * kernel_size * col_buffer_batch_count);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
|
||||
std::vector<T8Bits> padding_data(static_cast<size_t>(C), std::numeric_limits<T8Bits>::lowest());
|
||||
|
||||
const auto* Xdata = X->template Data<T8Bits>();
|
||||
|
|
|
|||
|
|
@ -583,7 +583,7 @@ Status QLinearAveragePool::ComputeImpl(OpKernelContext* context) const {
|
|||
BufferUniquePtr x_data_fp32_guard;
|
||||
if (kernel_shape.size() <= 3) {
|
||||
x_data_fp32 = (float*)allocator->Alloc(SafeInt<size_t>(x_shape.Size()) * sizeof(float));
|
||||
x_data_fp32_guard = BufferUniquePtr(x_data_fp32, BufferDeleter(allocator));
|
||||
x_data_fp32_guard = BufferUniquePtr(x_data_fp32, BufferDeleter(std::move(allocator)));
|
||||
dequantize_array(x_shape.Size(), X_data, x_scale, x_zero_point, x_data_fp32, tp);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
AllocatorPtr allocator,
|
||||
void* /*stream*/,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices) {
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices) {
|
||||
if (input->IsDataType<float>()) {
|
||||
return GetTopK<float>(input, axis, k, largest, sorted, allocator, threadpool, output_values, output_indices);
|
||||
}
|
||||
|
|
@ -343,20 +343,20 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
constexpr bool largest = true;
|
||||
constexpr bool sorted = true; // results returned in sorted order.
|
||||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
Tensor topk_scores;
|
||||
Tensor topk_indices;
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
|
||||
topk_scores, topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
dumper->Print("topk_indices", *(topk_indices.get()));
|
||||
dumper->Print("topk_scores", topk_scores);
|
||||
dumper->Print("topk_indices", topk_indices);
|
||||
#endif
|
||||
|
||||
// Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
|
||||
// next_indices = (next_tokens / vocab_size).long()
|
||||
// next_tokens = next_tokens % vocab_size
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices->DataAsSpan<int64_t>();
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
|
||||
offset = 0;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (unsigned int j = 0; j < top_k; j++, offset++) {
|
||||
|
|
@ -365,7 +365,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
}
|
||||
}
|
||||
|
||||
gsl::span<const T> next_scores = topk_scores->DataAsSpan<T>();
|
||||
gsl::span<const T> next_scores = topk_scores.DataAsSpan<T>();
|
||||
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
|
||||
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());
|
||||
|
||||
|
|
@ -453,8 +453,8 @@ Status GreedySearchProcessLogits(
|
|||
constexpr bool largest = true;
|
||||
constexpr bool sorted = false;
|
||||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
Tensor topk_scores;
|
||||
Tensor topk_indices;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
TopK(&input,
|
||||
axis,
|
||||
|
|
@ -472,7 +472,7 @@ Status GreedySearchProcessLogits(
|
|||
dumper->Print("topk_indices", *(topk_indices.get()));
|
||||
#endif
|
||||
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices->DataAsSpan<int64_t>();
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
|
||||
gsl::copy(next_token_indices, greedy_state->next_tokens_cpu);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ using TopkFunc = std::function<Status(
|
|||
AllocatorPtr allocator,
|
||||
void* stream, // cudaStream_t
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices)>;
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices)>;
|
||||
|
||||
// Create subgraph inputs: input_ids, position_ids and attention_mask (for GPT-2).
|
||||
using CreateGptInputsFunc = std::function<Status(
|
||||
|
|
@ -168,8 +168,8 @@ Status TopK(
|
|||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices);
|
||||
|
||||
Status AddToFeeds(
|
||||
const IExecutionProvider* execution_provider,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ Status Subgraph::Setup(const SessionState& session_state,
|
|||
session_state_ = &session_state;
|
||||
subgraph_session_state_ = &subgraph_session_state;
|
||||
|
||||
std::vector<std::string> feed_names;
|
||||
InlinedVector<std::string_view> feed_names;
|
||||
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
|
||||
// Use the first output (logits) to find device location.
|
||||
|
|
@ -69,25 +69,24 @@ Status Subgraph::Setup(const SessionState& session_state,
|
|||
feed_names.push_back(entry->Name());
|
||||
}
|
||||
|
||||
std::vector<OrtDevice> feed_locations;
|
||||
feed_locations.resize(feed_names.size());
|
||||
InlinedVector<OrtDevice> feed_locations;
|
||||
feed_locations.reserve(feed_names.size());
|
||||
|
||||
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
|
||||
if (i >= subgraph_input_names.size()) { // Implicit inputs
|
||||
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
|
||||
feed_locations[i] = location.device;
|
||||
feed_locations.push_back(location.device);
|
||||
} else {
|
||||
feed_locations[i] = default_location.device;
|
||||
feed_locations.push_back(default_location.device);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<FeedsFetchesManager> ffm;
|
||||
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
|
||||
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
|
||||
subgraph_session_state.GetOrtValueNameIdxMap(), feeds_fetches_manager_));
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *feeds_fetches_manager_));
|
||||
|
||||
// Setup the locations where we want the subgraph output to end up on
|
||||
std::vector<const OrtMemoryInfo*> fetch_locations;
|
||||
InlinedVector<const OrtMemoryInfo*> fetch_locations;
|
||||
fetch_locations.reserve(num_subgraph_outputs);
|
||||
|
||||
// Past state need to be where we can feed them in to the next iteration, so set the location to match the feed.
|
||||
|
|
@ -95,9 +94,7 @@ Status Subgraph::Setup(const SessionState& session_state,
|
|||
fetch_locations.push_back(&default_location);
|
||||
}
|
||||
|
||||
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
|
||||
|
||||
feeds_fetches_manager_ = std::move(ffm);
|
||||
utils::FinalizeFeedFetchCopyInfo(*feeds_fetches_manager_, feed_locations, fetch_locations);
|
||||
|
||||
// Check subgraph only need once so put in Setup function.
|
||||
auto& inputs = subgraph.GetInputs();
|
||||
|
|
|
|||
|
|
@ -48,7 +48,9 @@ class Subgraph {
|
|||
Status Setup(const SessionState& session_state,
|
||||
const SessionState& subgraph_session_state);
|
||||
|
||||
FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); }
|
||||
FeedsFetchesManager* GetFeedsFetchesManager() {
|
||||
return (feeds_fetches_manager_.has_value()) ? &*feeds_fetches_manager_ : nullptr;
|
||||
}
|
||||
|
||||
const IExecutionProvider* GetProvider() const;
|
||||
|
||||
|
|
@ -65,7 +67,7 @@ class Subgraph {
|
|||
AllocatorPtr allocator_;
|
||||
const SessionState* session_state_;
|
||||
const SessionState* subgraph_session_state_;
|
||||
std::unique_ptr<FeedsFetchesManager> feeds_fetches_manager_;
|
||||
std::optional<FeedsFetchesManager> feeds_fetches_manager_;
|
||||
bool is_output_float16_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
onnxruntime::concurrency::ThreadPool* /*threadpool*/,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices) {
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices) {
|
||||
ORT_ENFORCE(nullptr != input);
|
||||
int32_t rank = static_cast<int32_t>(input->Shape().NumDimensions());
|
||||
|
||||
|
|
@ -51,15 +51,15 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
int64_t dimension = input_shape[axis];
|
||||
int64_t N = elem_nums_cuda[0] / dimension;
|
||||
|
||||
output_values = Tensor::Create(input->DataType(), output_shape, allocator);
|
||||
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
|
||||
output_values = std::move(*Tensor::Create(input->DataType(), output_shape, allocator));
|
||||
output_indices = std::move(*Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, std::move(allocator)));
|
||||
|
||||
if (input->IsDataType<float>()) {
|
||||
return TopKImpl<float>(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here
|
||||
reinterpret_cast<cudaStream_t>(stream),
|
||||
input->Data<float>(),
|
||||
static_cast<float*>(output_values->MutableDataRaw()),
|
||||
static_cast<int64_t*>(output_indices->MutableDataRaw()),
|
||||
static_cast<float*>(output_values.MutableDataRaw()),
|
||||
static_cast<int64_t*>(output_indices.MutableDataRaw()),
|
||||
elem_nums_cuda,
|
||||
static_cast<size_t>(elem_nums_cuda.Size()),
|
||||
static_cast<int32_t>(axis),
|
||||
|
|
@ -72,8 +72,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
return TopKImpl<MLFloat16>(nullptr,
|
||||
reinterpret_cast<cudaStream_t>(stream),
|
||||
input->Data<MLFloat16>(),
|
||||
static_cast<MLFloat16*>(output_values->MutableDataRaw()),
|
||||
static_cast<int64_t*>(output_indices->MutableDataRaw()),
|
||||
static_cast<MLFloat16*>(output_values.MutableDataRaw()),
|
||||
static_cast<int64_t*>(output_indices.MutableDataRaw()),
|
||||
elem_nums_cuda,
|
||||
static_cast<size_t>(elem_nums_cuda.Size()),
|
||||
static_cast<int32_t>(axis),
|
||||
|
|
@ -350,10 +350,10 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
constexpr bool largest = true;
|
||||
constexpr bool sorted = true; // results returned in sorted order.
|
||||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
std::unique_ptr<Tensor> topk_scores = Tensor::CreateDefault();
|
||||
std::unique_ptr<Tensor> topk_indices = Tensor::CreateDefault();
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
|
||||
topk_scores, topk_indices));
|
||||
*topk_scores, *topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
|
|
@ -514,10 +514,10 @@ Status GreedySearchProcessLogits(
|
|||
constexpr bool largest = true;
|
||||
constexpr bool sorted = false;
|
||||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
auto topk_scores = Tensor::CreateDefault();
|
||||
auto topk_indices = Tensor::CreateDefault();
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
|
||||
topk_scores, topk_indices));
|
||||
*topk_scores, *topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices);
|
||||
|
||||
Status AddToFeeds(const IExecutionProvider* execution_provider,
|
||||
std::initializer_list<OrtValue> inputs,
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@
|
|||
#include "core/framework/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
common::Status FeedsFetchesInfo::MapNamesToMLValueIdxs(const std::vector<std::string>& names,
|
||||
common::Status FeedsFetchesInfo::MapNamesToMLValueIdxs(gsl::span<const std::string> names,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
std::vector<int>& ort_value_idxs) {
|
||||
InlinedVector<int>& ort_value_idxs) {
|
||||
auto status = Status::OK();
|
||||
|
||||
ort_value_idxs.reserve(names.size());
|
||||
|
|
@ -40,8 +40,8 @@ Status FeedsFetchesInfo::SetMLValueIdxs(const OrtValueNameIdxMap& ort_value_name
|
|||
return status;
|
||||
}
|
||||
|
||||
Status FeedsFetchesManager::Create(const std::vector<std::string>& feed_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
Status FeedsFetchesManager::Create(gsl::span<const std::string> feed_names,
|
||||
gsl::span<const std::string> output_names,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
std::unique_ptr<FeedsFetchesManager>& feed_fetch_manager) {
|
||||
FeedsFetchesInfo info{feed_names, output_names, ort_value_name_idx_map};
|
||||
|
|
@ -51,11 +51,22 @@ Status FeedsFetchesManager::Create(const std::vector<std::string>& feed_names,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FeedsFetchesManager::Create(gsl::span<const std::string_view> feed_names,
|
||||
gsl::span<const std::string> output_names,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
std::optional<FeedsFetchesManager>& feed_fetch_manager) {
|
||||
FeedsFetchesInfo info{feed_names, output_names, ort_value_name_idx_map};
|
||||
|
||||
feed_fetch_manager.emplace(std::move(info));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FeedsFetchesManager::FeedsFetchesManager(FeedsFetchesInfo&& info)
|
||||
: feeds_fetches_info_{info} {
|
||||
: feeds_fetches_info_(std::move(info)) {
|
||||
// init with default values
|
||||
feeds_device_copy_info_.resize(info.feed_names.size());
|
||||
fetches_device_copy_info_.resize(info.output_names.size());
|
||||
feeds_device_copy_info_.resize(feeds_fetches_info_.feed_names.size());
|
||||
fetches_device_copy_info_.resize(feeds_fetches_info_.output_names.size());
|
||||
}
|
||||
|
||||
void FeedsFetchesManager::SetDeviceCopyChecks(DeviceCopyCheck input_copy_needed, DeviceCopyCheck output_copy_needed) {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include "core/common/inlined_containers_fwd.h"
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/framework/ort_value.h"
|
||||
|
|
@ -30,25 +32,43 @@ struct DeviceCopyChecks {
|
|||
|
||||
struct FeedsFetchesInfo {
|
||||
FeedsFetchesInfo() = default;
|
||||
FeedsFetchesInfo(const std::vector<std::string>& feed_names_in,
|
||||
const std::vector<std::string>& output_names_in,
|
||||
FeedsFetchesInfo(gsl::span<const std::string> feed_names_in,
|
||||
gsl::span<const std::string> output_names_in,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map)
|
||||
: feed_names{feed_names_in}, output_names{output_names_in} {
|
||||
: feed_names(),
|
||||
output_names() {
|
||||
feed_names.reserve(feed_names_in.size());
|
||||
feed_names.assign(feed_names_in.begin(), feed_names_in.end());
|
||||
output_names.reserve(output_names_in.size());
|
||||
output_names.assign(output_names_in.begin(), output_names_in.end());
|
||||
ORT_THROW_IF_ERROR(SetMLValueIdxs(ort_value_name_idx_map));
|
||||
}
|
||||
|
||||
static Status MapNamesToMLValueIdxs(const std::vector<std::string>& names,
|
||||
FeedsFetchesInfo(gsl::span<const std::string_view> feed_names_in,
|
||||
gsl::span<const std::string> output_names_in,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map)
|
||||
: feed_names(),
|
||||
output_names() {
|
||||
feed_names.reserve(feed_names_in.size());
|
||||
feed_names.assign(feed_names_in.begin(), feed_names_in.end());
|
||||
output_names.reserve(output_names_in.size());
|
||||
output_names.assign(output_names_in.begin(), output_names_in.end());
|
||||
ORT_THROW_IF_ERROR(SetMLValueIdxs(ort_value_name_idx_map));
|
||||
}
|
||||
|
||||
|
||||
static Status MapNamesToMLValueIdxs(gsl::span<const std::string> names,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
std::vector<int>& ort_value_idxs);
|
||||
InlinedVector<int>& ort_value_idxs);
|
||||
|
||||
// set the ort_value_idxs for the current values in feed_names and output_names
|
||||
Status SetMLValueIdxs(const OrtValueNameIdxMap& ort_value_name_idx_map);
|
||||
|
||||
std::vector<std::string> feed_names;
|
||||
std::vector<std::string> output_names;
|
||||
InlinedVector<std::string> feed_names;
|
||||
InlinedVector<std::string> output_names;
|
||||
|
||||
std::vector<int> feeds_mlvalue_idxs;
|
||||
std::vector<int> fetches_mlvalue_idxs;
|
||||
InlinedVector<int> feeds_mlvalue_idxs;
|
||||
InlinedVector<int> fetches_mlvalue_idxs;
|
||||
};
|
||||
|
||||
struct MLValueCopyInfo {
|
||||
|
|
@ -58,10 +78,14 @@ struct MLValueCopyInfo {
|
|||
|
||||
class FeedsFetchesManager {
|
||||
public:
|
||||
static Status Create(const std::vector<std::string>& feed_names, const std::vector<std::string>& output_names,
|
||||
static Status Create(gsl::span<const std::string> feed_names, gsl::span<const std::string> output_names,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
std::unique_ptr<FeedsFetchesManager>& feeds_fetches_manager);
|
||||
|
||||
static Status Create(gsl::span<const std::string_view> feed_names, gsl::span<const std::string> output_names,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
std::optional<FeedsFetchesManager>& feeds_fetches_manager);
|
||||
|
||||
FeedsFetchesManager(FeedsFetchesInfo&& info);
|
||||
|
||||
const FeedsFetchesInfo& GetFeedsFetchesInfo() const { return feeds_fetches_info_; }
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ class IExecutor {
|
|||
* The lifetime of 'fetches' is limited by 'session_state'
|
||||
*/
|
||||
common::Status Execute(const SessionState& session_state,
|
||||
const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds,
|
||||
const std::vector<int>& fetch_mlvalue_idxs,
|
||||
gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds,
|
||||
gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const logging::Logger& logger) {
|
||||
std::unordered_map<size_t, CustomAllocator> fetch_allocators;
|
||||
|
|
@ -38,8 +38,8 @@ class IExecutor {
|
|||
}
|
||||
|
||||
// TODO: as fetch_allocators is optional, it should be a pointer instead of reference
|
||||
virtual common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
virtual common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
// optional custom allocators. key is index in fetches
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class OrtValueNameIdxMap {
|
|||
return p.first->second;
|
||||
}
|
||||
|
||||
common::Status GetIdx(const std::string& name, int& idx) const {
|
||||
common::Status GetIdx(std::string_view name, int& idx) const {
|
||||
idx = -1;
|
||||
|
||||
auto it = map_.find(name);
|
||||
|
|
|
|||
|
|
@ -80,9 +80,9 @@ void OrtValueTensorSlicer<T>::Iterator::MaterializeMLValue() const {
|
|||
//
|
||||
// TODO: Ideally we could avoid the overhead of creating a new Tensor (mainly cost of copying type and shape info)
|
||||
// and would simply update Tensor::p_data_ given all other info remains constant for each slice.
|
||||
auto sub_tensor = Tensor::Create(tensor_data_type_, per_iteration_shape_, const_cast<void*>(tensor_slice_data_raw), *tensor_location_);
|
||||
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
|
||||
current_ = OrtValue{sub_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()};
|
||||
OrtValue val;
|
||||
Tensor::InitOrtValue(tensor_data_type_, per_iteration_shape_, const_cast<void*>(tensor_slice_data_raw), *tensor_location_, val);
|
||||
current_ = std::move(val);
|
||||
}
|
||||
|
||||
template class OrtValueTensorSlicer<OrtValue>;
|
||||
|
|
|
|||
|
|
@ -130,8 +130,8 @@ static Status ReleaseNodeMLValues(ExecutionFrame& frame,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PartialExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
Status PartialExecutor::Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
const logging::Logger& logger) {
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ class PartialExecutor : public IExecutor {
|
|||
ORT_UNUSED_PARAMETER(partial_graph_index_);
|
||||
}
|
||||
|
||||
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
const logging::Logger& logger) override;
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ ParallelExecutor::ParallelExecutor(const SessionState& session_state, const bool
|
|||
}
|
||||
}
|
||||
|
||||
Status ParallelExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
Status ParallelExecutor::Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
const logging::Logger& logger) {
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ class ParallelExecutor : public IExecutor {
|
|||
public:
|
||||
ParallelExecutor(const SessionState& session_state, const bool& terminate_flag = false);
|
||||
|
||||
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
const logging::Logger& logger) override;
|
||||
|
|
|
|||
|
|
@ -26,8 +26,9 @@ struct PartialGraphExecutionState {
|
|||
size_t GetProgramCounterStart() { return program_counter_start_; }
|
||||
size_t GetProgramCounterEnd() { return program_counter_end_; }
|
||||
|
||||
ExecutionFrame& GetExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
|
||||
const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
|
||||
ExecutionFrame& GetExecutionFrame(gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> fetches,
|
||||
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
||||
const SessionState& session_state) {
|
||||
if (execution_frame_ == nullptr) {
|
||||
|
|
|
|||
|
|
@ -141,8 +141,8 @@ static Status ReleaseNodeMLValues(ExecutionFrame& frame,
|
|||
const SequentialExecutionPlan::NodeExecutionPlan& node_exec_plan,
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status SequentialExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
Status SequentialExecutor::Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
const logging::Logger& logger) {
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ class SequentialExecutor : public IExecutor {
|
|||
SequentialExecutor(const bool& terminate_flag = false, const bool only_execute_path_to_fetches = false)
|
||||
: terminate_flag_{terminate_flag}, only_execute_path_to_fetches_(only_execute_path_to_fetches) {}
|
||||
|
||||
common::Status Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
|
||||
common::Status Execute(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
|
||||
std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
|
||||
const logging::Logger& logger) override;
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_pr
|
|||
|
||||
static const OrtMemoryInfo& FindMemoryInfoForValue(const OrtValueNameIdxMap& map,
|
||||
const SequentialExecutionPlan& plan,
|
||||
const std::string& name) {
|
||||
std::string_view name) {
|
||||
int idx = -1;
|
||||
auto status = map.GetIdx(name, idx);
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
|
|
@ -261,7 +261,7 @@ static const OrtMemoryInfo& FindMemoryInfoForValue(const OrtValueNameIdxMap& map
|
|||
}
|
||||
|
||||
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
|
||||
const std::string& name) {
|
||||
std::string_view name) {
|
||||
const auto* exec_plan_ptr = session_state.GetExecutionPlan();
|
||||
ORT_ENFORCE(exec_plan_ptr);
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session
|
|||
}
|
||||
|
||||
static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& session_state,
|
||||
const std::vector<std::string>& feed_names,
|
||||
gsl::span<const std::string> feed_names,
|
||||
std::vector<MLValueCopyInfo>& copy_info) {
|
||||
for (size_t idx = 0, end = feed_names.size(); idx < end; ++idx) {
|
||||
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, feed_names[idx], copy_info[idx]));
|
||||
|
|
@ -316,7 +316,7 @@ static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& sessio
|
|||
// get the source device info for the node producing each output that we will return in the fetches.
|
||||
// target device info is not known until runtime.
|
||||
static common::Status CalculateStaticCopyInfoForFetches(const SessionState& session_state,
|
||||
const std::vector<std::string>& fetch_names,
|
||||
gsl::span<const std::string> fetch_names,
|
||||
std::vector<MLValueCopyInfo>& copy_info) {
|
||||
for (size_t idx = 0, end = fetch_names.size(); idx < end; ++idx) {
|
||||
const std::string& output_name = fetch_names[idx];
|
||||
|
|
@ -362,7 +362,7 @@ common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
|
|||
}
|
||||
|
||||
// update the allocation_provider in the copy info based on the actual feeds
|
||||
static bool FinalizeCopyInfoForFeeds(const std::vector<OrtDevice>& feed_locations,
|
||||
static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
|
||||
std::vector<MLValueCopyInfo>& copy_info) {
|
||||
ORT_ENFORCE(feed_locations.size() == copy_info.size());
|
||||
bool copy_needed = false;
|
||||
|
|
@ -378,7 +378,7 @@ static bool FinalizeCopyInfoForFeeds(const std::vector<OrtDevice>& feed_location
|
|||
return copy_needed;
|
||||
}
|
||||
|
||||
static bool FinalizeCopyInfoForFetches(const std::vector<const OrtMemoryInfo*>& fetch_alloc_info,
|
||||
static bool FinalizeCopyInfoForFetches(gsl::span<const OrtMemoryInfo* const>& fetch_alloc_info,
|
||||
std::vector<MLValueCopyInfo>& copy_info) {
|
||||
ORT_ENFORCE(fetch_alloc_info.size() == copy_info.size());
|
||||
bool copy_needed = false;
|
||||
|
|
@ -402,8 +402,8 @@ static bool FinalizeCopyInfoForFetches(const std::vector<const OrtMemoryInfo*>&
|
|||
// Finalize the copy info using the OrtDevice and OrtMemoryInfo for the feeds and fetches
|
||||
// This can be used by control flow nodes prior to the execution of the overall graph.
|
||||
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtDevice>& feed_locations,
|
||||
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info) {
|
||||
gsl::span<const OrtDevice> feed_locations,
|
||||
gsl::span<const OrtMemoryInfo* const> fetch_alloc_info) {
|
||||
if (feeds_fetches_manager.GetDeviceCopyChecks().status == DeviceCopyCheck::NoCopy)
|
||||
return;
|
||||
|
||||
|
|
@ -418,7 +418,7 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
|||
|
||||
// Finalize the copy info using the OrtValue instances for the feeds and fetches
|
||||
static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds,
|
||||
gsl::span<const OrtValue> feeds,
|
||||
std::vector<OrtValue>& fetches) {
|
||||
if (feeds_fetches_manager.GetDeviceCopyChecks().status == DeviceCopyCheck::NoCopy)
|
||||
return;
|
||||
|
|
@ -465,9 +465,9 @@ static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager
|
|||
}
|
||||
|
||||
static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
|
||||
const std::vector<OrtValue>& orig_feeds,
|
||||
gsl::span<const OrtValue> orig_feeds,
|
||||
std::vector<OrtValue>& new_feeds,
|
||||
const std::vector<MLValueCopyInfo>& copy_info) {
|
||||
gsl::span<const MLValueCopyInfo> copy_info) {
|
||||
size_t num_feeds = orig_feeds.size();
|
||||
ORT_ENFORCE(copy_info.size() == num_feeds);
|
||||
|
||||
|
|
@ -560,20 +560,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
|
|||
|
||||
static common::Status ExecuteGraphImpl(const SessionState& session_state,
|
||||
const FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag,
|
||||
const logging::Logger& logger, const bool only_execute_path_to_fetches = false) {
|
||||
std::unique_ptr<IExecutor> p_exec;
|
||||
// avoid memory allocations
|
||||
std::optional<SequentialExecutor> seq_executor;
|
||||
std::optional<ParallelExecutor> par_executor;
|
||||
IExecutor* p_exec = nullptr;
|
||||
if (execution_mode == ExecutionMode::ORT_SEQUENTIAL) {
|
||||
p_exec = std::make_unique<SequentialExecutor>(terminate_flag, only_execute_path_to_fetches);
|
||||
seq_executor.emplace(terminate_flag, only_execute_path_to_fetches);
|
||||
p_exec = &seq_executor.value();
|
||||
} else if (execution_mode == ExecutionMode::ORT_PARALLEL) {
|
||||
auto* p_inter_op_thread_pool = session_state.GetInterOpThreadPool();
|
||||
if (!p_inter_op_thread_pool) {
|
||||
LOGS(logger, WARNING) << "Only one thread was configured for parallel execution. Hence will use sequential execution.";
|
||||
p_exec = std::make_unique<SequentialExecutor>(terminate_flag, only_execute_path_to_fetches);
|
||||
seq_executor.emplace(terminate_flag, only_execute_path_to_fetches);
|
||||
p_exec = &seq_executor.value();
|
||||
} else {
|
||||
p_exec = std::make_unique<ParallelExecutor>(session_state, terminate_flag);
|
||||
par_executor.emplace(session_state, terminate_flag);
|
||||
p_exec = &par_executor.value();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -588,7 +594,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
|
|||
feeds_fetches_info.fetches_mlvalue_idxs, fetches, fetch_allocators,
|
||||
logger));
|
||||
} else {
|
||||
const std::vector<OrtValue>* p_feeds = &feeds;
|
||||
auto feeds_to_use = feeds;
|
||||
std::vector<OrtValue>* p_fetches = &fetches;
|
||||
std::vector<OrtValue> device_feeds;
|
||||
std::vector<OrtValue> device_fetches;
|
||||
|
|
@ -596,7 +602,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
|
|||
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
|
||||
const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo();
|
||||
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info));
|
||||
p_feeds = &device_feeds;
|
||||
feeds_to_use = device_feeds;
|
||||
}
|
||||
|
||||
auto num_outputs = fetches.size();
|
||||
|
|
@ -619,7 +625,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(p_exec->Execute(session_state,
|
||||
feeds_fetches_info.feeds_mlvalue_idxs, *p_feeds,
|
||||
feeds_fetches_info.feeds_mlvalue_idxs, feeds_to_use,
|
||||
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, fetch_allocators,
|
||||
logger));
|
||||
|
||||
|
|
@ -633,7 +639,7 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
|
|||
|
||||
common::Status ExecuteGraph(const SessionState& session_state,
|
||||
FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag,
|
||||
const logging::Logger& logger, bool only_execute_path_to_fetches) {
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));
|
||||
|
|
@ -649,7 +655,7 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
const logging::Logger& logger, PartialGraphExecutionState& state,
|
||||
const OrtValueCachePtr& cache,
|
||||
int32_t partial_graph_index) {
|
||||
|
|
@ -667,7 +673,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
|
|||
feeds_fetches_info.fetches_mlvalue_idxs, fetches, {},
|
||||
logger));
|
||||
} else {
|
||||
const std::vector<OrtValue>* p_feeds = &feeds;
|
||||
auto p_feeds = feeds;
|
||||
std::vector<OrtValue>* p_fetches = &fetches;
|
||||
std::vector<OrtValue> device_feeds;
|
||||
std::vector<OrtValue> device_fetches;
|
||||
|
|
@ -675,7 +681,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
|
|||
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
|
||||
const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo();
|
||||
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info));
|
||||
p_feeds = &device_feeds;
|
||||
p_feeds = device_feeds;
|
||||
}
|
||||
|
||||
auto num_outputs = fetches.size();
|
||||
|
|
@ -698,7 +704,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(executor.Execute(session_state,
|
||||
feeds_fetches_info.feeds_mlvalue_idxs, *p_feeds,
|
||||
feeds_fetches_info.feeds_mlvalue_idxs, p_feeds,
|
||||
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, {},
|
||||
logger));
|
||||
|
||||
|
|
@ -712,7 +718,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
|
|||
#endif
|
||||
|
||||
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger) {
|
||||
auto status = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
|
||||
// Searches the allocation plan from the session_state to find the OrtMemoryInfo for the value 'name'.
|
||||
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
|
||||
const std::string& name);
|
||||
std::string_view name);
|
||||
|
||||
// Initialize the feed and fetch copy info using session_state.
|
||||
// Determines the device that each graph input that will be fed will be consumed on,
|
||||
|
|
@ -82,18 +82,18 @@ common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
|
|||
// Finalize the feed and fetch copy info using session_state and the device and location information from the feeds
|
||||
// and fetches that will be used in graph execution.
|
||||
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtDevice>& feed_locations,
|
||||
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info);
|
||||
gsl::span<const OrtDevice> feed_locations,
|
||||
gsl::span<const OrtMemoryInfo* const> fetch_alloc_info);
|
||||
|
||||
// Execute the main graph. The feed_fetches_manager will be finalized based on the provided feeds and fetches.
|
||||
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
|
||||
bool only_execute_path_to_fetches = false);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
const logging::Logger& logger, PartialGraphExecutionState& state,
|
||||
const OrtValueCachePtr& cache,
|
||||
int32_t partial_graph_index);
|
||||
|
|
@ -102,7 +102,7 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
|
|||
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
|
||||
// See IControlFlowNode::SetupSubgraphExecutionInfo usage in the control flow kernels.
|
||||
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
|
||||
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger);
|
||||
|
||||
|
|
|
|||
|
|
@ -262,11 +262,11 @@ std::vector<uint8_t> ApiTensor::Data() const {
|
|||
const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto_.data_type())->GetElementType();
|
||||
auto tensor_shape_dims = utils::GetTensorShapeFromTensorProto(tensor_proto_);
|
||||
TensorShape tensor_shape{std::move(tensor_shape_dims)};
|
||||
auto tensor = onnxruntime::Tensor::Create(tensor_dtype, tensor_shape, cpu_allocator_);
|
||||
onnxruntime::Tensor tensor(tensor_dtype, tensor_shape, cpu_allocator_);
|
||||
ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path_.ToPathString().c_str(),
|
||||
tensor_proto_, *tensor));
|
||||
size_t num_bytes = gsl::narrow_cast<size_t>(tensor->SizeInBytes());
|
||||
const uint8_t* data = static_cast<const uint8_t*>(tensor->DataRaw());
|
||||
tensor_proto_, tensor));
|
||||
size_t num_bytes = gsl::narrow_cast<size_t>(tensor.SizeInBytes());
|
||||
const uint8_t* data = static_cast<const uint8_t*>(tensor.DataRaw());
|
||||
return std::vector<uint8_t>(data, data + num_bytes);
|
||||
}
|
||||
// </ApiTensor>
|
||||
|
|
@ -515,7 +515,7 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
|
|||
const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto->data_type())->GetElementType();
|
||||
auto tensor_shape_dims = utils::GetTensorShapeFromTensorProto(*tensor_proto);
|
||||
TensorShape tensor_shape{tensor_shape_dims};
|
||||
std::unique_ptr<Tensor> in_tensor = Tensor::Create(tensor_dtype, tensor_shape, cpu_allocator_);
|
||||
Tensor in_tensor(tensor_dtype, tensor_shape, cpu_allocator_);
|
||||
|
||||
std::vector<int64_t> new_tensor_shape_dims;
|
||||
std::vector<size_t> permutations;
|
||||
|
|
@ -528,13 +528,12 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
|
|||
}
|
||||
|
||||
TensorShape new_tensor_shape(new_tensor_shape_dims);
|
||||
|
||||
std::unique_ptr<Tensor> out_tensor = Tensor::Create(tensor_dtype, new_tensor_shape, cpu_allocator_);
|
||||
Tensor out_tensor(tensor_dtype, new_tensor_shape, cpu_allocator_);
|
||||
|
||||
ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), graph_.ModelPath().ToPathString().c_str(),
|
||||
*tensor_proto, *in_tensor));
|
||||
*tensor_proto, in_tensor));
|
||||
|
||||
ORT_THROW_IF_ERROR(Transpose::DoTranspose(permutations, *in_tensor, *out_tensor));
|
||||
ORT_THROW_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor));
|
||||
|
||||
auto& node_arg = *graph_.GetNodeArg(name_str);
|
||||
TensorShapeProto new_shape;
|
||||
|
|
@ -544,7 +543,7 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector<int
|
|||
|
||||
node_arg.SetShape(new_shape);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto new_tensor_proto = utils::TensorToTensorProto(*out_tensor, name_str);
|
||||
ONNX_NAMESPACE::TensorProto new_tensor_proto = utils::TensorToTensorProto(out_tensor, name_str);
|
||||
graph_.RemoveInitializedTensor(name_str);
|
||||
graph_.AddInitializedTensor(new_tensor_proto);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ static Status MultinomialCompute(OpKernelContext* ctx,
|
|||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
|
||||
auto cdf_data = static_cast<double*>(alloc->Alloc(SafeInt<size_t>(sizeof(double)) * num_classes));
|
||||
BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(std::move(alloc)));
|
||||
Eigen::array<int64_t, 1> cdf_dims = {{num_classes}};
|
||||
auto cdf = EigenVector<double>(cdf_data, cdf_dims);
|
||||
// END create temporary tensor
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@
|
|||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -123,7 +122,7 @@ static void HeapifyIthPosition(int64_t* heap, size_t i, size_t k, const HeapCmp&
|
|||
template <class Comparator>
|
||||
static void SelectTopK(const Comparator& comparer,
|
||||
int64_t row_offset, int64_t num_blocks, int64_t block_slice, int64_t inter_block_offset,
|
||||
const unsigned k, bool sort_top_k, vector<int64_t>& data_holder) {
|
||||
const unsigned k, bool sort_top_k, std::vector<int64_t>& data_holder) {
|
||||
for (int64_t l = 0; l < num_blocks; ++l) {
|
||||
data_holder[l] = (row_offset + (l * block_slice + inter_block_offset));
|
||||
}
|
||||
|
|
@ -375,8 +374,8 @@ template <typename T>
|
|||
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices) {
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices) {
|
||||
const TensorShape& input_shape = input->Shape();
|
||||
|
||||
// Will return axis_ as is if positive or fixes it in case it is negative
|
||||
|
|
@ -394,8 +393,8 @@ Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool large
|
|||
TensorShape output_shape = input_shape;
|
||||
output_shape[axis_parsed] = k;
|
||||
|
||||
output_values = Tensor::Create(input->DataType(), output_shape, allocator);
|
||||
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
|
||||
output_values = Tensor(input->DataType(), output_shape, allocator);
|
||||
output_indices = Tensor(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
|
||||
|
||||
// no-op - no output buffers to fill - return silently
|
||||
if (k == 0) {
|
||||
|
|
@ -403,10 +402,10 @@ Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool large
|
|||
}
|
||||
|
||||
if (largest) {
|
||||
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
|
||||
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, &output_values, &output_indices, output_shape, k, sorted,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
|
||||
} else {
|
||||
FindTopKElements<LesserValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
|
||||
FindTopKElements<LesserValueCmp<T>>(input, input_shape, &output_values, &output_indices, output_shape, k, sorted,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
|
||||
}
|
||||
|
||||
|
|
@ -417,8 +416,8 @@ Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool large
|
|||
template Status GetTopK<float>(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices);
|
||||
|
||||
// Opset ver - 1 to 9
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,6 @@ template <typename T>
|
|||
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices);
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -80,7 +80,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(T)) * col_buffer_size);
|
||||
col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc));
|
||||
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
|
||||
}
|
||||
|
||||
T* col_buffer_data = static_cast<T*>(col_buffer.get());
|
||||
|
|
@ -234,7 +234,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
|
||||
auto* working_data = WorkingBufferSize > 0 ? alloc->Alloc(SafeInt<size_t>(sizeof(float)) * WorkingBufferSize)
|
||||
: nullptr;
|
||||
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr working_buffer(working_data, BufferDeleter(std::move(alloc)));
|
||||
|
||||
MlasConv(&Parameters,
|
||||
Xdata,
|
||||
|
|
@ -254,7 +254,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
const int64_t col_buffer_size = kernel_dim * output_image_size;
|
||||
|
||||
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
|
||||
auto* col_buffer_data = static_cast<float*>(col_buffer.get());
|
||||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ Status ConvTranspose<float>::PrePack(const Tensor& tensor, int input_idx, Alloca
|
|||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_filter_data, 0, packed_filter_data_size);
|
||||
|
||||
transposed_filter_ = BufferUniquePtr(packed_filter_data, BufferDeleter(alloc));
|
||||
transposed_filter_ = BufferUniquePtr(packed_filter_data, BufferDeleter(std::move(alloc)));
|
||||
|
||||
for (int64_t group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
|
||||
MlasTranspose(tensor.Data<float>() + (group_id * N * K),
|
||||
|
|
@ -146,7 +146,7 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
|
|||
|
||||
const int64_t col_buffer_size = kernel_dim * p.input_shape.Size();
|
||||
auto col_data = alloc->Alloc(SafeInt<size_t>(sizeof(T)) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
|
||||
T* col_buffer_data = static_cast<T*>(col_buffer.get());
|
||||
|
||||
const T* Xdata = p.X->template Data<T>();
|
||||
|
|
@ -246,7 +246,7 @@ Status ConvTranspose<float>::DoConvTranspose(OpKernelContext* context, bool dyna
|
|||
|
||||
const int64_t col_buffer_size = kernel_dim * p.input_shape.Size();
|
||||
auto col_data = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(std::move(alloc)));
|
||||
float* col_buffer_data = static_cast<float*>(col_buffer.get());
|
||||
|
||||
const float* Xdata = p.X->template Data<float>();
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ Status LRN<float>::Compute(OpKernelContext* context) const {
|
|||
|
||||
const size_t padded_square_size = (static_cast<size_t>(C) + size_ - 1) * H * W;
|
||||
auto psdata = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * padded_square_size);
|
||||
BufferUniquePtr padded_square_buffer(psdata, BufferDeleter(alloc));
|
||||
BufferUniquePtr padded_square_buffer(psdata, BufferDeleter(std::move(alloc)));
|
||||
auto* padded_square_data = static_cast<float*>(padded_square_buffer.get());
|
||||
math::Set<float, CPUMathUtil>(padded_square_size, 0.0f, padded_square_data, &CPUMathUtil::Instance());
|
||||
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * col_buffer_size);
|
||||
col_buffer = BufferUniquePtr(col_data, BufferDeleter(alloc));
|
||||
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
|
||||
}
|
||||
|
||||
auto* col_buffer_data = static_cast<uint8_t*>(col_buffer.get());
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class MatMulIntegerBase : public OpKernel {
|
|||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_b_data, 0, packed_b_size);
|
||||
|
||||
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc));
|
||||
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(std::move(alloc)));
|
||||
MlasGemmPackB(N, K, b_data, N, a_is_signed, b_is_signed_, packed_b_data);
|
||||
|
||||
bool share_prepacked_weights = (prepacked_weights != nullptr);
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ Status QLinearMatMul::Compute(OpKernelContext* ctx) const {
|
|||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
|
||||
auto gemm_output_data = alloc->Alloc(SafeInt<size_t>(gemm_shape.M) *
|
||||
gemm_shape.N * sizeof(int32_t) * num_gemms);
|
||||
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(alloc));
|
||||
BufferUniquePtr gemm_output_buffer(gemm_output_data, BufferDeleter(std::move(alloc)));
|
||||
auto* gemm_output = static_cast<int32_t*>(gemm_output_buffer.get());
|
||||
|
||||
std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_params(num_gemms);
|
||||
|
|
|
|||
|
|
@ -734,7 +734,9 @@ struct ProviderHost {
|
|||
// Tensor
|
||||
virtual std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) = 0;
|
||||
virtual std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) = 0;
|
||||
virtual void Tensor__operator_delete(Tensor* p) = 0;
|
||||
virtual std::unique_ptr<Tensor> Tensor__construct_default() = 0;
|
||||
virtual void Tensor__move_assign(Tensor& lhs, Tensor&& rhs) noexcept = 0;
|
||||
virtual void Tensor__operator_delete(Tensor* p) noexcept = 0;
|
||||
|
||||
virtual void Tensor__InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, OrtValue& ort_value) = 0;
|
||||
virtual void Tensor__InitOrtValue(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, OrtValue& ort_value) = 0;
|
||||
|
|
|
|||
|
|
@ -882,10 +882,11 @@ class SessionState {
|
|||
};
|
||||
|
||||
struct Tensor final {
|
||||
static std::unique_ptr<Tensor> CreateDefault() { return g_host->Tensor__construct_default(); }
|
||||
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) { return g_host->Tensor__construct(p_type, shape, std::move(allocator)); }
|
||||
static std::unique_ptr<Tensor> Create(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset = 0) { return g_host->Tensor__construct(p_type, shape, p_data, alloc, offset); }
|
||||
|
||||
static void operator delete(void* p) { g_host->Tensor__operator_delete(reinterpret_cast<Tensor*>(p)); }
|
||||
static void operator delete(void* p) noexcept { g_host->Tensor__operator_delete(reinterpret_cast<Tensor*>(p)); }
|
||||
|
||||
static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, OrtValue& ort_value) {
|
||||
g_host->Tensor__InitOrtValue(elt_type, shape, std::move(allocator), ort_value);
|
||||
|
|
@ -935,6 +936,10 @@ struct Tensor final {
|
|||
Tensor() = delete;
|
||||
Tensor(const Tensor&) = delete;
|
||||
void operator=(const Tensor&) = delete;
|
||||
Tensor& operator=(Tensor&& o) noexcept {
|
||||
g_host->Tensor__move_assign(*this, std::move(o));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "conv.h"
|
||||
#include "core/common/inlined_containers_fwd.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
|
@ -230,22 +231,22 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
|||
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
|
||||
auto orig_shape = tensor.Shape();
|
||||
|
||||
std::vector<size_t> perm{0, 2, 3, 1};
|
||||
std::vector<int64_t> new_dims{orig_shape[0],
|
||||
orig_shape[2],
|
||||
orig_shape[3],
|
||||
orig_shape[1]};
|
||||
InlinedVector<size_t> perm{0, 2, 3, 1};
|
||||
TensorShapeVector new_dims{orig_shape[0],
|
||||
orig_shape[2],
|
||||
orig_shape[3],
|
||||
orig_shape[1]};
|
||||
|
||||
packed_w_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), alloc);
|
||||
packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc));
|
||||
|
||||
SingleAxisTranspose(perm, tensor, *packed_w_, /*from*/ 1, /*to*/ 3);
|
||||
SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3);
|
||||
|
||||
is_packed = true;
|
||||
|
||||
// we can create the kernel now
|
||||
struct xnn_operator* p = nullptr;
|
||||
ORT_RETURN_IF_ERROR(CreateXnnpackKernel(conv_attrs_, C_, M_, kernel_shape_, clip_min_max_, IsDepthwise(),
|
||||
*packed_w_, B_ ? B_->Data<float>() : nullptr, p));
|
||||
packed_w_, B_ ? B_->Data<float>() : nullptr, p));
|
||||
|
||||
op0_.reset(p);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class Conv : public OpKernel {
|
|||
TensorShapeVector kernel_shape_;
|
||||
int64_t C_;
|
||||
int64_t M_;
|
||||
std::unique_ptr<Tensor> packed_w_;
|
||||
Tensor packed_w_;
|
||||
const Tensor* B_{nullptr};
|
||||
std::optional<std::pair<float, float>> clip_min_max_;
|
||||
|
||||
|
|
|
|||
|
|
@ -1636,8 +1636,8 @@ static common::Status CheckTypes(MLDataType actual, MLDataType expected, const s
|
|||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str());
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ValidateInputs(const std::vector<std::string>& feed_names,
|
||||
const std::vector<OrtValue>& feeds) const {
|
||||
common::Status InferenceSession::ValidateInputs(gsl::span<const std::string> feed_names,
|
||||
gsl::span<const OrtValue> feeds) const {
|
||||
if (feed_names.size() != feeds.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(),
|
||||
"elements, but feeds has ", feeds.size(), " elements.");
|
||||
|
|
@ -1735,7 +1735,7 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>& output_names,
|
||||
common::Status InferenceSession::ValidateOutputs(gsl::span<const std::string> output_names,
|
||||
const std::vector<OrtValue>* p_fetches) const {
|
||||
if (p_fetches == nullptr) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL");
|
||||
|
|
@ -1863,8 +1863,8 @@ struct ThreadPoolSpinningSwitch {
|
|||
} // namespace
|
||||
|
||||
Status InferenceSession::Run(const RunOptions& run_options,
|
||||
const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds,
|
||||
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches,
|
||||
gsl::span<const std::string> feed_names, gsl::span<const OrtValue> feeds,
|
||||
gsl::span<const std::string> output_names, std::vector<OrtValue>* p_fetches,
|
||||
const std::vector<OrtDevice>* p_fetches_device_info) {
|
||||
TimePoint tp;
|
||||
if (session_profiler_.IsEnabled()) {
|
||||
|
|
@ -1895,10 +1895,10 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
<< " CUDA Graph for this model with tag: " << run_options.run_tag;
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph());
|
||||
} else {
|
||||
std::vector<IExecutionProvider*> exec_providers_to_stop;
|
||||
InlinedVector<IExecutionProvider*> exec_providers_to_stop;
|
||||
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
|
||||
|
||||
std::vector<AllocatorPtr> arenas_to_shrink;
|
||||
InlinedVector<AllocatorPtr> arenas_to_shrink;
|
||||
|
||||
ORT_TRY {
|
||||
if (!is_inited_) {
|
||||
|
|
@ -2037,17 +2037,17 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
return retval;
|
||||
}
|
||||
|
||||
common::Status InferenceSession::Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
|
||||
common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
|
||||
std::vector<OrtValue>* p_fetches) {
|
||||
return Run(RunOptions(), feeds, output_names, p_fetches);
|
||||
}
|
||||
|
||||
common::Status InferenceSession::Run(const RunOptions& run_options, const NameMLValMap& feeds_map,
|
||||
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches) {
|
||||
std::vector<std::string> feed_names;
|
||||
std::vector<OrtValue> feeds;
|
||||
gsl::span<const std::string> output_names, std::vector<OrtValue>* p_fetches) {
|
||||
InlinedVector<std::string> feed_names;
|
||||
InlinedVector<OrtValue> feeds;
|
||||
|
||||
auto num_feeds = feeds_map.size();
|
||||
const auto num_feeds = feeds_map.size();
|
||||
feed_names.reserve(num_feeds);
|
||||
feeds.reserve(num_feeds);
|
||||
|
||||
|
|
@ -2177,7 +2177,7 @@ AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const
|
|||
}
|
||||
|
||||
common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::string& ort_device_list,
|
||||
/*out*/ std::vector<AllocatorPtr>& arenas_to_shrink) const {
|
||||
/*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const {
|
||||
arenas_to_shrink.reserve(5); // Allocate some memory for the container (we are unlikely to see more than 5 memory arena shrink requests)
|
||||
|
||||
std::istringstream ss_1(ort_device_list);
|
||||
|
|
@ -2234,7 +2234,7 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void InferenceSession::ShrinkMemoryArenas(const std::vector<AllocatorPtr>& arenas_to_shrink) {
|
||||
void InferenceSession::ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink) {
|
||||
for (auto& alloc : arenas_to_shrink) {
|
||||
auto status = static_cast<BFCArena*>(alloc.get())->Shrink();
|
||||
|
||||
|
|
|
|||
|
|
@ -300,8 +300,8 @@ class InferenceSession {
|
|||
*/
|
||||
common::Status Initialize() ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
|
||||
const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
|
||||
common::Status Run(const RunOptions& run_options, gsl::span<const std::string> feed_names,
|
||||
gsl::span<const OrtValue> feeds, gsl::span<const std::string> output_names,
|
||||
std::vector<OrtValue>* p_fetches,
|
||||
const std::vector<OrtDevice>* p_fetches_device_info = nullptr) ORT_MUST_USE_RESULT;
|
||||
|
||||
|
|
@ -315,7 +315,7 @@ class InferenceSession {
|
|||
* This should not be changed during execution of this function.
|
||||
* @return OK if success.
|
||||
*/
|
||||
common::Status Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
|
||||
common::Status Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
|
|
@ -324,7 +324,7 @@ class InferenceSession {
|
|||
* @param run_options use this to tune the Run call to your needs.
|
||||
*/
|
||||
common::Status Run(const RunOptions& run_options, const NameMLValMap& feeds,
|
||||
const std::vector<std::string>& output_names,
|
||||
gsl::span<const std::string> output_names,
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
|
|
@ -595,10 +595,10 @@ class InferenceSession {
|
|||
common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
|
||||
const TensorShape& expected_shape) const ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status ValidateInputs(const std::vector<std::string>& feed_names,
|
||||
const std::vector<OrtValue>& feeds) const ORT_MUST_USE_RESULT;
|
||||
common::Status ValidateInputs(gsl::span<const std::string> feed_names,
|
||||
gsl::span<const OrtValue> feeds) const ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status ValidateOutputs(const std::vector<std::string>& output_names,
|
||||
common::Status ValidateOutputs(gsl::span<const std::string> output_names,
|
||||
const std::vector<OrtValue>* p_fetches) const ORT_MUST_USE_RESULT;
|
||||
|
||||
common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) ORT_MUST_USE_RESULT;
|
||||
|
|
@ -617,13 +617,13 @@ class InferenceSession {
|
|||
*/
|
||||
|
||||
common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list,
|
||||
/*out*/ std::vector<AllocatorPtr>& arenas_to_shrink) const ORT_MUST_USE_RESULT;
|
||||
/*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const ORT_MUST_USE_RESULT;
|
||||
|
||||
/*
|
||||
* Performs the shrinkage of arenas requested to be shrunk by the user
|
||||
* The `arenas_to_shrink` parameter is got from ValidateAndParseShrinkArenaString()
|
||||
*/
|
||||
void ShrinkMemoryArenas(const std::vector<AllocatorPtr>& arenas_to_shrink);
|
||||
void ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
virtual common::Status AddPredefinedTransformers(
|
||||
|
|
|
|||
|
|
@ -830,9 +830,23 @@ struct ProviderHostImpl : ProviderHost {
|
|||
const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); }
|
||||
|
||||
// Tensor (wrapped)
|
||||
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) override { return std::make_unique<Tensor>(p_type, shape, std::move(allocator)); }
|
||||
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) override { return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset); }
|
||||
void Tensor__operator_delete(Tensor* p) override { delete p; }
|
||||
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator) override {
|
||||
return std::make_unique<Tensor>(p_type, shape, std::move(allocator));
|
||||
}
|
||||
|
||||
std::unique_ptr<Tensor> Tensor__construct(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, ptrdiff_t offset) override {
|
||||
return std::make_unique<Tensor>(p_type, shape, p_data, alloc, offset);
|
||||
}
|
||||
|
||||
std::unique_ptr<Tensor> Tensor__construct_default() override {
|
||||
return std::make_unique<Tensor>();
|
||||
}
|
||||
|
||||
virtual void Tensor__move_assign(Tensor& lhs, Tensor&& rhs) noexcept override {
|
||||
lhs = std::move(rhs);
|
||||
};
|
||||
|
||||
void Tensor__operator_delete(Tensor* p) noexcept override { delete p; }
|
||||
|
||||
void Tensor__InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, OrtValue& ort_value) override {
|
||||
Tensor::InitOrtValue(elt_type, shape, std::move(allocator), ort_value);
|
||||
|
|
|
|||
|
|
@ -439,7 +439,8 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
|
|||
const void* orig_buffer = results[0].Get<Tensor>().DataRaw();
|
||||
|
||||
RunOptions ro;
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
|
||||
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<const std::string>(),
|
||||
EmptySpan<const OrtValue>(), AsSpan({std::string("values")}), &results, nullptr));
|
||||
|
||||
EXPECT_EQ(results[0].Get<Tensor>().DataRaw(), orig_buffer);
|
||||
EXPECT_THAT(results[0].Get<Tensor>().DataAsSpan<float>(), ::testing::ContainerEq(gsl::make_span(expected)));
|
||||
|
|
@ -453,7 +454,8 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
|
|||
|
||||
std::vector<OrtValue> results;
|
||||
RunOptions ro;
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
|
||||
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<std::string>(),
|
||||
EmptySpan<OrtValue>(), AsSpan({std::string("values")}), &results, nullptr));
|
||||
|
||||
// output buffer should not be the same as the initializer in SessionState
|
||||
const auto& initializers = session.GetSessionState().GetInitializedTensors();
|
||||
|
|
@ -464,7 +466,6 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) {
|
|||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
|
||||
|
||||
const std::vector<int64_t> dense_shape{3, 3};
|
||||
std::vector<float> dense_data = {
|
||||
0, 0, 1.764052391052246f,
|
||||
|
|
@ -491,7 +492,7 @@ TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
|
|||
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
|
||||
results[0].Init(p_tensor.release(), ml_type, ml_type->GetDeleteFunc());
|
||||
RunOptions ro;
|
||||
ASSERT_STATUS_OK(session.Run(ro, {}, {}, {"values"}, &results, nullptr));
|
||||
ASSERT_STATUS_OK(session.Run(ro, EmptySpan<std::string>(), EmptySpan<OrtValue>(), AsSpan<std::string>({"values"}), &results, nullptr));
|
||||
|
||||
ASSERT_TRUE(results[0].IsAllocated());
|
||||
ASSERT_TRUE(results[0].IsSparseTensor());
|
||||
|
|
@ -504,7 +505,7 @@ TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) {
|
|||
EXPECT_THAT(coo_view.Indices().DataAsSpan<int64_t>(), ::testing::ContainerEq(gsl::make_span(expected_linear_indices)));
|
||||
}
|
||||
}
|
||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "onnx/defs/parser.h"
|
||||
|
||||
#include "core/common/span_utils.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
|
@ -13,6 +14,7 @@
|
|||
#include "test/framework/test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
|
||||
|
||||
// Unit tests to check the implementation of functions, model-local functions,
|
||||
// function-inlining etc.
|
||||
|
||||
|
|
@ -56,7 +58,7 @@ static void Check(const char* source,
|
|||
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
status = session_object.Run(run_options, feeds, {output_name}, &fetches);
|
||||
status = session_object.Run(run_options, feeds, AsSpan({std::string(output_name)}), &fetches);
|
||||
ASSERT_TRUE(status.IsOK()) << "Session Run failed.";
|
||||
|
||||
auto& tensor = fetches[0].Get<Tensor>();
|
||||
|
|
|
|||
|
|
@ -86,11 +86,11 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
bool is_transpose_required = opset_ >= 13 && axis != (rank - 1);
|
||||
|
||||
std::unique_ptr<Tensor> transposed_dY;
|
||||
std::unique_ptr<Tensor> transposed_Y;
|
||||
std::vector<int64_t> transposed_input_dims;
|
||||
std::unique_ptr<Tensor> intermediate_output; // output that the softmax implementation will write into while using transposed input
|
||||
std::vector<size_t> permutation(rank);
|
||||
Tensor transposed_dY;
|
||||
Tensor transposed_Y;
|
||||
TensorShapeVector transposed_input_dims;
|
||||
Tensor intermediate_output; // output that the softmax implementation will write into while using transposed input
|
||||
InlinedVector<size_t> permutation(rank);
|
||||
|
||||
if (is_transpose_required) {
|
||||
AllocatorPtr alloc;
|
||||
|
|
@ -112,26 +112,26 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
|||
D = TensorShape(transposed_input_dims).SizeFromDimension(rank - 1);
|
||||
|
||||
// Allocate a temporary tensor to hold transposed input
|
||||
auto temp_input0 = Tensor::Create(Y.DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
auto temp_input0 = Tensor(Y.DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
|
||||
// Perform the transpose
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, Y, *temp_input0));
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, Y, temp_input0));
|
||||
transposed_Y = std::move(temp_input0);
|
||||
|
||||
auto temp_input1 = Tensor::Create(Y.DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, dY, *temp_input1));
|
||||
auto temp_input1 = Tensor(Y.DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, dY, temp_input1));
|
||||
transposed_dY = std::move(temp_input1);
|
||||
|
||||
// Allocate memory for the intermediate output
|
||||
intermediate_output = Tensor::Create(dX.DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
intermediate_output = Tensor(dX.DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
}
|
||||
|
||||
const int n = gsl::narrow_cast<int>(N);
|
||||
const int d = gsl::narrow_cast<int>(D);
|
||||
const int nd = gsl::narrow_cast<int>(N * D);
|
||||
const float* Ydata = is_transpose_required ? transposed_Y->template Data<T>() : Y.template Data<float>();
|
||||
const float* dYdata = is_transpose_required ? transposed_dY->template Data<T>() : dY.template Data<float>();
|
||||
float* dXdata = is_transpose_required ? intermediate_output->template MutableData<T>() : dX.template MutableData<float>();
|
||||
const float* Ydata = is_transpose_required ? transposed_Y.template Data<T>() : Y.template Data<float>();
|
||||
const float* dYdata = is_transpose_required ? transposed_dY.template Data<T>() : dY.template Data<float>();
|
||||
float* dXdata = is_transpose_required ? intermediate_output.template MutableData<T>() : dX.template MutableData<float>();
|
||||
|
||||
gsl::copy(gsl::make_span(dYdata, nd), gsl::make_span(dXdata, nd));
|
||||
if (is_logsoftmaxgrad_) {
|
||||
|
|
@ -164,7 +164,7 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
if (is_transpose_required) {
|
||||
// Perform the transpose to get the axes back to the original ordering
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, *intermediate_output, dX));
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutation, intermediate_output, dX));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -87,10 +87,10 @@ Status AdamWOptimizerBase::GenerateOutputs(OpKernelContext* ctx, size_t number_o
|
|||
updated_values->Reserve(number_of_values);
|
||||
for (size_t input_idx = 0; input_idx < number_of_values; ++input_idx) {
|
||||
const Tensor& source_tensor = values->Get(input_idx);
|
||||
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
|
||||
source_tensor.Shape(), alloc);
|
||||
ORT_RETURN_IF_ERROR(CopyInputTensorToOutputTensor(source_tensor, *target_tensor));
|
||||
updated_values->Add(std::move(*target_tensor)); // Add will check for type consistency
|
||||
Tensor target_tensor(source_tensor.DataType(),
|
||||
source_tensor.Shape(), alloc);
|
||||
ORT_RETURN_IF_ERROR(CopyInputTensorToOutputTensor(source_tensor, target_tensor));
|
||||
updated_values->Add(std::move(target_tensor)); // Add will check for type consistency
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue