mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
[webgpu] Use override shape in shader key (#23188)
### Description This PR 1) uses override shape instead of tensor original shape in shader key to reduce some shader variants; 2) adds indices shape rank to shader key in case some potential errors.
This commit is contained in:
parent
519fae019b
commit
4883ec50c4
3 changed files with 22 additions and 12 deletions
|
|
@ -193,10 +193,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
|
|||
.AddIndices(reshaped_output_shape)
|
||||
.AddIndices(reshaped_lhs_shape)
|
||||
.AddIndices(reshaped_rhs_shape)
|
||||
.CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(),
|
||||
reshaped_rhs_shape.NumDimensions(),
|
||||
reshaped_output_shape.NumDimensions()},
|
||||
";"));
|
||||
.CacheHint("V");
|
||||
} else {
|
||||
// Mode Broadcast
|
||||
// cache hint: "B"
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ namespace webgpu {
|
|||
|
||||
namespace {
|
||||
// append the info of an input or output to the cachekey
|
||||
void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
|
||||
void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
|
||||
bool& first) {
|
||||
if (first) {
|
||||
first = false;
|
||||
|
|
@ -35,9 +35,9 @@ void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDat
|
|||
}
|
||||
|
||||
if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) {
|
||||
ss D("Dims=") << tensor.Shape().ToString();
|
||||
ss D("Dims=") << tensor_shape.ToString();
|
||||
} else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) {
|
||||
ss D("Rank=") << tensor.Shape().NumDimensions();
|
||||
ss D("Rank=") << tensor_shape.NumDimensions();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
|
@ -97,13 +97,26 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp
|
|||
ss << ":" D("Inputs=");
|
||||
first = true;
|
||||
for (const auto& input : program.Inputs()) {
|
||||
AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first);
|
||||
AppendTensorInfo(ss, input.use_override_shape ? input.override_shape : input.tensor->Shape(), input.var_type, input.dependency, first);
|
||||
}
|
||||
|
||||
ss << ":" D("Outputs=");
|
||||
first = true;
|
||||
for (const auto& output : program.Outputs()) {
|
||||
AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first);
|
||||
AppendTensorInfo(ss, output.use_override_shape ? output.override_shape : output.tensor->Shape(), output.var_type, output.dependency, first);
|
||||
}
|
||||
|
||||
if (!program.Indices().empty()) {
|
||||
ss << ":" D("Indices=");
|
||||
first = true;
|
||||
for (const auto& indices_shape : program.Indices()) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
ss << '|';
|
||||
}
|
||||
ss D("Rank=") << indices_shape.NumDimensions();
|
||||
}
|
||||
}
|
||||
|
||||
return SS_GET(ss);
|
||||
|
|
|
|||
|
|
@ -134,9 +134,9 @@ Status Where::ComputeInternal(ComputeContext& context) const {
|
|||
program
|
||||
.CacheHint(is_broadcast)
|
||||
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
|
||||
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4},
|
||||
{x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4},
|
||||
{y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}})
|
||||
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Type, {(cond_shape.Size() + 3) / 4}, 4},
|
||||
{x_tensor, ProgramTensorMetadataDependency::Type, {(x_shape.Size() + 3) / 4}, 4},
|
||||
{y_tensor, ProgramTensorMetadataDependency::Type, {(y_shape.Size() + 3) / 4}, 4}})
|
||||
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
|
||||
.AddUniformVariables({
|
||||
{static_cast<uint32_t>(vec_size)},
|
||||
|
|
|
|||
Loading…
Reference in a new issue