diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 91e37d1f75..a256baddba 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -8,6 +8,7 @@ #include "core/graph/op.h" #include "onnx/defs/operator_sets.h" #include "onnx/defs/operator_sets-ml.h" +#include "onnx/defs/operator_sets-training.h" #ifndef DISABLE_CONTRIB_OPS #include "core/graph/contrib_ops/contrib_defs.h" #endif @@ -56,8 +57,8 @@ Status Environment::Initialize() { #ifdef USE_DML ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1); #endif - // Register contributed schemas. - // The corresponding kernels are registered inside the appropriate execution provider. +// Register contributed schemas. +// The corresponding kernels are registered inside the appropriate execution provider. #ifndef DISABLE_CONTRIB_OPS contrib::RegisterContribSchemas(); #endif @@ -69,6 +70,7 @@ Status Environment::Initialize() { #endif RegisterOnnxOperatorSetSchema(); RegisterOnnxMLOperatorSetSchema(); + RegisterOnnxTrainingOperatorSetSchema(); #ifdef ENABLE_TRAINING // preserve this order: this depends on operatorsetschema registration. diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index e06001f693..30d602ec0c 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -501,12 +501,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"shrink", "Invalid rank for input", {"onnx141"}}, {"split_zero_size_splits", "Invalid value", {"onnxtip"}}, {"dropout_random", "result differs", {"onnxtip"}}, - {"adagrad", "invalid model", {"onnxtip"}}, - {"adagrad_multiple", "invalid model", {"onnxtip"}}, {"celu", "invalid model", {"onnxtip"}}, {"celu_expanded", "invalid model", {"onnxtip"}}, - {"gradient_of_add", "invalid model", {"onnxtip"}}, - {"gradient_of_add_and_mul", "invalid model", {"onnxtip"}}, {"max_float16", "invalid model", {"onnxtip"}}, {"mean_square_distance_mean_3d_expanded", "invalid model", {"onnxtip"}}, {"mean_square_distance_mean_4d_expanded", "invalid model", {"onnxtip"}},