diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index a80aa08e9f..7f423b9ee0 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -105,11 +105,12 @@ int main(int argc, char* argv[]) { InferenceSession session_object{so, *env}; + Status st; CUDAExecutionProviderInfo xp_info{static_cast(world_rank)}; - session_object.RegisterExecutionProvider(std::make_unique(xp_info)); + st = session_object.RegisterExecutionProvider(std::make_unique(xp_info)); + ORT_ENFORCE(st == Status::OK(), "MPI rank ", world_rank, ": ", st.ErrorMessage()); std::string model_at_rank; - Status st; if (world_rank == 0) { st = session_object.Load(params.model_stage0_name); ORT_ENFORCE(st == Status::OK(), "MPI rank ", world_rank, ": ", st.ErrorMessage()); @@ -214,4 +215,4 @@ int main(int, char* []) { ORT_NOT_IMPLEMENTED("P2P demo currently requires CUDA to run."); } -#endif \ No newline at end of file +#endif