[C# Tests] Add support for double tensor output in TestPreTrainedModels. (#12008)

Add support for double tensor output in TestPreTrainedModels.
This commit is contained in:
Edward Chen 2022-06-27 18:49:19 -07:00 committed by GitHub
parent 7d712c8f8b
commit 466b2d9f3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 6 deletions

View file

@ -27,7 +27,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[Fact(DisplayName = "TestSessionOptions")]
public void TestSessionOptions()
{
// get instance to setup logging
// get instance to setup logging
var ortEnvInstance = OrtEnv.Instance();
using (SessionOptions opt = new SessionOptions())
@ -1938,7 +1938,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var session = (deviceId.HasValue)
? new InferenceSession(model, option)
: new InferenceSession(model);
float[] inputData = TestDataLoader.LoadTensorFromEmbeddedResource("bench.in");
float[] inputData = TestDataLoader.LoadTensorFromEmbeddedResource("bench.in");
float[] expectedOutput = TestDataLoader.LoadTensorFromEmbeddedResource("bench.expected_out");
var inputMeta = session.InputMetadata;
var tensor = new DenseTensor<float>(inputData, inputMeta["data_0"].Dimensions);
@ -1961,6 +1961,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
internal class DoubleComparer : IEqualityComparer<double>
{
private double atol = 1e-3;
private double rtol = 1.7e-2;
public bool Equals(double x, double y)
{
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
}
public int GetHashCode(double x)
{
return x.GetHashCode();
}
}
class ExactComparer<T> : IEqualityComparer<T>
{
public bool Equals(T x, T y)
@ -2069,4 +2084,4 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
#endregion
}
}
}

View file

@ -518,6 +518,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
Assert.Equal(result.AsTensor<float>(), outputValue.AsTensor<float>(), new FloatComparer());
}
else if (outputMeta.ElementType == typeof(double))
{
Assert.Equal(result.AsTensor<double>(), outputValue.AsTensor<double>(), new DoubleComparer());
}
else if (outputMeta.ElementType == typeof(int))
{
Assert.Equal(result.AsTensor<int>(), outputValue.AsTensor<int>(), new ExactComparer<int>());
@ -560,12 +564,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
else
{
Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType));
Assert.True(false, $"{nameof(TestPreTrainedModels)} does not yet support output of type {outputMeta.ElementType}");
}
}
else
{
Assert.True(false, "TestPretrainedModel cannot handle non-tensor outputs yet");
Assert.True(false, $"{nameof(TestPreTrainedModels)} cannot handle non-tensor outputs yet");
}
}
}
@ -808,4 +812,4 @@ namespace Microsoft.ML.OnnxRuntime.Tests
return modelsDir;
}
}
}
}