mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Workaround issue with VS2017 compiler. (#5279)
The definitions for some Eigen classes don't get pulled in leading to errors. Split out the broadcast function creation logic from the functions using std::enable_if to workaround that.
This commit is contained in:
parent
5a71819be6
commit
b49ff6151e
1 changed files with 21 additions and 2 deletions
|
|
@ -49,7 +49,7 @@ template <typename T, typename R>
|
|||
using EnableIfEigenNotScalar = typename std::enable_if<!std::is_arithmetic<T>::value, R>::type;
|
||||
|
||||
template <typename T>
|
||||
EnableIfEigenScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
|
||||
ProcessBroadcastSpanFuncs CreateScalarBroadcastFuncs() {
|
||||
return ProcessBroadcastSpanFuncs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
bool target = per_iter_bh.GetUserData();
|
||||
|
|
@ -81,7 +81,7 @@ EnableIfEigenScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
EnableIfEigenNotScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
|
||||
ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
|
||||
return ProcessBroadcastSpanFuncs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
bool target = per_iter_bh.GetUserData();
|
||||
|
|
@ -116,6 +116,19 @@ EnableIfEigenNotScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
|
|||
}};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EnableIfEigenScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
|
||||
// NOTE: Workaround a VS2017 bug by calling a separate function to create the broadcast funcs.
|
||||
// If we create them directly here it doesn't bring in the definitions of the Eigen classes leading to
|
||||
// a 'class has no constructors' error
|
||||
return CreateScalarBroadcastFuncs<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EnableIfEigenNotScalar<T, ProcessBroadcastSpanFuncs> SelectBroadcastFuncs() {
|
||||
return CreateNonScalarBroadcastFuncs<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MergeScalarAndVector(EigenVectorMap<T> output, const T& scalar_value, ConstEigenVectorMap<T> vector_value) {
|
||||
if (scalar_value != T{}) {
|
||||
|
|
@ -231,8 +244,13 @@ Status Where<T>::Compute(OpKernelContext* context) const {
|
|||
// X_selection = condition ? X : default value
|
||||
// Similarly, we broadcast over condition and Y to select the values from Y:
|
||||
// Y_selection = !condition ? Y : default value
|
||||
//
|
||||
// These selections are handled within UntypedSelect.
|
||||
//
|
||||
// Finally, we broadcast over and merge X_selection and Y_selection:
|
||||
// output = (X_selection != default value) ? X_selection : Y_selection
|
||||
//
|
||||
// The merging is handled within UntypedMerge.
|
||||
auto X_selection_tensor = UntypedSelect(*context, true, tensor_allocator, typed_tensor_allocation, funcs);
|
||||
auto Y_selection_tensor = UntypedSelect(*context, false, tensor_allocator, typed_tensor_allocation, funcs);
|
||||
|
||||
|
|
@ -240,4 +258,5 @@ Status Where<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue