Register ONNX Training Ops (#3252)

This commit is contained in:
Sherlock 2020-03-18 12:36:57 -07:00 committed by GitHub
parent c5576d70a6
commit 03d14bae2b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 6 deletions

View file

@ -8,6 +8,7 @@
#include "core/graph/op.h"
#include "onnx/defs/operator_sets.h"
#include "onnx/defs/operator_sets-ml.h"
#include "onnx/defs/operator_sets-training.h"
#ifndef DISABLE_CONTRIB_OPS
#include "core/graph/contrib_ops/contrib_defs.h"
#endif
@ -56,8 +57,8 @@ Status Environment::Initialize() {
#ifdef USE_DML
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1);
#endif
// Register contributed schemas.
// The corresponding kernels are registered inside the appropriate execution provider.
// Register contributed schemas.
// The corresponding kernels are registered inside the appropriate execution provider.
#ifndef DISABLE_CONTRIB_OPS
contrib::RegisterContribSchemas();
#endif
@ -69,6 +70,7 @@ Status Environment::Initialize() {
#endif
RegisterOnnxOperatorSetSchema();
RegisterOnnxMLOperatorSetSchema();
RegisterOnnxTrainingOperatorSetSchema();
#ifdef ENABLE_TRAINING
// preserve this order: this depends on operatorsetschema registration.

View file

@ -501,12 +501,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
{"shrink", "Invalid rank for input", {"onnx141"}},
{"split_zero_size_splits", "Invalid value", {"onnxtip"}},
{"dropout_random", "result differs", {"onnxtip"}},
{"adagrad", "invalid model", {"onnxtip"}},
{"adagrad_multiple", "invalid model", {"onnxtip"}},
{"celu", "invalid model", {"onnxtip"}},
{"celu_expanded", "invalid model", {"onnxtip"}},
{"gradient_of_add", "invalid model", {"onnxtip"}},
{"gradient_of_add_and_mul", "invalid model", {"onnxtip"}},
{"max_float16", "invalid model", {"onnxtip"}},
{"mean_square_distance_mean_3d_expanded", "invalid model", {"onnxtip"}},
{"mean_square_distance_mean_4d_expanded", "invalid model", {"onnxtip"}},