mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
70 lines
No EOL
2.9 KiB
C++
70 lines
No EOL
2.9 KiB
C++
#include "testPch.h"
|
|
#include "compare_feature_value.h"
|
|
#include "StringHelpers.h"
|
|
#include <core/framework/ort_value.h>
|
|
#include "ort_value_helper.h"
|
|
|
|
using namespace winrt::Windows::Foundation::Collections;
|
|
using namespace winml;
|
|
|
|
namespace CompareFeatureValuesHelper{
|
|
|
|
template <typename T>
|
|
bool IsResultCloselyMatch(const T& outvalue, const T& expected_value, const double diff, const double tol) {
|
|
if (diff > tol) return false;
|
|
if (std::isnan(diff) && !(std::isnan(outvalue) && std::isnan(expected_value)) &&
|
|
!(std::isinf(outvalue) && std::isinf(expected_value)))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
bool CompareSequenceOfMapsStringToFloat(
|
|
IVectorView<IMap<winrt::hstring, float>> featureValue,
|
|
const Ort::Value& val,
|
|
double perSampleTolerance,
|
|
double relativePerSampleTolerance) {
|
|
if (val.GetCount() != featureValue.Size()) {
|
|
printf("Map lengths are not the same! Got %d, expected %d\n", static_cast<int>(featureValue.Size()), static_cast<int>(val.GetCount()));
|
|
}
|
|
|
|
int expectedValSequenceIndex = 0;
|
|
Ort::AllocatorWithDefaultOptions allocator;
|
|
for (IMap<winrt::hstring, float> mapVal : featureValue) {
|
|
std::map<winrt::hstring, float> expectedKvp;
|
|
std::vector<std::pair<winrt::hstring, float>> actualKvp;
|
|
|
|
Ort::Value mapExpectedOutput(nullptr);
|
|
Ort::Value mapExpectedOutputKeys(nullptr);
|
|
Ort::Value mapExpectedOutputValues(nullptr);
|
|
WINML_EXPECT_NO_THROW(mapExpectedOutput = val.GetValue(expectedValSequenceIndex, allocator));
|
|
WINML_EXPECT_NO_THROW(mapExpectedOutputKeys = mapExpectedOutput.GetValue(0, allocator));
|
|
WINML_EXPECT_NO_THROW(mapExpectedOutputValues = mapExpectedOutput.GetValue(1, allocator));
|
|
|
|
auto expectedOutputKeys = OrtValueHelpers::LoadTensorFromOrtValue(mapExpectedOutputKeys).as<TensorString>().GetAsVectorView();
|
|
auto expectedOutputValues = OrtValueHelpers::LoadTensorFromOrtValue(mapExpectedOutputValues).as<TensorFloat>().GetAsVectorView();
|
|
for (uint32_t i = 0; i < expectedOutputKeys.Size(); i++) {
|
|
expectedKvp[expectedOutputKeys.GetAt(i)] = expectedOutputValues.GetAt(i);
|
|
}
|
|
for (auto kvp : mapVal) {
|
|
winrt::hstring actualKey = kvp.Key();
|
|
float actualValue = kvp.Value();
|
|
if (expectedKvp.find(actualKey) == expectedKvp.end()) {
|
|
printf("Unexpected key in actual output: %ws", actualKey.c_str());
|
|
return false;
|
|
} else {
|
|
// verify that the value is within tolerable ranges
|
|
const double diff = std::fabs(expectedKvp[actualKey] - actualValue);
|
|
const double tol = perSampleTolerance + relativePerSampleTolerance * std::fabs(expectedKvp[actualKey]);
|
|
if (!IsResultCloselyMatch<double>(actualValue, expectedKvp[actualKey], diff, tol)) {
|
|
printf("expected (%f), actual (%f), diff: %f, tol= %f .\n", expectedKvp[actualKey], actualValue, diff, tol);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
expectedValSequenceIndex++;
|
|
}
|
|
// If errors or discrepancies are not found, then return true
|
|
return true;
|
|
}
|
|
|
|
} |