onnxruntime/onnxruntime/core/framework/utils.cc
Vincent Wang b7408f7389
[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.

Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.

The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
2023-10-27 10:29:27 +08:00

1063 lines
47 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/onnx_protobuf.h"
#include "core/framework/utils.h"
#include <iomanip>
#include "core/graph/graph_viewer.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/bfc_arena.h"
#include "core/framework/execution_frame.h"
#include "core/framework/stream_execution_context.h"
#include "core/framework/execution_providers.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/kernel_def_builder.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/session_state.h"
#include "core/framework/sequential_executor.h"
#include "core/framework/tensorprotoutils.h"
#include "core/mlas/inc/mlas.h"
#include "core/framework/TensorSeq.h"
#include "core/framework/run_options.h"
#include "core/session/onnxruntime_run_options_config_keys.h"
#ifdef ENABLE_TRAINING
#include "core/framework/partial_graph_execution_state.h"
#endif
#ifdef ENABLE_ATEN
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif
namespace onnxruntime {
namespace utils {
void* DefaultAlloc(size_t size) {
return onnxruntime::AllocatorDefaultAlloc(size);
}
void DefaultFree(void* p) {
onnxruntime::AllocatorDefaultFree(p);
}
void ConstructStrings(void* p_data, int64_t elements) {
auto* ptr = static_cast<std::string*>(p_data);
for (int64_t i = 0; i < elements; ++i) {
new (ptr + i) std::string();
}
}
void DestroyStrings(void* p_data, int64_t elements) {
using string = std::string;
auto* ptr = static_cast<std::string*>(p_data);
for (int64_t i = 0; i < elements; i++)
ptr[i].~string();
}
bool ProviderIsCpuBased(const std::string& provider_type) {
return provider_type == onnxruntime::kCpuExecutionProvider ||
provider_type == onnxruntime::kDnnlExecutionProvider ||
provider_type == onnxruntime::kTvmExecutionProvider ||
provider_type == onnxruntime::kVitisAIExecutionProvider ||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
provider_type == onnxruntime::kNnapiExecutionProvider ||
provider_type == onnxruntime::kAclExecutionProvider ||
provider_type == onnxruntime::kArmNNExecutionProvider ||
provider_type == onnxruntime::kRknpuExecutionProvider ||
provider_type == onnxruntime::kCoreMLExecutionProvider ||
provider_type == onnxruntime::kSnpeExecutionProvider ||
provider_type == onnxruntime::kQnnExecutionProvider ||
provider_type == onnxruntime::kXnnpackExecutionProvider ||
provider_type == onnxruntime::kAzureExecutionProvider ||
provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
}
static common::Status AllocateHelper(const AllocatorPtr& allocator,
Stream* target_stream,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue) {
if (!allocator) {
return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator.");
}
if (source_mlvalue.IsTensor()) {
const Tensor& source_tensor = source_mlvalue.Get<Tensor>();
if (allocator->Info().alloc_type == OrtArenaAllocator) {
void* p_data = nullptr;
#ifdef ORT_ENABLE_STREAM
BFCArena* arena_ptr = static_cast<BFCArena*>(allocator.get());
auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(*arena_ptr);
if (stream_aware_alloc && target_stream) {
size_t len = Tensor::CalculateTensorStorageSize(source_tensor.DataType(), source_tensor.Shape());
p_data = stream_aware_alloc->AllocOnStream(len, target_stream, nullptr);
}
#else
ORT_UNUSED_PARAMETER(target_stream);
#endif // ORT_ENABLE_STREAM
if (p_data == nullptr) {
Tensor::InitOrtValue(source_tensor.DataType(),
source_tensor.Shape(),
allocator, target_mlvalue);
} else {
Tensor::InitOrtValue(source_tensor.DataType(),
source_tensor.Shape(),
p_data,
allocator, target_mlvalue);
}
} else {
Tensor::InitOrtValue(source_tensor.DataType(),
source_tensor.Shape(),
allocator, target_mlvalue);
}
} else if (source_mlvalue.IsSparseTensor()) {
#if !defined(DISABLE_SPARSE_TENSORS)
const SparseTensor& source_tensor = source_mlvalue.Get<SparseTensor>();
SparseTensor::InitOrtValue(source_tensor.DataType(), source_tensor.DenseShape(), allocator, target_mlvalue);
#endif
} else if (source_mlvalue.IsTensorSequence()) {
const TensorSeq& source_tensor_seq = source_mlvalue.Get<TensorSeq>();
TensorSeq::InitOrtValue(source_tensor_seq, allocator, target_mlvalue);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported OrtValue type.");
}
return Status::OK();
}
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) {
// the input index will be std::numeric_limits<size_t>::max() if it's an implicit input to a control flow node.
// the input will be processed fully when executing the subgraph that consumes the implicit input.
bool implicit_input = info.index == std::numeric_limits<size_t>::max();
// node may declare input_mem_type to be on CPU explicitly
// skip implicit inputs as they don't have a valid 'index' value
bool node_input_on_cpu = !implicit_input && info.kci && info.kci->kernel_def->IsInputOnCpu(info.index);
// need a std::string that doesn't go away for kCpuExecutionProvider so we can return a reference.
static const std::string cpu_execution_provider{onnxruntime::kCpuExecutionProvider};
auto& required_provider_type = node_input_on_cpu ? cpu_execution_provider
: info.p_node->GetExecutionProviderType();
return required_provider_type;
}
// Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_tensor_pairs/copy_sparse_pairs is provided,
// src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager
// implementation for performance reasons.
static Status BatchOrCopyMLValue(const SessionState& session_state,
const MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue,
Stream* stream,
#if !defined(DISABLE_SPARSE_TENSORS)
std::vector<IDataTransfer::SrcDstPair>* copy_tensor_pairs = nullptr,
std::vector<IDataTransfer::SparseSrcDstPair>* copy_sparse_pairs = nullptr)
#else
std::vector<IDataTransfer::SrcDstPair>* copy_tensor_pairs = nullptr)
#endif
{
// same device so direct copy
if (copy_info.source_device == copy_info.target_device) {
target_mlvalue = source_mlvalue;
return Status::OK();
}
auto allocator = session_state.GetAllocator(copy_info.target_device);
if (!target_mlvalue.IsAllocated()) {
ORT_ENFORCE(allocator != nullptr, "Failed to find allocator for device ", copy_info.target_device.ToString());
ORT_RETURN_IF_ERROR(utils::AllocateHelper(allocator, stream, source_mlvalue, target_mlvalue));
}
if (source_mlvalue.IsTensor()) {
const auto& source_tensor = source_mlvalue.Get<Tensor>();
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
if (copy_tensor_pairs != nullptr) {
copy_tensor_pairs->push_back({source_tensor, *p_output_tensor, stream});
} else {
ORT_RETURN_IF_ERROR(stream ? session_state.GetDataTransferMgr().CopyTensorAsync(source_tensor, *p_output_tensor, *stream) : session_state.GetDataTransferMgr().CopyTensor(source_tensor, *p_output_tensor));
}
} else if (source_mlvalue.IsSparseTensor()) {
#if !defined(DISABLE_SPARSE_TENSORS)
const auto& source_tensor = source_mlvalue.Get<SparseTensor>();
SparseTensor* p_output_tensor = target_mlvalue.GetMutable<SparseTensor>();
if (copy_sparse_pairs != nullptr) {
copy_sparse_pairs->push_back({source_tensor, *p_output_tensor, 0});
} else {
ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopySparseTensor(source_tensor, *p_output_tensor));
}
#endif
} else if (source_mlvalue.IsTensorSequence()) {
const TensorSeq& source_tensor_seq = source_mlvalue.Get<TensorSeq>();
TensorSeq& target_tensor_seq = const_cast<TensorSeq&>(target_mlvalue.Get<TensorSeq>());
size_t size = 0;
while ((size = target_tensor_seq.Size()) < source_tensor_seq.Size()) {
if (0 == size) {
target_tensor_seq.SetType(source_tensor_seq.DataType());
}
const Tensor& source_tensor = source_tensor_seq.Get(size);
std::unique_ptr<Tensor> target_tensor = std::make_unique<Tensor>(source_tensor.DataType(), source_tensor.Shape(), allocator);
target_tensor_seq.Add(std::move(*target_tensor));
}
const auto& data_transfer_mgr = session_state.GetDataTransferMgr();
auto source_iter = source_tensor_seq.begin();
auto target_iter = target_tensor_seq.begin();
while (source_iter != source_tensor_seq.end() &&
target_iter != target_tensor_seq.end()) {
if (copy_tensor_pairs != nullptr) {
copy_tensor_pairs->push_back({source_iter->Get<Tensor>(), *target_iter->GetMutable<Tensor>(), stream});
} else {
if (stream)
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensorAsync(source_iter->Get<Tensor>(), *target_iter->GetMutable<Tensor>(), *stream));
else
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_iter->Get<Tensor>(), *target_iter->GetMutable<Tensor>()));
}
++source_iter;
++target_iter;
} // while
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported OrtValue type to copy between device.");
}
return Status::OK();
} // namespace utils
static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) {
for (const auto& execution_provider : execution_providers) {
if (!ProviderIsCpuBased(execution_provider->Type())) {
return false;
}
}
return true;
}
static const OrtDevice& FindDeviceForValue(const OrtValueNameIdxMap& map,
const SequentialExecutionPlan& plan,
std::string_view name) {
int idx = -1;
auto status = map.GetIdx(name, idx);
ORT_THROW_IF_ERROR(status);
const auto& location = plan.GetLocation(idx);
return location;
}
const OrtDevice& FindDeviceForValue(const SessionState& session_state, std::string_view name) {
const auto* exec_plan_ptr = session_state.GetExecutionPlan();
ORT_ENFORCE(exec_plan_ptr);
return FindDeviceForValue(session_state.GetOrtValueNameIdxMap(), *exec_plan_ptr, name);
}
// get the target device info for the node consuming each input provided in the feeds.
// source_device info is not known until runtime
static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session_state,
const std::string& input_name,
MLValueCopyInfo& copy_info) {
InlinedVector<SessionState::NodeInfo> node_info_vec;
#ifdef ENABLE_TRAINING
if (session_state.GetInputNodeInfo(input_name, node_info_vec) == Status::OK()) {
#else
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
#endif
const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine
if (node_info.p_node == nullptr) {
// ignore dummy entry for an input that we didn't find a use of in the graph.
return Status::OK();
}
copy_info.target_device = *node_info.device;
#ifdef ENABLE_TRAINING
} else {
// This input might be for an intermediate tensor for partial graph execution.
const auto* exec_plan = session_state.GetExecutionPlan();
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();
int index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_name, index));
const auto& device = exec_plan->GetLocation(index);
copy_info.target_device = device;
}
#endif
return Status::OK();
}
static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& session_state,
gsl::span<const std::string> feed_names,
std::vector<MLValueCopyInfo>& copy_info) {
for (size_t idx = 0, end = feed_names.size(); idx < end; ++idx) {
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, feed_names[idx], copy_info[idx]));
}
return Status::OK();
}
// get the source device info for the node producing each output that we will return in the fetches.
// target device info is not known until runtime.
static common::Status CalculateStaticCopyInfoForFetches(const SessionState& session_state,
gsl::span<const std::string> fetch_names,
std::vector<MLValueCopyInfo>& copy_info) {
for (size_t idx = 0, end = fetch_names.size(); idx < end; ++idx) {
const std::string& output_name = fetch_names[idx];
const auto& info = FindDeviceForValue(session_state, output_name);
copy_info[idx].source_device = info;
}
return Status::OK();
}
common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager) {
// if we only have CPU based EPs we can skip all the copy logic
auto cpu_only = HaveCpuExecutionProvidersOnly(session_state.GetExecutionProviders());
if (cpu_only) {
feeds_fetches_manager.SetDeviceCopyChecks(DeviceCopyCheck::NoCopy, DeviceCopyCheck::NoCopy);
} else {
// setup all the static info about where the graph inputs and outputs are located
auto info = feeds_fetches_manager.GetFeedsFetchesInfo();
auto& feed_copy_info = feeds_fetches_manager.GetMutableFeedsDeviceCopyInfo();
auto& fetch_copy_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo();
ORT_RETURN_IF_ERROR(utils::CalculateStaticCopyInfoForFeeds(session_state, info.feed_names, feed_copy_info));
ORT_RETURN_IF_ERROR(utils::CalculateStaticCopyInfoForFetches(session_state, info.output_names, fetch_copy_info));
}
return Status::OK();
}
// update the allocation_provider in the copy info based on the actual feeds
static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
std::vector<MLValueCopyInfo>& copy_info) {
ORT_ENFORCE(feed_locations.size() == copy_info.size());
bool copy_needed = false;
for (size_t i = 0, end = feed_locations.size(); i < end; ++i) {
copy_info[i].source_device = feed_locations[i];
if (copy_info[i].source_device != copy_info[i].target_device) {
copy_needed = true;
}
}
return copy_needed;
}
static bool FinalizeCopyInfoForFetches(gsl::span<const OrtDevice* const>& fetch_alloc_info,
std::vector<MLValueCopyInfo>& copy_info) {
ORT_ENFORCE(fetch_alloc_info.size() == copy_info.size());
bool copy_needed = false;
auto num_outputs = fetch_alloc_info.size();
for (size_t i = 0; i < num_outputs; ++i) {
const OrtDevice* alloc_info = fetch_alloc_info[i];
if (alloc_info != nullptr) {
copy_info[i].target_device = *alloc_info;
}
if (copy_info[i].source_device != copy_info[i].target_device) {
copy_needed = true;
}
}
return copy_needed;
}
// Finalize the copy info using the OrtDevice and OrtMemoryInfo for the feeds and fetches
// This can be used by control flow nodes prior to the execution of the overall graph.
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtDevice> feed_locations,
gsl::span<const OrtDevice* const> fetch_alloc_info) {
if (feeds_fetches_manager.GetDeviceCopyChecks().status == DeviceCopyCheck::NoCopy)
return;
bool need_copy = FinalizeCopyInfoForFeeds(feed_locations, feeds_fetches_manager.GetMutableFeedsDeviceCopyInfo());
DeviceCopyCheck input_copy = need_copy ? DeviceCopyCheck::Copy : DeviceCopyCheck::NoCopy;
need_copy = FinalizeCopyInfoForFetches(fetch_alloc_info, feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo());
DeviceCopyCheck output_copy = need_copy ? DeviceCopyCheck::Copy : DeviceCopyCheck::NoCopy;
feeds_fetches_manager.SetDeviceCopyChecks(input_copy, output_copy);
}
// Finalize the copy info using the OrtValue instances for the feeds and fetches
static void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds,
std::vector<OrtValue>& fetches) {
if (feeds_fetches_manager.GetDeviceCopyChecks().status == DeviceCopyCheck::NoCopy)
return;
auto num_inputs = feeds.size();
auto num_outputs = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size();
std::vector<OrtDevice> feed_locations(num_inputs);
std::vector<const OrtDevice*> fetch_alloc_info(num_outputs, nullptr);
for (size_t i = 0; i < num_inputs; ++i) {
const auto& feed = feeds[i];
if (feed.IsTensor()) {
feed_locations[i] = feed.Get<Tensor>().Location().device;
} else if (feed.IsTensorSequence()) {
const auto& tensor_seq = feed.Get<TensorSeq>();
if (tensor_seq.Size() != std::size_t{0}) {
feed_locations[i] = tensor_seq.Get(0).Location().device;
}
} else if (feed.IsSparseTensor()) {
#if !defined(DISABLE_SPARSE_TENSORS)
feed_locations[i] = feed.Get<SparseTensor>().Location().device;
#endif
}
}
// create default instances if needed
fetches.resize(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
const auto& fetch = fetches[i];
if (fetch.IsAllocated()) {
if (fetch.IsTensor()) {
fetch_alloc_info[i] = &fetch.Get<Tensor>().Location().device;
} else if (fetch.IsTensorSequence()) {
const auto& tensor_seq = fetch.Get<TensorSeq>();
if (tensor_seq.Size() != std::size_t{0}) {
fetch_alloc_info[i] = &tensor_seq.Get(0).Location().device;
}
} else if (fetch.IsSparseTensor()) {
#if !defined(DISABLE_SPARSE_TENSORS)
fetch_alloc_info[i] = &fetch.Get<SparseTensor>().Location().device;
#endif
}
}
}
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feed_locations, fetch_alloc_info);
}
static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
gsl::span<const OrtValue> orig_feeds,
std::vector<OrtValue>& new_feeds,
gsl::span<const MLValueCopyInfo> copy_info,
gsl::span<Stream* const> feed_streams) {
size_t num_feeds = orig_feeds.size();
ORT_ENFORCE(copy_info.size() == num_feeds);
ORT_ENFORCE(feed_streams.size() == num_feeds);
new_feeds.resize(num_feeds);
std::vector<IDataTransfer::SrcDstPair> batched_data_transfers;
#if !defined(DISABLE_SPARSE_TENSORS)
std::vector<IDataTransfer::SparseSrcDstPair> batched_sparse_data_transfers;
#endif
for (size_t idx = 0; idx < num_feeds; ++idx) {
#if !defined(DISABLE_SPARSE_TENSORS)
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(session_state, copy_info[idx], orig_feeds[idx], new_feeds[idx],
feed_streams[idx],
&batched_data_transfers, &batched_sparse_data_transfers));
#else
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(session_state, copy_info[idx], orig_feeds[idx], new_feeds[idx],
feed_streams[idx],
&batched_data_transfers));
#endif
}
if (!batched_data_transfers.empty()) {
ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopyTensors(batched_data_transfers));
}
#if !defined(DISABLE_SPARSE_TENSORS)
if (!batched_sparse_data_transfers.empty()) {
ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopySparseTensors(batched_sparse_data_transfers));
}
#endif
// flush the stream to make sure the inputs are ready before launch the inference.
// TODO: this sync is because the graph inputs can be consumed by multiple stream,
// but we can only place the MemCpyAsync on one of the stream. Ideally we should make
// other stream wait on the event of the memory copy stream, instead of host sync stream.
std::unordered_set<Stream*> visited;
for (auto* stream : feed_streams) {
if (stream && visited.insert(stream).second) stream->Flush();
}
return Status::OK();
}
#ifdef ORT_ENABLE_STREAM
static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
// TODO: in theory, we should make current subgraph's stream depends on parent stream.
// but in current code structure, it causing issues with the resource sharing and stream
// lifetime. it also may cause additional cost of stream sync for single stream case.
// In first phase, let's just put all the subgraph execution on the parent stream.
for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) {
auto* stream = device_stream_collection.GetStream(i);
if (stream) {
// if current logic stream is not on the same EP instance as parent stream
// and the EP instance does have async streams (not EP like CPU)
// throw error as we don't have the code to setup the dependency at this moment.
if (stream->GetDevice() != parent_stream->GetDevice()) {
ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(),
" while parent graph node running on device: ", parent_stream->GetDevice().Type(),
", this is not supported yet.");
}
device_stream_collection.SetDeviceStream(i, parent_stream);
}
}
}
}
#endif
// public method to do a single copy. used by external partners
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue) {
if (!orig_mlvalue.IsTensor() && !orig_mlvalue.IsSparseTensor()) {
new_mlvalue = orig_mlvalue;
return Status::OK();
}
MLValueCopyInfo copy_info;
// Sets copy_info.target_device.
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info));
#if !defined(DISABLE_SPARSE_TENSORS)
copy_info.source_device = (orig_mlvalue.IsTensor())
? orig_mlvalue.Get<Tensor>().Location().device
: orig_mlvalue.Get<SparseTensor>().Location().device;
#else
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
#endif
Stream* device_stream = nullptr;
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
if (device_stream_collection_holder.p_ != nullptr) {
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
size_t num_streams = device_stream_collection->NumStreams();
for (size_t i = 0; i < num_streams; i++) {
Stream* stream = device_stream_collection->GetStream(i);
if (stream && stream->GetDevice().Type() != OrtDevice::CPU) {
device_stream = stream;
break;
}
}
}
#endif
// copy_info.target_device is not set leaving to be equal to CPU.
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, device_stream);
}
static common::Status CopyOutputsAcrossDevices(const SessionState& session_state,
gsl::span<const OrtValue> fetches,
std::vector<OrtValue>& user_fetches,
gsl::span<const MLValueCopyInfo> copy_info,
gsl::span<Stream* const> fetch_streams) {
auto num_outputs = fetches.size();
user_fetches.resize(num_outputs);
std::vector<IDataTransfer::SrcDstPair> batched_data_transfers;
#if !defined(DISABLE_SPARSE_TENSORS)
std::vector<IDataTransfer::SparseSrcDstPair> batched_sparse_data_transfers;
#endif
for (size_t idx = 0; idx < num_outputs; ++idx) {
#if !defined(DISABLE_SPARSE_TENSORS)
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(session_state, copy_info[idx], fetches[idx], user_fetches[idx], fetch_streams[idx],
&batched_data_transfers, &batched_sparse_data_transfers));
#else
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(session_state, copy_info[idx], fetches[idx], user_fetches[idx], fetch_streams[idx],
&batched_data_transfers));
#endif
}
if (!batched_data_transfers.empty()) {
ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopyTensors(batched_data_transfers));
}
#if !defined(DISABLE_SPARSE_TENSORS)
if (!batched_sparse_data_transfers.empty()) {
ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopySparseTensors(batched_sparse_data_transfers));
}
#endif
return Status::OK();
}
static common::Status
ExecuteGraphImpl(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag,
const logging::Logger& logger,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollection* device_stream_collection,
#endif
const bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr) {
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
#ifdef ORT_ENABLE_STREAM
auto* execution_plan = session_state.GetExecutionPlan();
if (device_stream_collection)
UpdateWithParentStream(*device_stream_collection, parent_stream);
#else
ORT_UNUSED_PARAMETER(parent_stream);
#endif
bool is_subgraph = session_state.GetGraphViewer().ParentNode() != nullptr;
// in following two cases, we execute the workload in main thread:
// 1. execution mode is sequential.
// 2. execute a subgraph. Because in current implementation, the execute of subgraph is launched through parent kernel.
// the parent kernel will occupy a thread in thread pool. if we use multiple threads to execute subgraph, it may cause
// deadlock when we reach the limitation of thread pool.
bool single_thread_mode = execution_mode == ExecutionMode::ORT_SEQUENTIAL || is_subgraph;
// see if we can skip copies due to the types of execution providers available
if (device_copy_checks.status == DeviceCopyCheck::NoCopy) {
// no device copies are needed so simple execute
auto status = (ExecuteThePlan(session_state,
feeds_fetches_info.feeds_mlvalue_idxs, feeds,
feeds_fetches_info.fetches_mlvalue_idxs, fetches, fetch_allocators,
logger,
#ifdef ORT_ENABLE_STREAM
device_stream_collection,
#endif
terminate_flag,
only_execute_path_to_fetches,
// single thread mode
single_thread_mode));
ORT_RETURN_IF_ERROR(status);
} else {
auto feeds_to_use = feeds;
std::vector<OrtValue>* p_fetches = &fetches;
std::vector<OrtValue> device_feeds;
std::vector<OrtValue> device_fetches;
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo();
InlinedVector<Stream*> feed_streams;
feed_streams.reserve(feed_copy_info.size());
// TODO: we can pre-calculate the stream index for graph inputs in execution plan
#ifdef ORT_ENABLE_STREAM
for (auto& copy_info : feed_copy_info) {
auto& device = copy_info.target_device;
bool found = false;
if (device_stream_collection) {
size_t num_streams = device_stream_collection->NumStreams();
for (size_t i = 0; i < num_streams; i++) {
Stream* stream = device_stream_collection->GetStream(i);
if (stream && stream->GetDevice().Type() == device.Type()) {
feed_streams.push_back(stream);
found = true;
break;
}
}
}
if (!found)
feed_streams.push_back(nullptr);
}
#else
for (size_t i = 0; i < feed_copy_info.size(); ++i) {
feed_streams.push_back(nullptr);
}
#endif
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info, feed_streams));
feeds_to_use = device_feeds;
}
auto num_outputs = fetches.size();
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
// need intermediate fetches. use pre-allocated fetches where possible.
device_fetches.reserve(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) {
device_fetches.push_back(fetches[i]);
} else {
// use temporary value
device_fetches.push_back({});
}
}
p_fetches = &device_fetches;
}
// no device copies are needed so simple execute
auto status = (ExecuteThePlan(session_state,
feeds_fetches_info.feeds_mlvalue_idxs, feeds_to_use,
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, fetch_allocators,
logger,
#ifdef ORT_ENABLE_STREAM
device_stream_collection,
#endif
terminate_flag,
only_execute_path_to_fetches,
single_thread_mode));
ORT_RETURN_IF_ERROR(status);
InlinedVector<Stream*> fetches_streams;
fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size());
#ifdef ORT_ENABLE_STREAM
auto& value_to_stream_map = execution_plan->value_to_stream_map;
for (auto fetch_idx : feeds_fetches_info.fetches_mlvalue_idxs) {
auto it = value_to_stream_map.find(fetch_idx);
if (it != value_to_stream_map.end()) {
fetches_streams.push_back(device_stream_collection ? device_stream_collection->GetStream(it->second) : nullptr);
} else {
// for subgraph, it is possible the graph is empty,
// the fetches are come from parent graph.
fetches_streams.push_back(parent_stream);
}
}
#else
for (size_t i = 0; i < feeds_fetches_info.fetches_mlvalue_idxs.size(); ++i) {
fetches_streams.push_back(nullptr);
}
#endif
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
ORT_RETURN_IF_ERROR(CopyOutputsAcrossDevices(session_state, *p_fetches, fetches, fetch_copy_info, fetches_streams));
}
}
return Status::OK();
}
common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag,
const logging::Logger& logger,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches,
Stream* parent_stream) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));
// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
device_stream_collection,
only_execute_path_to_fetches,
parent_stream);
return retval;
#else
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
only_execute_path_to_fetches,
parent_stream);
#endif
}
common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger) {
return ExecuteGraph(session_state,
feeds_fetches_manager,
feeds, fetches,
execution_mode,
run_options.terminate,
logger,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_options.only_execute_path_to_fetches);
}
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache, const bool& terminate_flag,
DeviceStreamCollection* device_stream_collection,
int32_t partial_graph_index,
Stream* parent_stream) {
// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
// always use single_stream mode for training, to have a stable execution order
bool single_thread_mode = true;
auto* execution_plan = session_state.GetExecutionPlan();
if (device_stream_collection)
UpdateWithParentStream(*device_stream_collection, parent_stream);
// see if we can skip copies due to the types of execution providers available
if (device_copy_checks.status == DeviceCopyCheck::NoCopy) {
// no device copies are needed so simple execute
ORT_RETURN_IF_ERROR(PartialExecuteThePlan(session_state,
feeds_fetches_info.feeds_mlvalue_idxs, feeds,
feeds_fetches_info.fetches_mlvalue_idxs, fetches, {},
logger,
device_stream_collection,
terminate_flag,
// single thread mode
single_thread_mode,
state,
cache,
partial_graph_index));
} else {
auto p_feeds = feeds;
std::vector<OrtValue>* p_fetches = &fetches;
std::vector<OrtValue> device_feeds;
std::vector<OrtValue> device_fetches;
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
const auto& feed_copy_info = feeds_fetches_manager.GetFeedsDeviceCopyInfo();
InlinedVector<Stream*> feed_streams;
feed_streams.reserve(feed_copy_info.size());
// TODO: we can pre-calculate the stream index for graph inputs in execution plan
for (auto& copy_info : feed_copy_info) {
auto& device = copy_info.target_device;
bool found = false;
if (device_stream_collection) {
size_t num_streams = device_stream_collection->NumStreams();
for (size_t i = 0; i < num_streams; i++) {
Stream* stream = device_stream_collection->GetStream(i);
if (stream && stream->GetDevice().Type() == device.Type()) {
feed_streams.push_back(stream);
found = true;
break;
}
}
}
if (!found)
feed_streams.push_back(nullptr);
}
ORT_RETURN_IF_ERROR(CopyInputsAcrossDevices(session_state, feeds, device_feeds, feed_copy_info, feed_streams));
p_feeds = device_feeds;
}
auto num_outputs = fetches.size();
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
// need intermediate fetches. use pre-allocated fetches where possible.
device_fetches.reserve(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) {
device_fetches.push_back(fetches[i]);
} else {
// use temporary value
device_fetches.push_back({});
}
}
p_fetches = &device_fetches;
}
ORT_RETURN_IF_ERROR(PartialExecuteThePlan(session_state,
feeds_fetches_info.feeds_mlvalue_idxs, p_feeds,
feeds_fetches_info.fetches_mlvalue_idxs, *p_fetches, {},
logger,
device_stream_collection,
terminate_flag,
// single thread mode
single_thread_mode,
state,
cache,
partial_graph_index));
InlinedVector<Stream*> fetches_streams;
fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size());
auto& value_to_stream_map = execution_plan->value_to_stream_map;
for (auto fetch_idx : feeds_fetches_info.fetches_mlvalue_idxs) {
auto it = value_to_stream_map.find(fetch_idx);
if (it != value_to_stream_map.end()) {
fetches_streams.push_back(device_stream_collection ? device_stream_collection->GetStream(it->second) : nullptr);
} else {
// for subgraph, it is possible the graph is empty,
// the fetches are come from parent graph.
fetches_streams.push_back(parent_stream);
}
}
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
ORT_RETURN_IF_ERROR(CopyOutputsAcrossDevices(session_state, *p_fetches, fetches, fetch_copy_info, fetches_streams));
}
// training don't want to flush the stream
}
return Status::OK();
}
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache, const bool& terminate_flag,
int32_t partial_graph_index,
Stream* parent_stream) {
DeviceStreamCollection* device_stream_collection = state.GetDeviceStreamCollection(session_state);
auto retval = ExecutePartialGraphImpl(session_state, feeds_fetches_manager, feeds, fetches,
logger, state, cache, terminate_flag, device_stream_collection,
partial_graph_index, parent_stream);
if (device_stream_collection)
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(false));
return retval;
}
#endif
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
Stream* parent_stream,
bool sync_subgraph_fetches) {
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,
execution_mode, terminate_flag, logger, device_stream_collection, false, parent_stream);
if (device_stream_collection)
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(false));
#else
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,
execution_mode, terminate_flag, logger, false, parent_stream);
#endif
if (retval.IsOK() && sync_subgraph_fetches && parent_stream) {
parent_stream->Flush();
}
return retval;
}
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType onnx_enum) {
switch (onnx_enum) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return onnx::TensorProto_DataType::TensorProto_DataType_FLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
return onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return onnx::TensorProto_DataType::TensorProto_DataType_INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return onnx::TensorProto_DataType::TensorProto_DataType_UINT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
return onnx::TensorProto_DataType::TensorProto_DataType_INT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
return onnx::TensorProto_DataType::TensorProto_DataType_UINT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return onnx::TensorProto_DataType::TensorProto_DataType_INT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
return onnx::TensorProto_DataType::TensorProto_DataType_UINT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return onnx::TensorProto_DataType::TensorProto_DataType_INT64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
return onnx::TensorProto_DataType::TensorProto_DataType_UINT64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
return onnx::TensorProto_DataType::TensorProto_DataType_STRING;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return onnx::TensorProto_DataType::TensorProto_DataType_BOOL;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
return onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
return onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
return onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
default:
assert(false);
return onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED;
}
}
#ifdef ENABLE_TRAINING
// Needed only when NCCL kernels are enabled.
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) {
const Tensor* prev_input = context->Input<Tensor>(0);
for (int i = 1; i < context->InputCount(); i++) {
const Tensor* curr_input = context->Input<Tensor>(i);
ORT_ENFORCE(prev_input->Shape().Size() >= 0);
const void* curr_address = curr_input->DataRaw();
const void* prev_address = prev_input->DataRaw();
const void* prev_end_address = reinterpret_cast<const char*>(prev_address) + prev_input->SizeInBytes();
void* aligned_address = const_cast<void*>(prev_end_address);
size_t dummy_space = kAllocAlignment * 2;
std::align(kAllocAlignment, 1, aligned_address, dummy_space);
if (!(curr_address == prev_end_address || curr_address == aligned_address)) {
const std::string node = context->GetNodeName().empty() ? context->GetOpType() : context->GetNodeName();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Contiguous memory checking failed on node ", node, ": ",
"input #", i - 1, " address is ", prev_address, " and #bytes = ", prev_input->SizeInBytes(),
", input #", i, " address is ", curr_address);
}
prev_input = curr_input;
}
return Status::OK();
}
#endif
bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) {
if (p_kci && p_kci->kernel_def->IsInputOnCpu(index)) {
return true;
}
#ifdef ENABLE_ATEN
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
node.Domain() == kPytorchAtenDomain) {
const auto& attrs = node.GetAttributes();
ORT_ENFORCE(utils::HasString(attrs.at("operator")));
std::string op_name = attrs.at("operator").s();
std::string overload_name = "";
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
overload_name = attrs.at("overload_name").s();
}
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
}
#else
ORT_UNUSED_PARAMETER(node);
#endif
return false;
}
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) {
if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) {
return true;
}
#ifdef ENABLE_ATEN
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
node.Domain() == kPytorchAtenDomain) {
const auto& attrs = node.GetAttributes();
ORT_ENFORCE(utils::HasString(attrs.at("operator")));
std::string op_name = attrs.at("operator").s();
std::string overload_name = "";
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
overload_name = attrs.at("overload_name").s();
}
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
}
#else
ORT_UNUSED_PARAMETER(node);
#endif
return false;
}
} // namespace utils
} // namespace onnxruntime