Fixes to get stable diffusion benchmark running (#15755)

### Description

Added changes to MIGraphX EP to suppoert stable diffusion

1. Added parameterized input dimensions to not trigger a precompile to
set input parameters in the EP
2. Removed input checking for Resize operator in EP as MIGraphX already
performs these checks
3. Add support to benchmark script to use the MIGraphX execution
provider
4. Add support for an odd valued batch size (3) that was seen on other
benchmarks we were performing comparison on.

### Motivation and Context

These changes are required to get stable diffusion mdoels to run on
MIGraphX through the EP. Without these changes we see the following
incorrect behavior.

1. Resize operators are pushed onto the CPU EP instead of MIGraphX,
causing a significant slowdown during runs
2. Precompile operations incorrectly parse input_ids parameter for our
text model, with a 1, which breaks during MIGraphX Compile of onnx. This
in turn throws an error and stops any setup before inference.
3. Selecting the correct EP in the benchmark script which was previously
missing the MIGraphX option
5. Suppressed an error we keep seeing with pthread_set_affinity - this
is a quality of life change when using the MIGraphX EP

This was testing with the benchmark.py script using stable diffusion v2
located in

onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion/

---------

Co-authored-by: Ted Themistokleous <tthemist@amd.com>
This commit is contained in:
Ted Themistokleous 2023-05-06 05:35:21 -04:00 committed by GitHub
parent 41457885e0
commit 42d62b8f2b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 13 deletions

View file

@ -248,11 +248,13 @@ class PosixThread : public EnvThread {
<< ", mask: " << *p->affinity;
} else {
auto [err_no, err_msg] = GetSystemError(ret);
#if !defined(USE_MIGRAPHX)
LOGS_DEFAULT(ERROR) << "pthread_setaffinity_np failed for thread: " << syscall(SYS_gettid)
<< ", index: " << p->index
<< ", mask: " << *p->affinity
<< ", error code: " << err_no << " error msg: " << err_msg
<< ". Specify the number of threads explicitly so the affinity is not set.";
#endif
}
}
#endif

View file

@ -451,15 +451,6 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
}
}
const auto& args = node->InputDefs();
if (args.size() > 1) {
std::vector<std::size_t> indices(args.size() - 1);
std::iota(indices.begin(), indices.end(), 1);
if (canEvalNodeArgument(graph_viewer, node, indices, input_nodes)) {
return false;
}
return true;
}
} else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
@ -952,8 +943,15 @@ bool get_input_output_names(const GraphViewer& graph,
if (sptr == nullptr)
return true;
auto dim_size = sptr->dim_size();
return (dim_size == 0);
if (sptr->dim_size() == 0)
return true;
for (int i = 0; i < sptr->dim_size(); i++) {
if (sptr->dim(i).has_dim_param())
return true;
}
return false;
});
const auto& out_args = graph.GetOutputs();
@ -1002,7 +1000,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
}
std::vector<std::string> input_names, output_names;
no_input_shape = no_input_shape or get_input_output_names(graph_body_viewer, input_names, output_names);
no_input_shape = get_input_output_names(graph_body_viewer, input_names, output_names);
// by parsing the model_proto, create a program corresponding to
// the input fused_node

View file

@ -19,6 +19,7 @@ SD_MODELS = {
PROVIDERS = {
"cuda": "CUDAExecutionProvider",
"rocm": "ROCMExecutionProvider",
"migraphx": "MIGraphXExecutionProvider",
}
@ -570,7 +571,7 @@ def parse_arguments():
"--batch_size",
type=int,
default=1,
choices=[1, 2, 4, 8, 10, 16, 32],
choices=[1, 2, 3, 4, 8, 10, 16, 32],
help="Number of images per batch. Default is 1.",
)