Apply feedback: use more suitable data types

This commit is contained in:
vraspar 2025-02-04 12:19:39 -08:00
parent e6d8a09286
commit a2c7069fd3
2 changed files with 5 additions and 7 deletions

View file

@ -38,7 +38,7 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);
static std::string MaxVector(std::string name, int components) {
static std::string MaxVector(const std::string& name, int components) {
switch (components) {
case 1:
return name;
@ -53,7 +53,7 @@ static std::string MaxVector(std::string name, int components) {
}
}
static std::string SumVector(std::string x, int components) {
static std::string SumVector(const std::string& x, int components) {
switch (components) {
case 1:
return x;
@ -184,7 +184,7 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
perm[axis] = input_rank - 1;
perm[input_rank - 1] = axis;
std::vector<int64_t> transposed_input_dims;
TensorShapeVector transposed_input_dims;
for (auto e : perm) {
transposed_input_dims.push_back(input_shape[e]);
}

View file

@ -108,10 +108,8 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl:
output_dims[i] = input_dims[permutations[i]];
}
TensorShape output_shape(output_dims);
InlinedVector<int64_t> new_shape{};
InlinedVector<int64_t> new_perm{};
TensorShapeVector new_shape{};
TensorShapeVector new_perm{};
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});