mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Refine cast optimizer for safety (#15658)
### Description Cast optimizer may convert a fp16 node to fp32. This used to be safe as all fp16 kernels has fp32 implementation. As this assumption is no longer true, we need to check the validity of the operation ### Motivation and Context Main work here is to introduce an API to check whether a kernel is registered. Currently we don't have a way to do that without an operator node. This needs to be augmented. We need to query whether a kernel is registered by its property only, so that we can judge whether it is safe to construct a node long before we actually do so.
This commit is contained in:
parent
c415bc725f
commit
be08b47e7b
7 changed files with 242 additions and 78 deletions
|
|
@ -44,6 +44,25 @@ class KernelRegistry {
|
|||
const TypeConstraintMap& type_constraints,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
/**
|
||||
* @brief Find out whether a kernel is registered, without a node.
|
||||
* This should be useful in graph optimizers, to check whether
|
||||
* the node it is about to generate, is supported or not.
|
||||
* @param exec_provider
|
||||
* @param op_type
|
||||
* @param domain
|
||||
* @param version
|
||||
* @param type_constraints
|
||||
* @param out
|
||||
* @return
|
||||
*/
|
||||
Status TryFindKernel(ProviderType exec_provider,
|
||||
std::string_view op_type,
|
||||
std::string_view domain,
|
||||
int version,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraints,
|
||||
const KernelCreateInfo** out) const;
|
||||
|
||||
static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
|
||||
ProviderType exec_provider,
|
||||
const IKernelTypeStrResolver& kernel_type_str_resolver) {
|
||||
|
|
|
|||
|
|
@ -116,35 +116,41 @@ bool MatchKernelDefTypes(const std::unordered_map<std::string, std::vector<MLDat
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool KernelRegistry::VerifyKernelDef(const Node& node,
|
||||
const KernelDef& kernel_def,
|
||||
const IKernelTypeStrResolver* kernel_type_str_resolver,
|
||||
const TypeConstraintMap* type_constraint_values,
|
||||
std::string& error_str) {
|
||||
static bool VerifyVersion(int since_ver, const KernelDef& kernel_def, std::string& error_str) {
|
||||
// check if version matches
|
||||
int node_version = node.SinceVersion();
|
||||
int kernel_start_version;
|
||||
int kernel_end_version;
|
||||
kernel_def.SinceVersion(&kernel_start_version, &kernel_end_version);
|
||||
|
||||
bool valid_version =
|
||||
// exact match. typical usage.
|
||||
kernel_start_version == node_version ||
|
||||
kernel_start_version == since_ver ||
|
||||
// allow match if the kernel def has an end version. if it does not, all we know is that the kernel supported
|
||||
// the start version when it was created, and not whether a new version of the operator was added since then
|
||||
// that the kernel doesn't support.
|
||||
(kernel_end_version != INT_MAX &&
|
||||
kernel_start_version <= node_version && kernel_end_version >= node_version);
|
||||
kernel_start_version <= since_ver && kernel_end_version >= since_ver);
|
||||
|
||||
if (!valid_version) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Op with name (" << node.Name() << ")"
|
||||
<< " and type (" << node.OpType() << ")"
|
||||
<< " Version mismatch."
|
||||
<< " node_version: " << node_version
|
||||
ostr << " Version mismatch."
|
||||
<< " node_version: " << since_ver
|
||||
<< " kernel start version: " << kernel_start_version
|
||||
<< " kernel_end_version: " << kernel_end_version;
|
||||
error_str = ostr.str();
|
||||
}
|
||||
return valid_version;
|
||||
}
|
||||
|
||||
bool KernelRegistry::VerifyKernelDef(const Node& node,
|
||||
const KernelDef& kernel_def,
|
||||
const IKernelTypeStrResolver* kernel_type_str_resolver,
|
||||
const TypeConstraintMap* type_constraint_values,
|
||||
std::string& error_str) {
|
||||
// check if version matches
|
||||
bool valid_version = VerifyVersion(node.SinceVersion(), kernel_def, error_str);
|
||||
|
||||
if (!valid_version) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -157,12 +163,9 @@ bool KernelRegistry::VerifyKernelDef(const Node& node,
|
|||
|
||||
if (!matched) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Found kernel for Op with name (" << node.Name() << ")"
|
||||
<< " and type (" << node.OpType() << ")"
|
||||
ostr << "Kernel found kernel"
|
||||
<< " in the supported version range"
|
||||
<< " (node_version: " << node_version
|
||||
<< " kernel start version: " << kernel_start_version
|
||||
<< " kernel_end_version: " << kernel_end_version << ")."
|
||||
<< " (node_version: " << node.SinceVersion() << ")."
|
||||
<< " However the types are incompatible. " << mismatch_reason;
|
||||
error_str = ostr.str();
|
||||
}
|
||||
|
|
@ -203,6 +206,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
|
|||
if (!verify_kernel_def_error_strs.empty()) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op with name (" << node.Name() << ")"
|
||||
<< " domain (" << node.Domain() << ")"
|
||||
<< " and type (" << node.OpType() << ")"
|
||||
<< " kernel is not supported in " << expected_provider << "."
|
||||
<< " Encountered following errors: (";
|
||||
|
|
@ -229,6 +233,68 @@ Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provide
|
|||
return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out);
|
||||
}
|
||||
|
||||
static bool KernelDefCompatible(int version, const KernelDef& kernel_def,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraint_values,
|
||||
std::string& error_str) {
|
||||
if (!VerifyVersion(version, kernel_def, error_str)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& kernel_type_constraints = kernel_def.TypeConstraints();
|
||||
bool matched = MatchKernelDefTypes(kernel_type_constraints, type_constraint_values);
|
||||
|
||||
if (!matched) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Kernel found kernel"
|
||||
<< " in the supported version range"
|
||||
<< " (node_version: " << version << ")."
|
||||
<< " However the types are incompatible.";
|
||||
error_str = ostr.str();
|
||||
}
|
||||
|
||||
return matched;
|
||||
}
|
||||
|
||||
Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
|
||||
std::string_view op_type,
|
||||
std::string_view domain,
|
||||
int version,
|
||||
const KernelRegistry::TypeConstraintMap& type_constraints,
|
||||
const KernelCreateInfo** out) const {
|
||||
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider));
|
||||
if (out) *out = nullptr;
|
||||
|
||||
std::vector<std::string> verify_kernel_def_error_strs;
|
||||
|
||||
for (auto i = range.first; i != range.second; ++i) {
|
||||
std::string error_str;
|
||||
if (KernelDefCompatible(version, *i->second.kernel_def, type_constraints, error_str)) {
|
||||
if (out) {
|
||||
*out = &i->second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
verify_kernel_def_error_strs.push_back(error_str);
|
||||
}
|
||||
|
||||
if (!verify_kernel_def_error_strs.empty()) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op type (" << op_type << ")"
|
||||
<< " domain (" << domain << ")"
|
||||
<< " kernel is not supported in " << exec_provider << "."
|
||||
<< " Encountered following errors: (";
|
||||
std::copy(verify_kernel_def_error_strs.begin(), verify_kernel_def_error_strs.end(),
|
||||
std::ostream_iterator<std::string>(oss, "\n"));
|
||||
oss << ")";
|
||||
|
||||
VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
|
||||
}
|
||||
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Kernel not found");
|
||||
}
|
||||
|
||||
Status KernelRegistry::Register(KernelDefBuilder& kernel_builder,
|
||||
const KernelCreateFn& kernel_creator) {
|
||||
return Register(KernelCreateInfo(kernel_builder.Build(), kernel_creator));
|
||||
|
|
|
|||
|
|
@ -84,9 +84,7 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
|
|||
// going to a node that will need a Cast.
|
||||
//
|
||||
// Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32.
|
||||
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph) {
|
||||
bool isolated_fp16_node = false;
|
||||
|
||||
static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
|
||||
// we can check if it's an isolated fp16 node
|
||||
// if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't),
|
||||
// does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node),
|
||||
|
|
@ -96,70 +94,135 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
|
|||
!node.ContainsSubgraph() &&
|
||||
!graph.NodeProducesGraphOutput(node) &&
|
||||
node.GetExecutionProviderType() == kCpuExecutionProvider) {
|
||||
do {
|
||||
// find the number of fp16 inputs as we need to make sure they're all coming from nodes that will be cast
|
||||
const auto& input_defs = node.InputDefs();
|
||||
size_t num_fp16_inputs = std::count_if(input_defs.cbegin(), input_defs.cend(),
|
||||
[](const NodeArg* input_def) {
|
||||
return IsMLFloat16Tensor(*input_def);
|
||||
});
|
||||
//
|
||||
// Three tasks here:
|
||||
// 1. make sure all tensor(float16) inputs and first output coming from or
|
||||
// going to nodes that will be cast to fp32
|
||||
// 2. check the current node is float16 node.
|
||||
// 3. check the current node has a float32 implementation
|
||||
// Only return true when all three are satisfied
|
||||
//
|
||||
const auto* schema = node.Op();
|
||||
if (!schema) {
|
||||
// no way to know whether it is safe to convert this to fp32, give up
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_fp16_inputs == 0) {
|
||||
break;
|
||||
const TypeConstraintMap& type_schema = schema->typeConstraintMap();
|
||||
InlinedHashMap<std::string, MLDataType> type_constraint_map;
|
||||
type_constraint_map.reserve(type_schema.size());
|
||||
|
||||
// For each formal parameters, there might be 0-n
|
||||
// actual inputs, this makes it very tricky to find out which
|
||||
// actual input should map to which formal parameter
|
||||
|
||||
const auto& input_arg_counts = node.InputArgCount();
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& formal_inputs = schema->inputs();
|
||||
const size_t num_inputs = std::min(formal_inputs.size(), input_arg_counts.size());
|
||||
|
||||
InlinedHashSet<int> fp16_args;
|
||||
int input_idx_start = 0;
|
||||
for (size_t formal_idx = 0;
|
||||
formal_idx < num_inputs;
|
||||
input_idx_start += input_arg_counts[formal_idx], formal_idx++) {
|
||||
const auto& type_str = formal_inputs[formal_idx].GetTypeStr();
|
||||
TypeConstraintMap::const_iterator it = type_schema.find(type_str);
|
||||
if (it == type_schema.end()) {
|
||||
// Don't care about parameter that does not have a type constraint.
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t num_fp16_input_edges = 0;
|
||||
// type_str is like T, T1 or T2 ...
|
||||
for (int input_idx = 0; input_idx < input_arg_counts[formal_idx]; input_idx++) {
|
||||
const size_t idx = static_cast<size_t>(input_idx_start) + static_cast<size_t>(input_idx);
|
||||
ORT_ENFORCE(idx < input_defs.size());
|
||||
const NodeArg* input_def = input_defs[idx];
|
||||
if (!input_def || !input_def->Exists()) {
|
||||
continue;
|
||||
}
|
||||
if (IsMLFloat16Tensor(*input_def)) {
|
||||
fp16_args.emplace(static_cast<int>(idx));
|
||||
type_constraint_map[type_str] = DataTypeImpl::GetTensorType<float>();
|
||||
break; // we don't have multiple tensors feeding into one input
|
||||
}
|
||||
type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(input_def->TypeAsProto()));
|
||||
break; // we don't have multiple tensors feeding into one input
|
||||
}
|
||||
}
|
||||
|
||||
// check if all nodes providing our fp16 input need to be cast to fp32
|
||||
for (auto input_edge = node.InputEdgesBegin(), end = node.InputEdgesEnd(); input_edge != end; ++input_edge) {
|
||||
const NodeArg& input_def = *input_defs[input_edge->GetDstArgIndex()];
|
||||
if (fp16_args.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (IsMLFloat16Tensor(input_def)) {
|
||||
// if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16
|
||||
if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) {
|
||||
break;
|
||||
}
|
||||
|
||||
++num_fp16_input_edges;
|
||||
// check if all nodes providing our fp16 input need to be cast to fp32
|
||||
for (auto input_edge = node.InputEdgesBegin(), end = node.InputEdgesEnd(); input_edge != end; ++input_edge) {
|
||||
const int arg_idx = input_edge->GetDstArgIndex();
|
||||
if (fp16_args.find(arg_idx) != fp16_args.end()) {
|
||||
// if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16
|
||||
if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// one or more fp16 inputs are coming from a graph input or initializer
|
||||
if (num_fp16_inputs != num_fp16_input_edges) {
|
||||
break;
|
||||
// if we got here all nodes providing our fp16 input/s will be cast to fp32.
|
||||
// check if the same applies to the nodes consuming our fp16 output.
|
||||
fp16_args.clear();
|
||||
const auto& output_defs = node.OutputDefs();
|
||||
const auto& formal_outputs = schema->outputs();
|
||||
const size_t num_outputs = std::min(formal_outputs.size(), output_defs.size());
|
||||
for (size_t idx = 0; idx < num_outputs; idx++) {
|
||||
const auto& type_str = formal_outputs[idx].GetTypeStr();
|
||||
TypeConstraintMap::const_iterator it = type_schema.find(type_str);
|
||||
if (it == type_schema.end()) {
|
||||
// Don't care about parameter that does not have a type constraint.
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we got here all nodes providing our fp16 input/s will be cast to fp32.
|
||||
// check if the same applies to all nodes consuming our fp16 output.
|
||||
const NodeArg* output_def = output_defs[idx];
|
||||
if (!output_def || !output_def->Exists()) {
|
||||
continue;
|
||||
}
|
||||
if (IsMLFloat16Tensor(*output_def)) {
|
||||
fp16_args.emplace((int)idx);
|
||||
type_constraint_map[type_str] = DataTypeImpl::GetTensorType<float>();
|
||||
} else {
|
||||
type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(output_def->TypeAsProto()));
|
||||
}
|
||||
}
|
||||
|
||||
bool node_has_fp16_output = false;
|
||||
if (fp16_args.empty()) {
|
||||
return false; // no fp16 output
|
||||
}
|
||||
|
||||
for (auto output_edge = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge != end; ++output_edge) {
|
||||
const NodeArg& output_def = *node.OutputDefs()[output_edge->GetSrcArgIndex()];
|
||||
if (IsMLFloat16Tensor(output_def)) {
|
||||
node_has_fp16_output = true;
|
||||
|
||||
// if the node consuming our fp16 output does not need a cast, we should run in fp16
|
||||
if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) {
|
||||
break;
|
||||
}
|
||||
for (auto output_edge = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge != end; ++output_edge) {
|
||||
const int arg_idx = output_edge->GetSrcArgIndex();
|
||||
if (fp16_args.find(arg_idx) != fp16_args.end()) {
|
||||
// if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16
|
||||
if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node_has_fp16_output) {
|
||||
// all nodes providing our fp16 input/s will be cast to fp32, and
|
||||
// we produce one or more fp16 outputs, and all nodes consuming those outputs will be cast to fp32
|
||||
isolated_fp16_node = true;
|
||||
}
|
||||
} while (false);
|
||||
// now all fp16 inputs and outputs would have a cast
|
||||
// make sure fp32 version of the kernel is available.
|
||||
const KernelCreateInfo* kernel_create_info{};
|
||||
const auto lookup_status = cpu_kernel_registry.TryFindKernel(
|
||||
kCpuExecutionProvider, node.OpType(), node.Domain(),
|
||||
node.SinceVersion(), type_constraint_map, &kernel_create_info);
|
||||
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return isolated_fp16_node;
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph) {
|
||||
static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (IsIsolatedFp16NodeOnCpu(node, graph)) {
|
||||
if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
|
||||
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
|
||||
node.SetExecutionProviderType("");
|
||||
}
|
||||
|
|
@ -338,7 +401,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
|
|||
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const {
|
||||
if (force_cpu_fp32_)
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph));
|
||||
ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_));
|
||||
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/framework/kernel_registry_manager.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -16,19 +18,26 @@ Transformer to insert cast node that casts float16 to float for cpu nodes
|
|||
*/
|
||||
class InsertCastTransformer : public onnxruntime::GraphTransformer {
|
||||
public:
|
||||
InsertCastTransformer(const std::string& name)
|
||||
/**
|
||||
* @brief Initializer
|
||||
* @param name for logging purpose
|
||||
* @param cpu_kernel_registry used to query whether an op node can be safely created
|
||||
*/
|
||||
InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry)
|
||||
: onnxruntime::GraphTransformer(name),
|
||||
force_cpu_fp32_(true) {
|
||||
}
|
||||
cpu_kernel_registries_(cpu_kernel_registry),
|
||||
force_cpu_fp32_(cpu_kernel_registry != nullptr) {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const;
|
||||
|
||||
const KernelRegistry* cpu_kernel_registries_;
|
||||
|
||||
// Currently because we only have very few cpu kernels support float16, place those nodes on float16
|
||||
// will introduce many cast between fp32 and fp16, which will slow the execution.
|
||||
// A better solution is to have a cost model to evaluate does it works to place the node on float16.
|
||||
// Here for simplify, we only force the single-node-float16 sub-graph to float32
|
||||
bool force_cpu_fp32_;
|
||||
const bool force_cpu_fp32_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -976,7 +976,13 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
|
|||
|
||||
// Insert cast node/s.
|
||||
{
|
||||
InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer"};
|
||||
const InlinedVector<gsl::not_null<const KernelRegistry*>> kernel_regs =
|
||||
kernel_registry_manager_.GetKernelRegistriesByProviderType(kCpuExecutionProvider);
|
||||
const KernelRegistry* cpu_regs = nullptr;
|
||||
if (!kernel_regs.empty()) {
|
||||
cpu_regs = kernel_regs[0];
|
||||
}
|
||||
InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer", cpu_regs};
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(insert_cast_transformer, *session_logger_, graph));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/util/include/inference_session_wrapper.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
|
|
@ -38,7 +39,7 @@ TEST(TransformerTest, InsertCastGPUTest) {
|
|||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
bool modified = true;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
|
|
@ -86,7 +87,7 @@ TEST(TransformerTest, InsertCastAllCPUTest) {
|
|||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
bool modified = true;
|
||||
EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
|
@ -123,7 +124,7 @@ TEST(TransformerTest, ThreeInARowRemoval) {
|
|||
// we want to remove 2 of the first 3
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 4);
|
||||
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
|
|
@ -146,7 +147,7 @@ TEST(TransformerTest, RandomNormalLikeWithFloat16Inputs) {
|
|||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
|
|
@ -166,7 +167,7 @@ TEST(TransformerTest, MultinomialWithFloat16Input) {
|
|||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
|
|
@ -186,7 +187,7 @@ TEST(TransformerTest, InsertCastNodeTwice) {
|
|||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
// First insert
|
||||
bool modified = false;
|
||||
|
|
@ -279,7 +280,7 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) {
|
|||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
|
||||
|
||||
bool modified = true;
|
||||
EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type,
|
|||
// the function body is instead used for CPU pass. This option allows the comparison with
|
||||
// the CPU kernel by adding the input/output casts before looking for a registered CPU kernel.
|
||||
if (need_cpu_cast) {
|
||||
InsertCastTransformer transformer("Test");
|
||||
InsertCastTransformer transformer("Test", GetExecutionProvider(kCpuExecutionProvider)->GetKernelRegistry().get());
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
|
|
|||
Loading…
Reference in a new issue