Workaround issue with VS2017 compiler. (#5279)

The definitions for some Eigen classes don't get pulled in leading to errors. Split out the broadcast function creation logic from the functions using std::enable_if to workaround that.
This commit is contained in:
Scott McKay 2020-09-25 06:50:14 +10:00 committed by GitHub
parent 5a71819be6
commit b49ff6151e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -49,7 +49,7 @@ template <typename T, typename R>
using EnableIfEigenNotScalar = typename std::enable_if<!std::is_arithmetic<T>::value, R>::type;
template <typename T>
EnableIfEigenScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
ProcessBroadcastSpanFuncs CreateScalarBroadcastFuncs() {
return ProcessBroadcastSpanFuncs{
[](BroadcastHelper& per_iter_bh) {
bool target = per_iter_bh.GetUserData();
@ -81,7 +81,7 @@ EnableIfEigenScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
}
template <typename T>
EnableIfEigenNotScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
return ProcessBroadcastSpanFuncs{
[](BroadcastHelper& per_iter_bh) {
bool target = per_iter_bh.GetUserData();
@ -116,6 +116,19 @@ EnableIfEigenNotScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
}};
}
template <typename T>
EnableIfEigenScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
// NOTE: Workaround a VS2017 bug by calling a separate function to create the broadcast funcs.
// If we create them directly here it doesn't bring in the definitions of the Eigen classes leading to
// a 'class has no constructors' error
return CreateScalarBroadcastFuncs<T>();
}
template <typename T>
EnableIfEigenNotScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
return CreateNonScalarBroadcastFuncs<T>();
}
template <typename T>
void MergeScalarAndVector(EigenVectorMap<T> output, const T& scalar_value, ConstEigenVectorMap<T> vector_value) {
if (scalar_value != T{}) {
@ -231,8 +244,13 @@ Status Where<T>::Compute(OpKernelContext* context) const {
// X_selection = condition ? X : default value
// Similarly, we broadcast over condition and Y to select the values from Y:
// Y_selection = !condition ? Y : default value
//
// These selections are handled within UntypedSelect.
//
// Finally, we broadcast over and merge X_selection and Y_selection:
// output = (X_selection != default value) ? X_selection : Y_selection
//
// The merging is handled within UntypedMerge.
auto X_selection_tensor = UntypedSelect(*context, true, tensor_allocator, typed_tensor_allocation, funcs);
auto Y_selection_tensor = UntypedSelect(*context, false, tensor_allocator, typed_tensor_allocation, funcs);
@ -240,4 +258,5 @@ Status Where<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}
} // namespace onnxruntime