mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[C# Tests] Add support for double tensor output in TestPreTrainedModels. (#12008)
Add support for double tensor output in TestPreTrainedModels.
This commit is contained in:
parent
7d712c8f8b
commit
466b2d9f3d
2 changed files with 25 additions and 6 deletions
|
|
@ -27,7 +27,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
[Fact(DisplayName = "TestSessionOptions")]
|
||||
public void TestSessionOptions()
|
||||
{
|
||||
// get instance to setup logging
|
||||
// get instance to setup logging
|
||||
var ortEnvInstance = OrtEnv.Instance();
|
||||
|
||||
using (SessionOptions opt = new SessionOptions())
|
||||
|
|
@ -1938,7 +1938,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
var session = (deviceId.HasValue)
|
||||
? new InferenceSession(model, option)
|
||||
: new InferenceSession(model);
|
||||
float[] inputData = TestDataLoader.LoadTensorFromEmbeddedResource("bench.in");
|
||||
float[] inputData = TestDataLoader.LoadTensorFromEmbeddedResource("bench.in");
|
||||
float[] expectedOutput = TestDataLoader.LoadTensorFromEmbeddedResource("bench.expected_out");
|
||||
var inputMeta = session.InputMetadata;
|
||||
var tensor = new DenseTensor<float>(inputData, inputMeta["data_0"].Dimensions);
|
||||
|
|
@ -1961,6 +1961,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
internal class DoubleComparer : IEqualityComparer<double>
|
||||
{
|
||||
private double atol = 1e-3;
|
||||
private double rtol = 1.7e-2;
|
||||
|
||||
public bool Equals(double x, double y)
|
||||
{
|
||||
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
|
||||
}
|
||||
public int GetHashCode(double x)
|
||||
{
|
||||
return x.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
class ExactComparer<T> : IEqualityComparer<T>
|
||||
{
|
||||
public bool Equals(T x, T y)
|
||||
|
|
@ -2069,4 +2084,4 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -518,6 +518,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{
|
||||
Assert.Equal(result.AsTensor<float>(), outputValue.AsTensor<float>(), new FloatComparer());
|
||||
}
|
||||
else if (outputMeta.ElementType == typeof(double))
|
||||
{
|
||||
Assert.Equal(result.AsTensor<double>(), outputValue.AsTensor<double>(), new DoubleComparer());
|
||||
}
|
||||
else if (outputMeta.ElementType == typeof(int))
|
||||
{
|
||||
Assert.Equal(result.AsTensor<int>(), outputValue.AsTensor<int>(), new ExactComparer<int>());
|
||||
|
|
@ -560,12 +564,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
else
|
||||
{
|
||||
Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType));
|
||||
Assert.True(false, $"{nameof(TestPreTrainedModels)} does not yet support output of type {outputMeta.ElementType}");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert.True(false, "TestPretrainedModel cannot handle non-tensor outputs yet");
|
||||
Assert.True(false, $"{nameof(TestPreTrainedModels)} cannot handle non-tensor outputs yet");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -808,4 +812,4 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
return modelsDir;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue