From a43382e390dd9244cc5e46f308ff19d9e4aa311c Mon Sep 17 00:00:00 2001 From: jignparm Date: Thu, 20 Dec 2018 09:58:03 -0800 Subject: [PATCH] Jignparm/csharp gpu (#221) * Minor updates to exception message * update models folder to new location * update copy to preservenewest * reenable pretrained test * added some debugging info for build * update pretrained test, and tensor proto definition --- .../InferenceTest.cs | 29 +++-- .../Microsoft.ML.OnnxRuntime.Tests/OnnxMl.cs | 112 +++++++++--------- 2 files changed, 73 insertions(+), 68 deletions(-) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 8044665ab7..256792f734 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -205,40 +205,43 @@ namespace Microsoft.ML.OnnxRuntime.Tests session.Dispose(); } - [Fact(Skip = "Disable temporarily")] + [Fact] private void TestPreTrainedModelsOpset7And8() { var opsets = new[] { "opset7", "opset8" }; foreach (var opset in opsets) { var modelRoot = new DirectoryInfo(opset); - foreach (var model in modelRoot.EnumerateDirectories()) + foreach (var modelDir in modelRoot.EnumerateDirectories()) { // TODO: dims contains 'None'. Session throws error. - if (model.ToString() == "test_tiny_yolov2") + if (modelDir.Name== "test_tiny_yolov2") continue; + + String onnxModelFileName = null; try { - var modelNames = model.GetFiles("*.onnx"); - if (modelNames.Count() != 1) + var onnxModelNames = modelDir.GetFiles("*.onnx"); + if (onnxModelNames.Count() != 1) { // TODO remove file "._resnet34v2.onnx" from test set - if (modelNames[0].ToString() == "._resnet34v2.onnx") - modelNames[0] = modelNames[1]; + if (onnxModelNames[0].Name == "._resnet34v2.onnx") + onnxModelNames[0] = onnxModelNames[1]; else { - var modelNamesList = string.Join(",", modelNames.Select(x => x.ToString())); - throw new Exception($"Opset {opset}: Model {model}. Can't determine model file name. Found these :{modelNamesList}"); + var modelNamesList = string.Join(",", onnxModelNames.Select(x => x.ToString())); + throw new Exception($"Opset {opset}: Model {modelDir}. Can't determine model file name. Found these :{modelNamesList}"); } } - var session = new InferenceSession($"{opset}\\{model}\\{modelNames[0].ToString()}"); + onnxModelFileName = $"{opset}\\{modelDir.Name}\\{onnxModelNames[0].Name}"; + var session = new InferenceSession(onnxModelFileName); var inMeta = session.InputMetadata; var innodepair = inMeta.First(); var innodename = innodepair.Key; var innodedims = innodepair.Value.Dimensions; - var dataIn = LoadTensorFromFilePb($"{opset}\\{model}\\test_data_set_0\\input_0.pb"); - var dataOut = LoadTensorFromFilePb($"{opset}\\{model}\\test_data_set_0\\output_0.pb"); + var dataIn = LoadTensorFromFilePb($"{opset}\\{modelDir.Name}\\test_data_set_0\\input_0.pb"); + var dataOut = LoadTensorFromFilePb($"{opset}\\{modelDir.Name}\\test_data_set_0\\output_0.pb"); var tensorIn = new DenseTensor(dataIn, innodedims); var nov = new List(); nov.Add(NamedOnnxValue.CreateFromTensor(innodename, tensorIn)); @@ -249,7 +252,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests } catch (Exception ex) { - var msg = $"Opset {opset}: Model {model}: error = {ex.Message}"; + var msg = $"Opset {opset}: Model {modelDir}: ModelFile = {onnxModelFileName} error = {ex.Message}"; throw new Exception(msg); } } //model diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/OnnxMl.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/OnnxMl.cs index baf8f0e80e..209e62c30a 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/OnnxMl.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/OnnxMl.cs @@ -53,41 +53,39 @@ namespace Onnx { "b3RvEhIKCmRvY19zdHJpbmcYCiABKAkSIwoFaW5wdXQYCyADKAsyFC5vbm54", "LlZhbHVlSW5mb1Byb3RvEiQKBm91dHB1dBgMIAMoCzIULm9ubnguVmFsdWVJ", "bmZvUHJvdG8SKAoKdmFsdWVfaW5mbxgNIAMoCzIULm9ubnguVmFsdWVJbmZv", - "UHJvdG8ivQQKC1RlbnNvclByb3RvEgwKBGRpbXMYASADKAMSLQoJZGF0YV90", - "eXBlGAIgASgOMhoub25ueC5UZW5zb3JQcm90by5EYXRhVHlwZRIqCgdzZWdt", - "ZW50GAMgASgLMhkub25ueC5UZW5zb3JQcm90by5TZWdtZW50EhYKCmZsb2F0", - "X2RhdGEYBCADKAJCAhABEhYKCmludDMyX2RhdGEYBSADKAVCAhABEhMKC3N0", - "cmluZ19kYXRhGAYgAygMEhYKCmludDY0X2RhdGEYByADKANCAhABEgwKBG5h", - "bWUYCCABKAkSEgoKZG9jX3N0cmluZxgMIAEoCRIQCghyYXdfZGF0YRgJIAEo", - "DBIXCgtkb3VibGVfZGF0YRgKIAMoAUICEAESFwoLdWludDY0X2RhdGEYCyAD", - "KARCAhABGiUKB1NlZ21lbnQSDQoFYmVnaW4YASABKAMSCwoDZW5kGAIgASgD", - "ItoBCghEYXRhVHlwZRINCglVTkRFRklORUQQABIJCgVGTE9BVBABEgkKBVVJ", - "TlQ4EAISCAoESU5UOBADEgoKBlVJTlQxNhAEEgkKBUlOVDE2EAUSCQoFSU5U", - "MzIQBhIJCgVJTlQ2NBAHEgoKBlNUUklORxAIEggKBEJPT0wQCRILCgdGTE9B", - "VDE2EAoSCgoGRE9VQkxFEAsSCgoGVUlOVDMyEAwSCgoGVUlOVDY0EA0SDQoJ", - "Q09NUExFWDY0EA4SDgoKQ09NUExFWDEyOBAPEgwKCEJGTE9BVDE2EBAilQEK", - "EFRlbnNvclNoYXBlUHJvdG8SLQoDZGltGAEgAygLMiAub25ueC5UZW5zb3JT", - "aGFwZVByb3RvLkRpbWVuc2lvbhpSCglEaW1lbnNpb24SEwoJZGltX3ZhbHVl", - "GAEgASgDSAASEwoJZGltX3BhcmFtGAIgASgJSAASEgoKZGVub3RhdGlvbhgD", - "IAEoCUIHCgV2YWx1ZSKWBQoJVHlwZVByb3RvEi0KC3RlbnNvcl90eXBlGAEg", - "ASgLMhYub25ueC5UeXBlUHJvdG8uVGVuc29ySAASMQoNc2VxdWVuY2VfdHlw", - "ZRgEIAEoCzIYLm9ubnguVHlwZVByb3RvLlNlcXVlbmNlSAASJwoIbWFwX3R5", - "cGUYBSABKAsyEy5vbm54LlR5cGVQcm90by5NYXBIABItCgtvcGFxdWVfdHlw", - "ZRgHIAEoCzIWLm9ubnguVHlwZVByb3RvLk9wYXF1ZUgAEjoKEnNwYXJzZV90", - "ZW5zb3JfdHlwZRgIIAEoCzIcLm9ubnguVHlwZVByb3RvLlNwYXJzZVRlbnNv", - "ckgAEhIKCmRlbm90YXRpb24YBiABKAkaXgoGVGVuc29yEi0KCWVsZW1fdHlw", - "ZRgBIAEoDjIaLm9ubnguVGVuc29yUHJvdG8uRGF0YVR5cGUSJQoFc2hhcGUY", - "AiABKAsyFi5vbm54LlRlbnNvclNoYXBlUHJvdG8aLgoIU2VxdWVuY2USIgoJ", - "ZWxlbV90eXBlGAEgASgLMg8ub25ueC5UeXBlUHJvdG8aWAoDTWFwEiwKCGtl", - "eV90eXBlGAEgASgOMhoub25ueC5UZW5zb3JQcm90by5EYXRhVHlwZRIjCgp2", - "YWx1ZV90eXBlGAIgASgLMg8ub25ueC5UeXBlUHJvdG8aJgoGT3BhcXVlEg4K", - "BmRvbWFpbhgBIAEoCRIMCgRuYW1lGAIgASgJGmQKDFNwYXJzZVRlbnNvchIt", - "CgllbGVtX3R5cGUYASABKA4yGi5vbm54LlRlbnNvclByb3RvLkRhdGFUeXBl", - "EiUKBXNoYXBlGAIgASgLMhYub25ueC5UZW5zb3JTaGFwZVByb3RvQgcKBXZh", - "bHVlIjUKEk9wZXJhdG9yU2V0SWRQcm90bxIOCgZkb21haW4YASABKAkSDwoH", - "dmVyc2lvbhgCIAEoAypjCgdWZXJzaW9uEhIKDl9TVEFSVF9WRVJTSU9OEAAS", - "GQoVSVJfVkVSU0lPTl8yMDE3XzEwXzEwEAESGQoVSVJfVkVSU0lPTl8yMDE3", - "XzEwXzMwEAISDgoKSVJfVkVSU0lPThADYgZwcm90bzM=")); + "UHJvdG8ioQQKC1RlbnNvclByb3RvEgwKBGRpbXMYASADKAMSEQoJZGF0YV90", + "eXBlGAIgASgFEioKB3NlZ21lbnQYAyABKAsyGS5vbm54LlRlbnNvclByb3Rv", + "LlNlZ21lbnQSFgoKZmxvYXRfZGF0YRgEIAMoAkICEAESFgoKaW50MzJfZGF0", + "YRgFIAMoBUICEAESEwoLc3RyaW5nX2RhdGEYBiADKAwSFgoKaW50NjRfZGF0", + "YRgHIAMoA0ICEAESDAoEbmFtZRgIIAEoCRISCgpkb2Nfc3RyaW5nGAwgASgJ", + "EhAKCHJhd19kYXRhGAkgASgMEhcKC2RvdWJsZV9kYXRhGAogAygBQgIQARIX", + "Cgt1aW50NjRfZGF0YRgLIAMoBEICEAEaJQoHU2VnbWVudBINCgViZWdpbhgB", + "IAEoAxILCgNlbmQYAiABKAMi2gEKCERhdGFUeXBlEg0KCVVOREVGSU5FRBAA", + "EgkKBUZMT0FUEAESCQoFVUlOVDgQAhIICgRJTlQ4EAMSCgoGVUlOVDE2EAQS", + "CQoFSU5UMTYQBRIJCgVJTlQzMhAGEgkKBUlOVDY0EAcSCgoGU1RSSU5HEAgS", + "CAoEQk9PTBAJEgsKB0ZMT0FUMTYQChIKCgZET1VCTEUQCxIKCgZVSU5UMzIQ", + "DBIKCgZVSU5UNjQQDRINCglDT01QTEVYNjQQDhIOCgpDT01QTEVYMTI4EA8S", + "DAoIQkZMT0FUMTYQECKVAQoQVGVuc29yU2hhcGVQcm90bxItCgNkaW0YASAD", + "KAsyIC5vbm54LlRlbnNvclNoYXBlUHJvdG8uRGltZW5zaW9uGlIKCURpbWVu", + "c2lvbhITCglkaW1fdmFsdWUYASABKANIABITCglkaW1fcGFyYW0YAiABKAlI", + "ABISCgpkZW5vdGF0aW9uGAMgASgJQgcKBXZhbHVlIsIECglUeXBlUHJvdG8S", + "LQoLdGVuc29yX3R5cGUYASABKAsyFi5vbm54LlR5cGVQcm90by5UZW5zb3JI", + "ABIxCg1zZXF1ZW5jZV90eXBlGAQgASgLMhgub25ueC5UeXBlUHJvdG8uU2Vx", + "dWVuY2VIABInCghtYXBfdHlwZRgFIAEoCzITLm9ubnguVHlwZVByb3RvLk1h", + "cEgAEi0KC29wYXF1ZV90eXBlGAcgASgLMhYub25ueC5UeXBlUHJvdG8uT3Bh", + "cXVlSAASOgoSc3BhcnNlX3RlbnNvcl90eXBlGAggASgLMhwub25ueC5UeXBl", + "UHJvdG8uU3BhcnNlVGVuc29ySAASEgoKZGVub3RhdGlvbhgGIAEoCRpCCgZU", + "ZW5zb3ISEQoJZWxlbV90eXBlGAEgASgFEiUKBXNoYXBlGAIgASgLMhYub25u", + "eC5UZW5zb3JTaGFwZVByb3RvGi4KCFNlcXVlbmNlEiIKCWVsZW1fdHlwZRgB", + "IAEoCzIPLm9ubnguVHlwZVByb3RvGjwKA01hcBIQCghrZXlfdHlwZRgBIAEo", + "BRIjCgp2YWx1ZV90eXBlGAIgASgLMg8ub25ueC5UeXBlUHJvdG8aJgoGT3Bh", + "cXVlEg4KBmRvbWFpbhgBIAEoCRIMCgRuYW1lGAIgASgJGkgKDFNwYXJzZVRl", + "bnNvchIRCgllbGVtX3R5cGUYASABKAUSJQoFc2hhcGUYAiABKAsyFi5vbm54", + "LlRlbnNvclNoYXBlUHJvdG9CBwoFdmFsdWUiNQoST3BlcmF0b3JTZXRJZFBy", + "b3RvEg4KBmRvbWFpbhgBIAEoCRIPCgd2ZXJzaW9uGAIgASgDKmMKB1ZlcnNp", + "b24SEgoOX1NUQVJUX1ZFUlNJT04QABIZChVJUl9WRVJTSU9OXzIwMTdfMTBf", + "MTAQARIZChVJUl9WRVJTSU9OXzIwMTdfMTBfMzAQAhIOCgpJUl9WRVJTSU9O", + "EANiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Onnx.Version), }, new pbr::GeneratedClrTypeInfo[] { @@ -2116,12 +2114,13 @@ namespace Onnx { /// Field number for the "data_type" field. public const int DataTypeFieldNumber = 2; - private global::Onnx.TensorProto.Types.DataType dataType_ = 0; + private int dataType_; /// /// The data type of the tensor. + /// This field MUST have a valid TensorProto.DataType value /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Onnx.TensorProto.Types.DataType DataType { + public int DataType { get { return dataType_; } set { dataType_ = value; @@ -2355,7 +2354,7 @@ namespace Onnx { dims_.WriteTo(output, _repeated_dims_codec); if (DataType != 0) { output.WriteRawTag(16); - output.WriteEnum((int) DataType); + output.WriteInt32(DataType); } if (segment_ != null) { output.WriteRawTag(26); @@ -2389,7 +2388,7 @@ namespace Onnx { int size = 0; size += dims_.CalculateSize(_repeated_dims_codec); if (DataType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DataType); + size += 1 + pb::CodedOutputStream.ComputeInt32Size(DataType); } if (segment_ != null) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(Segment); @@ -2462,7 +2461,7 @@ namespace Onnx { break; } case 16: { - dataType_ = (global::Onnx.TensorProto.Types.DataType) input.ReadEnum(); + DataType = input.ReadInt32(); break; } case 26: { @@ -3517,13 +3516,14 @@ namespace Onnx { /// Field number for the "elem_type" field. public const int ElemTypeFieldNumber = 1; - private global::Onnx.TensorProto.Types.DataType elemType_ = 0; + private int elemType_; /// /// This field MUST NOT have the value of UNDEFINED + /// This field MUST have a valid TensorProto.DataType value /// This field MUST be present for this version of the IR. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Onnx.TensorProto.Types.DataType ElemType { + public int ElemType { get { return elemType_; } set { elemType_ = value; @@ -3579,7 +3579,7 @@ namespace Onnx { public void WriteTo(pb::CodedOutputStream output) { if (ElemType != 0) { output.WriteRawTag(8); - output.WriteEnum((int) ElemType); + output.WriteInt32(ElemType); } if (shape_ != null) { output.WriteRawTag(18); @@ -3594,7 +3594,7 @@ namespace Onnx { public int CalculateSize() { int size = 0; if (ElemType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ElemType); + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ElemType); } if (shape_ != null) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); @@ -3631,7 +3631,7 @@ namespace Onnx { _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); break; case 8: { - elemType_ = (global::Onnx.TensorProto.Types.DataType) input.ReadEnum(); + ElemType = input.ReadInt32(); break; } case 18: { @@ -3829,13 +3829,14 @@ namespace Onnx { /// Field number for the "key_type" field. public const int KeyTypeFieldNumber = 1; - private global::Onnx.TensorProto.Types.DataType keyType_ = 0; + private int keyType_; /// + /// This field MUST have a valid TensorProto.DataType value /// This field MUST be present for this version of the IR. /// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Onnx.TensorProto.Types.DataType KeyType { + public int KeyType { get { return keyType_; } set { keyType_ = value; @@ -3894,7 +3895,7 @@ namespace Onnx { public void WriteTo(pb::CodedOutputStream output) { if (KeyType != 0) { output.WriteRawTag(8); - output.WriteEnum((int) KeyType); + output.WriteInt32(KeyType); } if (valueType_ != null) { output.WriteRawTag(18); @@ -3909,7 +3910,7 @@ namespace Onnx { public int CalculateSize() { int size = 0; if (KeyType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) KeyType); + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KeyType); } if (valueType_ != null) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValueType); @@ -3946,7 +3947,7 @@ namespace Onnx { _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); break; case 8: { - keyType_ = (global::Onnx.TensorProto.Types.DataType) input.ReadEnum(); + KeyType = input.ReadInt32(); break; } case 18: { @@ -4162,13 +4163,14 @@ namespace Onnx { /// Field number for the "elem_type" field. public const int ElemTypeFieldNumber = 1; - private global::Onnx.TensorProto.Types.DataType elemType_ = 0; + private int elemType_; /// /// This field MUST NOT have the value of UNDEFINED + /// This field MUST have a valid TensorProto.DataType value /// This field MUST be present for this version of the IR. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Onnx.TensorProto.Types.DataType ElemType { + public int ElemType { get { return elemType_; } set { elemType_ = value; @@ -4224,7 +4226,7 @@ namespace Onnx { public void WriteTo(pb::CodedOutputStream output) { if (ElemType != 0) { output.WriteRawTag(8); - output.WriteEnum((int) ElemType); + output.WriteInt32(ElemType); } if (shape_ != null) { output.WriteRawTag(18); @@ -4239,7 +4241,7 @@ namespace Onnx { public int CalculateSize() { int size = 0; if (ElemType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ElemType); + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ElemType); } if (shape_ != null) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); @@ -4276,7 +4278,7 @@ namespace Onnx { _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); break; case 8: { - elemType_ = (global::Onnx.TensorProto.Types.DataType) input.ReadEnum(); + ElemType = input.ReadInt32(); break; } case 18: {