Fix build error

This commit is contained in:
Ryan Hill 2021-05-06 18:26:45 -07:00
parent 29288eb480
commit dda221501b
2 changed files with 18 additions and 12 deletions

View file

@ -36,6 +36,11 @@ final class OnnxRuntime {
/** The short name of the ONNX runtime JNI shared library */
static final String ONNXRUNTIME_JNI_LIBRARY_NAME = "onnxruntime4j_jni";
/** The short name of the ONNX runtime shared provider library */
static final String ONNXRUNTIME_LIBRARY_SHARED_NAME = "onnxruntime_providers_shared";
/** The short name of the ONNX runtime cuda provider library */
static final String ONNXRUNTIME_LIBRARY_CUDA_NAME = "onnxruntime_providers_cuda";
private static final String OS_ARCH_STR = initOsArch();
private static boolean loaded = false;
@ -93,8 +98,10 @@ final class OnnxRuntime {
}
Path tempDirectory = isAndroid() ? null : Files.createTempDirectory("onnxruntime-java");
try {
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME);
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME);
load(tempDirectory, ONNXRUNTIME_LIBRARY_SHARED_NAME, false);
load(tempDirectory, ONNXRUNTIME_LIBRARY_CUDA_NAME, false);
load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME, true);
load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME, true);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_7);
providers = initialiseProviders(ortApiHandle);
loaded = true;
@ -138,7 +145,7 @@ final class OnnxRuntime {
* @param library The bare name of the library.
* @throws IOException If the file failed to read or write.
*/
private static void load(Path tempDirectory, String library) throws IOException {
private static void load(Path tempDirectory, String library, boolean system_load) throws IOException {
// On Android, we simply use System.loadLibrary
if (isAndroid()) {
System.loadLibrary("onnxruntime4j_jni");
@ -201,7 +208,8 @@ final class OnnxRuntime {
os.write(buffer, 0, readBytes);
}
}
System.load(tempFile.getAbsolutePath());
if (system_load)
System.load(tempFile.getAbsolutePath());
logger.log(Level.FINE, "Loaded native library '" + library + "' from resource path");
}
} finally {

View file

@ -19,10 +19,9 @@
#include "core/util/thread_utils.h"
#include "gtest/gtest.h"
#include "test/test_environment.h"
#include "test/util/include/default_providers.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_execution_provider.h"
#elif USE_ROCM
#ifdef USE_ROCM
#include "core/providers/rocm/rocm_execution_provider.h"
#endif
@ -189,8 +188,7 @@ static void TestCPUNodePlacement(const std::basic_string<ORTCHAR_T>& model_uri,
ExecutionProviders execution_providers;
#if defined(USE_CUDA)
CUDAExecutionProviderInfo cuda_epi;
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, std::make_unique<CUDAExecutionProvider>(cuda_epi)));
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider()));
#elif defined(USE_ROCM)
ROCMExecutionProviderInfo rocm_epi;
ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kRocmExecutionProvider, std::make_unique<ROCMExecutionProvider>(rocm_epi)));
@ -248,9 +246,9 @@ TEST(SessionStateTest, CPUPlacementTest3) {
TestCPUNodePlacement(ORT_TSTR("testdata/cpu_fallback_pattern_3.onnx"), expected_cpu_nodes, expected_gpu_nodes);
}
TEST(SessionStateTest, CPUPlacementTest4) {
// Currently, the behaviour is different for RocM and CUDA EP as Rocm EP is missing a valid kernel
// for ReduceSum for int64 type. This causes the backward trace in GetCpuPreferredNodes to stop
// earlier. The expected values can be modified to match CUDA once the RocM EP kernel is updated
// Currently, the behaviour is different for RocM and CUDA EP as Rocm EP is missing a valid kernel
// for ReduceSum for int64 type. This causes the backward trace in GetCpuPreferredNodes to stop
// earlier. The expected values can be modified to match CUDA once the RocM EP kernel is updated
#if defined(USE_CUDA)
std::unordered_set<std::string> expected_cpu_nodes = {"range", "reduce", "const1"};
std::unordered_set<std::string> expected_gpu_nodes = {"size0", "expand"};