cuda fix to unblock the tf model tests (#333)

* Check the pads attribute on Conv, and auto fallback to CPU if it's not symmetric padding

* Insert copy nodes after all graph transformer. It causes some issue if do the cast transformer before memory copy transformer.
This commit is contained in:
Hector Li 2019-01-15 14:05:47 -08:00 committed by Changming Sun
parent 7977871740
commit 835b511fa8
3 changed files with 32 additions and 5 deletions

View file

@ -830,6 +830,30 @@ bool CUDAExecutionProvider::RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}
bool CUDAExecutionProvider::ConvNeedFallbackToCPU(const onnxruntime::Node& node) const {
auto node_attributes = node.GetAttributes();
// Check attributes
for (auto& attr : node_attributes) {
auto attr_name = attr.first;
auto attr_value = attr.second;
//cudnn only supports symmetric padding
if ("pads" == attr_name && ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS == attr_value.type()) {
auto pads = attr_value.ints();
int pads_size = pads.size();
ORT_ENFORCE(pads_size % 2 == 0);
int rank = pads_size / 2;
for (int i = 0; i < rank; i++) {
if(pads.Get(i) != pads.Get(i + rank)) {
return true;
}
}
}
}
return false;
}
std::vector<std::unique_ptr<ComputeCapability>>
CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const {
@ -847,6 +871,8 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
} else if ("GRU" == node.OpType()) {
std::vector<std::string> activations_supported{"sigmoid", "tanh", "sigmoid", "tanh"};
fallback_to_cpu_provider = RNNNeedFallbackToCPU(node, activations_supported, node.OpType());
} else if ("Conv" == node.OpType()) {
fallback_to_cpu_provider = ConvNeedFallbackToCPU(node);
}
if (fallback_to_cpu_provider) {

View file

@ -174,6 +174,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
void ReleasePerThreadStuffs() const;
bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, const std::vector<std::string> activations_supported, const std::string& op_type) const;
bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) const;
};
} // namespace onnxruntime

View file

@ -278,7 +278,11 @@ class InferenceSession::Impl {
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(), const_cast<FuncManager*>(session_state.GetFuncMgr())));
// Insert copy nodes.
// Insert cast node/s.
bool modified = false;
ORT_RETURN_IF_ERROR(insert_cast_transformer.Apply(graph, modified));
// Insert copy nodes after all graph transformer.
for (auto& provider : providers) {
if (provider->Type() != onnxruntime::kCpuExecutionProvider &&
provider->Type() != onnxruntime::kMklDnnExecutionProvider &&
@ -288,10 +292,6 @@ class InferenceSession::Impl {
}
}
// Insert cast node/s.
bool modified = false;
ORT_RETURN_IF_ERROR(insert_cast_transformer.Apply(graph, modified));
return common::Status::OK();
}