[webgpu] Use override shape in shader key (#23188)

### Description
This PR 1) uses override shape instead of tensor original shape in
shader key to reduce some shader variants; 2) adds indices shape rank to
shader key in case some potential errors.
This commit is contained in:
Jiajia Qin 2025-01-08 07:36:02 +08:00 committed by GitHub
parent 519fae019b
commit 4883ec50c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 22 additions and 12 deletions

View file

@ -193,10 +193,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
.AddIndices(reshaped_output_shape)
.AddIndices(reshaped_lhs_shape)
.AddIndices(reshaped_rhs_shape)
.CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(),
reshaped_rhs_shape.NumDimensions(),
reshaped_output_shape.NumDimensions()},
";"));
.CacheHint("V");
} else {
// Mode Broadcast
// cache hint: "B"

View file

@ -17,7 +17,7 @@ namespace webgpu {
namespace {
// append the info of an input or output to the cachekey
void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
bool& first) {
if (first) {
first = false;
@ -35,9 +35,9 @@ void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDat
}
if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) {
ss D("Dims=") << tensor.Shape().ToString();
ss D("Dims=") << tensor_shape.ToString();
} else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) {
ss D("Rank=") << tensor.Shape().NumDimensions();
ss D("Rank=") << tensor_shape.NumDimensions();
}
}
} // namespace
@ -97,13 +97,26 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp
ss << ":" D("Inputs=");
first = true;
for (const auto& input : program.Inputs()) {
AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first);
AppendTensorInfo(ss, input.use_override_shape ? input.override_shape : input.tensor->Shape(), input.var_type, input.dependency, first);
}
ss << ":" D("Outputs=");
first = true;
for (const auto& output : program.Outputs()) {
AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first);
AppendTensorInfo(ss, output.use_override_shape ? output.override_shape : output.tensor->Shape(), output.var_type, output.dependency, first);
}
if (!program.Indices().empty()) {
ss << ":" D("Indices=");
first = true;
for (const auto& indices_shape : program.Indices()) {
if (first) {
first = false;
} else {
ss << '|';
}
ss D("Rank=") << indices_shape.NumDimensions();
}
}
return SS_GET(ss);

View file

@ -134,9 +134,9 @@ Status Where::ComputeInternal(ComputeContext& context) const {
program
.CacheHint(is_broadcast)
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}})
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Type, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::Type, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::Type, {(y_shape.Size() + 3) / 4}, 4}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddUniformVariables({
{static_cast<uint32_t>(vec_size)},