From b49ff6151e4dc080cd9fe8200fd2672e9bbba51c Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 25 Sep 2020 06:50:14 +1000 Subject: [PATCH] Workaround issue with VS2017 compiler. (#5279) The definitions for some Eigen classes don't get pulled in leading to errors. Split out the broadcast function creation logic from the functions using std::enable_if to workaround that. --- .../core/providers/cpu/tensor/where_op.cc | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/where_op.cc b/onnxruntime/core/providers/cpu/tensor/where_op.cc index 5b4f1b9159..05be3d4ff4 100644 --- a/onnxruntime/core/providers/cpu/tensor/where_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/where_op.cc @@ -49,7 +49,7 @@ template using EnableIfEigenNotScalar = typename std::enable_if::value, R>::type; template -EnableIfEigenScalar SelectBroadcastFuncs() { +ProcessBroadcastSpanFuncs CreateScalarBroadcastFuncs() { return ProcessBroadcastSpanFuncs{ [](BroadcastHelper& per_iter_bh) { bool target = per_iter_bh.GetUserData(); @@ -81,7 +81,7 @@ EnableIfEigenScalar SelectBroadcastFuncs() { } template -EnableIfEigenNotScalar SelectBroadcastFuncs() { +ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() { return ProcessBroadcastSpanFuncs{ [](BroadcastHelper& per_iter_bh) { bool target = per_iter_bh.GetUserData(); @@ -116,6 +116,19 @@ EnableIfEigenNotScalar SelectBroadcastFuncs() { }}; } +template +EnableIfEigenScalar SelectBroadcastFuncs() { + // NOTE: Workaround a VS2017 bug by calling a separate function to create the broadcast funcs. + // If we create them directly here it doesn't bring in the definitions of the Eigen classes leading to + // a 'class has no constructors' error + return CreateScalarBroadcastFuncs(); +} + +template +EnableIfEigenNotScalar SelectBroadcastFuncs() { + return CreateNonScalarBroadcastFuncs(); +} + template void MergeScalarAndVector(EigenVectorMap output, const T& scalar_value, ConstEigenVectorMap vector_value) { if (scalar_value != T{}) { @@ -231,8 +244,13 @@ Status Where::Compute(OpKernelContext* context) const { // X_selection = condition ? X : default value // Similarly, we broadcast over condition and Y to select the values from Y: // Y_selection = !condition ? Y : default value + // + // These selections are handled within UntypedSelect. + // // Finally, we broadcast over and merge X_selection and Y_selection: // output = (X_selection != default value) ? X_selection : Y_selection + // + // The merging is handled within UntypedMerge. auto X_selection_tensor = UntypedSelect(*context, true, tensor_allocator, typed_tensor_allocation, funcs); auto Y_selection_tensor = UntypedSelect(*context, false, tensor_allocator, typed_tensor_allocation, funcs); @@ -240,4 +258,5 @@ Status Where::Compute(OpKernelContext* context) const { return Status::OK(); } + } // namespace onnxruntime