mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
enable more unit tests for ROCM EP (#7307)
This commit is contained in:
parent
f27f5afd8a
commit
75c0192e4f
8 changed files with 73 additions and 30 deletions
|
|
@ -31,6 +31,8 @@
|
|||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/gpu_data_transfer.h"
|
||||
#elif USE_ROCM
|
||||
#include "core/providers/rocm/gpu_data_transfer.h"
|
||||
#endif
|
||||
#include "core/session/environment.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
|
|
@ -293,6 +295,11 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
#ifdef USE_CUDA
|
||||
AllocateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), expected_output_dims,
|
||||
&output_ml_value);
|
||||
#endif
|
||||
} else if (allocation_provider == kRocmExecutionProvider) {
|
||||
#ifdef USE_ROCM
|
||||
AllocateMLValue<float>(TestRocmExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), expected_output_dims,
|
||||
&output_ml_value);
|
||||
#endif
|
||||
} else {
|
||||
ORT_THROW("Unsupported provider");
|
||||
|
|
@ -317,9 +324,9 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
std::cout << "Run returned status: " << st.ErrorMessage() << std::endl;
|
||||
ASSERT_TRUE(st.IsOK());
|
||||
|
||||
if ((is_preallocate_output_vec && allocation_provider == kCudaExecutionProvider) ||
|
||||
if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider)) ||
|
||||
(output_device && output_device->Type() == OrtDevice::GPU)) {
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
// in this case we need to copy the tensor from cuda to cpu
|
||||
vector<OrtValue>& outputs = io_binding->GetOutputs();
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
|
|
@ -330,7 +337,11 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
std::unique_ptr<Tensor> cpu_tensor = onnxruntime::make_unique<Tensor>(element_type,
|
||||
shape,
|
||||
cpu_allocator);
|
||||
#ifdef USE_CUDA
|
||||
cudaStream_t stream = static_cast<cudaStream_t>(static_cast<const onnxruntime::CUDAExecutionProvider*>(TestCudaExecutionProvider())->GetComputeStream());
|
||||
#elif USE_ROCM
|
||||
hipStream_t stream = static_cast<hipStream_t>(static_cast<const onnxruntime::ROCMExecutionProvider*>(TestRocmExecutionProvider())->GetComputeStream());
|
||||
#endif
|
||||
st = GPUDataTransfer(stream).CopyTensor(rtensor, *cpu_tensor.get(), 0);
|
||||
ASSERT_TRUE(st.IsOK());
|
||||
OrtValue ml_value;
|
||||
|
|
@ -343,6 +354,10 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
|
|||
if (allocation_provider == kCudaExecutionProvider) {
|
||||
#ifdef USE_CUDA
|
||||
TestCudaExecutionProvider()->Sync();
|
||||
#endif
|
||||
} else if (allocation_provider == kRocmExecutionProvider) {
|
||||
#ifdef USE_ROCM
|
||||
TestRocmExecutionProvider()->Sync();
|
||||
#endif
|
||||
}
|
||||
VerifyOutputs(io_binding->GetOutputs(), expected_output_dims, expected_values_mul_y);
|
||||
|
|
@ -832,11 +847,15 @@ static void TestBindHelper(const std::string& log_str,
|
|||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
|
||||
if (bind_provider_type == kCudaExecutionProvider || run_provider_type == kCudaExecutionProvider) {
|
||||
if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider) {
|
||||
#ifdef USE_CUDA
|
||||
CUDAExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<CUDAExecutionProvider>(epi)).IsOK());
|
||||
#elif USE_ROCM
|
||||
ROCMExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<ROCMExecutionProvider>(epi)).IsOK());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -944,34 +963,38 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) {
|
|||
ASSERT_TRUE(!st.IsOK());
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#if USE_CUDA
|
||||
constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider;
|
||||
#elif USE_ROCM
|
||||
constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider;
|
||||
#endif
|
||||
TEST(InferenceSessionTests, TestBindCuda) {
|
||||
TestBindHelper("TestBindCuda",
|
||||
kCudaExecutionProvider,
|
||||
kCudaExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
false /* don't preallocate output */);
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCuda) {
|
||||
TestBindHelper("TestBindCudaPreallocateOutputOnCuda",
|
||||
kCudaExecutionProvider,
|
||||
kCudaExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
true /* preallocate output on GPU */,
|
||||
kCudaExecutionProvider);
|
||||
kGpuExecutionProvider);
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu) {
|
||||
TestBindHelper("TestBindCudaPreallocateOutputOnCpu",
|
||||
kCudaExecutionProvider,
|
||||
kCudaExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
true /* preallocate output on CPU */,
|
||||
kCpuExecutionProvider);
|
||||
}
|
||||
|
||||
TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) {
|
||||
TestBindHelper("TestBindCudaPreallocateOutputOnCpu2",
|
||||
kCudaExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
kCpuExecutionProvider,
|
||||
true /* preallocate output on CPU */,
|
||||
kCpuExecutionProvider);
|
||||
|
|
@ -981,10 +1004,10 @@ TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) {
|
|||
OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0);
|
||||
|
||||
TestBindHelper("TestBindCudaPreallocateOutputOnCuda",
|
||||
kCudaExecutionProvider,
|
||||
kCudaExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
false /* preallocate output on GPU */,
|
||||
kCudaExecutionProvider,
|
||||
kGpuExecutionProvider,
|
||||
&device /* specify output device */);
|
||||
}
|
||||
|
||||
|
|
@ -1390,6 +1413,10 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {
|
|||
CUDAExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<CUDAExecutionProvider>(epi)).IsOK());
|
||||
#elif USE_ROCM
|
||||
ROCMExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<ROCMExecutionProvider>(epi)).IsOK());
|
||||
#endif
|
||||
|
||||
status = session_object.Load(model_file_name);
|
||||
|
|
@ -1526,6 +1553,10 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) {
|
|||
CUDAExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<CUDAExecutionProvider>(epi)).IsOK());
|
||||
#elif USE_ROCM
|
||||
ROCMExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<ROCMExecutionProvider>(epi)).IsOK());
|
||||
#endif
|
||||
|
||||
status = session_object.Load(model_file_name);
|
||||
|
|
|
|||
|
|
@ -20,6 +20,14 @@ IExecutionProvider* TestCudaExecutionProvider() {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
IExecutionProvider* TestRocmExecutionProvider() {
|
||||
static ROCMExecutionProviderInfo info;
|
||||
static ROCMExecutionProvider rocm_provider(info);
|
||||
return &rocm_provider;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_TENSORRT
|
||||
#if 0 // TODO: TensorRT is shared, can't access these directly anymore
|
||||
IExecutionProvider* TestTensorrtExecutionProvider() {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@
|
|||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"
|
||||
#endif
|
||||
|
|
@ -37,6 +40,10 @@ IExecutionProvider* TestCPUExecutionProvider();
|
|||
IExecutionProvider* TestCudaExecutionProvider();
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
IExecutionProvider* TestRocmExecutionProvider();
|
||||
#endif
|
||||
|
||||
#ifdef USE_TENSORRT
|
||||
// Doesn't work with ExecutionProviders class and KernelRegistryManager
|
||||
IExecutionProvider* TestTensorrtExecutionProvider();
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#elif USE_ROCM
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#endif
|
||||
|
||||
using namespace onnxruntime::logging;
|
||||
|
|
@ -62,7 +64,7 @@ TEST(TrainingSessionTest, LoadOptimState_FullPrecision_FP32Moments_Adam) {
|
|||
RunTrainingSessionLoadOptimTests(k_adam_optimizer_op_name, false, false);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(TrainingSessionTest, LoadOptimState_MixedPrecision_FP32Moments_Adam) {
|
||||
RunTrainingSessionLoadOptimTests(k_adam_optimizer_op_name, true, false);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ void VerifyState(const DataTransferManager& data_transfer_mgr, const NameMLValMa
|
|||
const auto& e_state_it = expected_state.find(key);
|
||||
ORT_ENFORCE(e_state_it != expected_state.end());
|
||||
auto& expected_tensor = e_state_it->second.Get<Tensor>();
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
auto& actual_gpu_tensor = a_state_it.second.Get<Tensor>();
|
||||
|
||||
// Copying tensor to CPU when cuda is enabled.
|
||||
|
|
@ -182,6 +182,9 @@ std::unique_ptr<TrainingSession> BuildAndRunTrainingSessionWithChecks(
|
|||
#ifdef USE_CUDA
|
||||
CUDAExecutionProviderInfo xp_info;
|
||||
ORT_THROW_IF_ERROR(training_session->RegisterExecutionProvider(onnxruntime::make_unique<CUDAExecutionProvider>(xp_info)));
|
||||
#elif USE_ROCM
|
||||
ROCMExecutionProviderInfo xp_info;
|
||||
ORT_THROW_IF_ERROR(training_session->RegisterExecutionProvider(onnxruntime::make_unique<ROCMExecutionProvider>(xp_info)));
|
||||
#endif
|
||||
ORT_THROW_IF_ERROR(training_session->Initialize());
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#elif USE_ROCM
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ __global__ void LambMultiTensorComputeDirectionImpl(
|
|||
const float alpha,
|
||||
const float beta,
|
||||
const float epsilon,
|
||||
const T1 max_norm,
|
||||
const float max_norm,
|
||||
const float alpha_correction,
|
||||
const float beta_correction) {
|
||||
const int group_index = chunk_group.block_index_to_tensor_group_index[blockIdx.x];
|
||||
|
|
|
|||
|
|
@ -39,23 +39,13 @@ ReductionOpTest.ReduceMin_int32
|
|||
ReductionOpTest.ReduceMin_int8
|
||||
ReductionOpTest.ReduceSum_double
|
||||
ReductionOpTest.ReduceSumSquare_double
|
||||
ReductionOpTest.ReduceInfMax
|
||||
ReductionOpTest.ReduceInfMax_double
|
||||
ReductionOpTest.ReduceInfMin
|
||||
ReductionOpTest.ReduceInfMin_double
|
||||
ReductionOpTest.ReduceInfSum
|
||||
ReductionOpTest.ReduceInfLogSum
|
||||
ReductionOpTest.ReduceInfLogSumExp
|
||||
ReductionOpTest.ReduceInfLogSumExp_double
|
||||
GatherOpTest.Gather_invalid_index_cpu
|
||||
Scatter.InvalidIndex
|
||||
LogSoftmaxOperator.LargeNumber
|
||||
MathOpTest.Pow_int64_float
|
||||
MathOpTest.Pow_int32_float
|
||||
MathOpTest.Pow_int64_double
|
||||
Scatter.InvalidIndex
|
||||
GradientCheckerTest.AddGrad
|
||||
GradientCheckerTest.SubGrad
|
||||
GradientCheckerTest.MulGrad
|
||||
GradientCheckerTest.ReduceMeanGrad
|
||||
GradientCheckerTest.ReduceL2Grad
|
||||
GradientCheckerTest.DivGrad
|
||||
|
|
|
|||
Loading…
Reference in a new issue