mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Apply feedback: use more suitable data types
This commit is contained in:
parent
e6d8a09286
commit
a2c7069fd3
2 changed files with 5 additions and 7 deletions
|
|
@ -38,7 +38,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
|
||||
Softmax);
|
||||
|
||||
static std::string MaxVector(std::string name, int components) {
|
||||
static std::string MaxVector(const std::string& name, int components) {
|
||||
switch (components) {
|
||||
case 1:
|
||||
return name;
|
||||
|
|
@ -53,7 +53,7 @@ static std::string MaxVector(std::string name, int components) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::string SumVector(std::string x, int components) {
|
||||
static std::string SumVector(const std::string& x, int components) {
|
||||
switch (components) {
|
||||
case 1:
|
||||
return x;
|
||||
|
|
@ -184,7 +184,7 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
|
|||
perm[axis] = input_rank - 1;
|
||||
perm[input_rank - 1] = axis;
|
||||
|
||||
std::vector<int64_t> transposed_input_dims;
|
||||
TensorShapeVector transposed_input_dims;
|
||||
for (auto e : perm) {
|
||||
transposed_input_dims.push_back(input_shape[e]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,10 +108,8 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl:
|
|||
output_dims[i] = input_dims[permutations[i]];
|
||||
}
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
InlinedVector<int64_t> new_shape{};
|
||||
InlinedVector<int64_t> new_perm{};
|
||||
TensorShapeVector new_shape{};
|
||||
TensorShapeVector new_perm{};
|
||||
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
|
||||
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
|
||||
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
|
||||
|
|
|
|||
Loading…
Reference in a new issue