mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
clang-format signal_defs.cc (#11767)
This commit is contained in:
parent
750cb42f87
commit
79db92f8fe
1 changed files with 250 additions and 254 deletions
|
|
@ -193,91 +193,88 @@ void RegisterSignalSchemas() {
|
|||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain scalar length types to int64_t.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
|
||||
bool inverse = static_cast<bool>(getAttribute(ctx, "inverse", 0));
|
||||
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
|
||||
bool inverse = static_cast<bool>(getAttribute(ctx, "inverse", 0));
|
||||
|
||||
if (inverse && is_onesided) {
|
||||
fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time");
|
||||
if (inverse && is_onesided) {
|
||||
fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time");
|
||||
}
|
||||
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasInputShape(ctx, 0)) {
|
||||
// If no shape is available for the input, skip shape inference...
|
||||
return;
|
||||
}
|
||||
|
||||
// In general the output shape will match the input shape exactly
|
||||
// So initialize the output shape with the input shape
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape;
|
||||
|
||||
// Get the axis where the DFT will be performed.
|
||||
auto axis = static_cast<int>(getAttribute(ctx, "axis", 1));
|
||||
auto rank = input_shape.dim_size();
|
||||
|
||||
if (!(-rank <= axis && axis < rank)) {
|
||||
fail_shape_inference(
|
||||
"axis attribute value ",
|
||||
axis,
|
||||
" is invalid for a tensor of rank ",
|
||||
rank);
|
||||
}
|
||||
|
||||
auto axis_idx = (axis >= 0 ? axis : axis + rank);
|
||||
|
||||
// If dft_length is specified, then we should honor the shape.
|
||||
// Set the output dimension to match the dft_length on the axis.
|
||||
// If onesided this will be adjusted later on...
|
||||
const ONNX_NAMESPACE::TensorProto* dft_length = nullptr;
|
||||
if (ctx.getNumInputs() >= 2 && ctx.getInputType(1) != nullptr) {
|
||||
dft_length = ctx.getInputData(1);
|
||||
if (dft_length == nullptr) {
|
||||
// If we cannot read the dft_length, we cannot infer shape
|
||||
// return...
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasInputShape(ctx, 0))
|
||||
{
|
||||
// If no shape is available for the input, skip shape inference...
|
||||
return;
|
||||
if (nullptr != dft_length) {
|
||||
if (dft_length->dims_size() != 0) {
|
||||
fail_shape_inference("dft_length input must be a scalar.");
|
||||
}
|
||||
|
||||
// In general the output shape will match the input shape exactly
|
||||
// So initialize the output shape with the input shape
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape;
|
||||
|
||||
// Get the axis where the DFT will be performed.
|
||||
auto axis = static_cast<int>(getAttribute(ctx, "axis", 1));
|
||||
auto rank = input_shape.dim_size();
|
||||
|
||||
if (!(-rank <= axis && axis < rank)) {
|
||||
fail_shape_inference(
|
||||
"axis attribute value ",
|
||||
axis,
|
||||
" is invalid for a tensor of rank ",
|
||||
rank);
|
||||
}
|
||||
|
||||
auto axis_idx = (axis >= 0 ? axis : axis + rank);
|
||||
|
||||
// If dft_length is specified, then we should honor the shape.
|
||||
// Set the output dimension to match the dft_length on the axis.
|
||||
// If onesided this will be adjusted later on...
|
||||
const ONNX_NAMESPACE::TensorProto* dft_length = nullptr;
|
||||
if (ctx.getNumInputs() >= 2 && ctx.getInputType(1) != nullptr) {
|
||||
dft_length = ctx.getInputData(1);
|
||||
if (dft_length == nullptr) {
|
||||
// If we cannot read the dft_length, we cannot infer shape
|
||||
// return...
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr != dft_length) {
|
||||
if (dft_length->dims_size() != 0) {
|
||||
fail_shape_inference("dft_length input must be a scalar.");
|
||||
}
|
||||
auto dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
|
||||
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value);
|
||||
}
|
||||
// When DFT is onesided, the output shape is half the size of the input shape
|
||||
// along the specified axis.
|
||||
if (is_onesided) {
|
||||
auto axis_dimension = result_shape_proto.dim(axis_idx);
|
||||
// We need to update the output shape dimension along the specified axis,
|
||||
// but sometimes the dimension will be a free dimension or be otherwise unset.
|
||||
// Only perform inference when a input dimension value exists.
|
||||
if (axis_dimension.has_dim_value())
|
||||
{
|
||||
auto original_signal_size = axis_dimension.dim_value();
|
||||
auto half_signal_size = (original_signal_size >> 1) + 1;
|
||||
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size);
|
||||
} else
|
||||
{
|
||||
// Clear the value and param (which would otherwie be inherited from the input).
|
||||
result_shape_proto.mutable_dim(axis_idx)->clear_dim_value();
|
||||
result_shape_proto.mutable_dim(axis_idx)->clear_dim_param();
|
||||
}
|
||||
}
|
||||
|
||||
// Coerce the last dimension to 2.
|
||||
auto dim_size = static_cast<int64_t>(result_shape_proto.dim_size());
|
||||
auto has_component_dimension = dim_size > 2;
|
||||
|
||||
// This if check is retained in the contrib op and not the official spec for back compat
|
||||
if (has_component_dimension) {
|
||||
result_shape_proto.mutable_dim(static_cast<int>(dim_size - 1))->set_dim_value(2);
|
||||
auto dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
|
||||
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value);
|
||||
}
|
||||
// When DFT is onesided, the output shape is half the size of the input shape
|
||||
// along the specified axis.
|
||||
if (is_onesided) {
|
||||
auto axis_dimension = result_shape_proto.dim(axis_idx);
|
||||
// We need to update the output shape dimension along the specified axis,
|
||||
// but sometimes the dimension will be a free dimension or be otherwise unset.
|
||||
// Only perform inference when a input dimension value exists.
|
||||
if (axis_dimension.has_dim_value()) {
|
||||
auto original_signal_size = axis_dimension.dim_value();
|
||||
auto half_signal_size = (original_signal_size >> 1) + 1;
|
||||
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size);
|
||||
} else {
|
||||
result_shape_proto.add_dim()->set_dim_value(2);
|
||||
// Clear the value and param (which would otherwie be inherited from the input).
|
||||
result_shape_proto.mutable_dim(axis_idx)->clear_dim_value();
|
||||
result_shape_proto.mutable_dim(axis_idx)->clear_dim_param();
|
||||
}
|
||||
}
|
||||
|
||||
updateOutputShape(ctx, 0, result_shape_proto);
|
||||
// Coerce the last dimension to 2.
|
||||
auto dim_size = static_cast<int64_t>(result_shape_proto.dim_size());
|
||||
auto has_component_dimension = dim_size > 2;
|
||||
|
||||
// This if check is retained in the contrib op and not the official spec for back compat
|
||||
if (has_component_dimension) {
|
||||
result_shape_proto.mutable_dim(static_cast<int>(dim_size - 1))->set_dim_value(2);
|
||||
} else {
|
||||
result_shape_proto.add_dim()->set_dim_value(2);
|
||||
}
|
||||
|
||||
updateOutputShape(ctx, 0, result_shape_proto);
|
||||
});
|
||||
|
||||
MS_SIGNAL_OPERATOR_SCHEMA(IDFT)
|
||||
|
|
@ -319,29 +316,29 @@ void RegisterSignalSchemas() {
|
|||
1,
|
||||
OpSchema::NonDifferentiable)
|
||||
.TypeConstraint(
|
||||
"T1",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
"T1",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
{"tensor(int64)"},
|
||||
"Constrain scalar length types to int64_t.")
|
||||
"T2",
|
||||
{"tensor(int64)"},
|
||||
"Constrain scalar length types to int64_t.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
const int64_t batch_ndim = 1;
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
const int64_t batch_ndim = 1;
|
||||
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape;
|
||||
auto dim_size = static_cast<int64_t>(input_shape.dim_size());
|
||||
auto has_component_dimension = dim_size > 2;
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape;
|
||||
auto dim_size = static_cast<int64_t>(input_shape.dim_size());
|
||||
auto has_component_dimension = dim_size > 2;
|
||||
|
||||
if (has_component_dimension) {
|
||||
result_shape.mutable_dim(static_cast<int>(dim_size - 1))->set_dim_value(2);
|
||||
} else {
|
||||
result_shape.add_dim()->set_dim_value(2);
|
||||
}
|
||||
if (has_component_dimension) {
|
||||
result_shape.mutable_dim(static_cast<int>(dim_size - 1))->set_dim_value(2);
|
||||
} else {
|
||||
result_shape.add_dim()->set_dim_value(2);
|
||||
}
|
||||
|
||||
updateOutputShape(ctx, 0, result_shape);
|
||||
updateOutputShape(ctx, 0, result_shape);
|
||||
});
|
||||
|
||||
MS_SIGNAL_OPERATOR_SCHEMA(STFT)
|
||||
|
|
@ -349,21 +346,22 @@ void RegisterSignalSchemas() {
|
|||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(STFT)DOC")
|
||||
.Attr(
|
||||
"onesided",
|
||||
"If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because "
|
||||
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*. "
|
||||
"Note if the input or window tensors are complex, then onesided output is not possible. "
|
||||
"Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)."
|
||||
"When invoked with real or complex valued input, the default value is 1. "
|
||||
"Values can be 0 or 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
"onesided",
|
||||
"If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because "
|
||||
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w] = "
|
||||
"X[m,n_fft-w]*. Note if the input or window tensors are complex, then onesided output is not possible. "
|
||||
"Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)."
|
||||
"When invoked with real or complex valued input, the default value is 1. "
|
||||
"Values can be 0 or 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Input(0,
|
||||
"signal",
|
||||
"Input tensor representing a real or complex valued signal. "
|
||||
"For real input, the following shape is expected: [batch_size][signal_length][1]. "
|
||||
"For complex input, the following shape is expected: [batch_size][signal_length][2], where "
|
||||
"[batch_size][signal_length][0] represents the real component and [batch_size][signal_length][1] represents the imaginary component of the signal.",
|
||||
"[batch_size][signal_length][0] represents the real component and [batch_size][signal_length][1] "
|
||||
"represents the imaginary component of the signal.",
|
||||
"T1",
|
||||
OpSchema::Single,
|
||||
true,
|
||||
|
|
@ -399,8 +397,10 @@ void RegisterSignalSchemas() {
|
|||
.Output(0,
|
||||
"output",
|
||||
"The Short-time Fourier Transform of the signals."
|
||||
"If onesided is 1, the output has the shape: [batch_size][frames][dft_unique_bins][2], where dft_unique_bins is frame_length // 2 + 1 (the unique components of the DFT) "
|
||||
"If onesided is 0, the output has the shape: [batch_size][frames][frame_length][2], where frame_length is the length of the DFT.",
|
||||
"If onesided is 1, the output has the shape: [batch_size][frames][dft_unique_bins][2], where "
|
||||
"dft_unique_bins is frame_length // 2 + 1 (the unique components of the DFT) "
|
||||
"If onesided is 0, the output has the shape: [batch_size][frames][frame_length][2], where frame_length "
|
||||
"is the length of the DFT.",
|
||||
"T1",
|
||||
OpSchema::Single,
|
||||
true,
|
||||
|
|
@ -409,141 +409,136 @@ void RegisterSignalSchemas() {
|
|||
.TypeConstraint(
|
||||
"T1",
|
||||
{"tensor(float)",
|
||||
"tensor(float16)",
|
||||
"tensor(double)",
|
||||
"tensor(bfloat16)"},
|
||||
"tensor(float16)",
|
||||
"tensor(double)",
|
||||
"tensor(bfloat16)"},
|
||||
"Constrain signal and output to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain scalar length types to int64_t.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
// Get signal size
|
||||
// The signal size is needed to perform inference because the size of the signal
|
||||
// is needed to compute the number of DFTs in the output.
|
||||
//
|
||||
// 1) Check if shape exists, return if not
|
||||
// 2) Get the shape
|
||||
// 3) Check if signal dim value exists, return if not
|
||||
if (!hasInputShape(ctx, 0)) {
|
||||
return;
|
||||
}
|
||||
// Get signal size
|
||||
// The signal size is needed to perform inference because the size of the signal
|
||||
// is needed to compute the number of DFTs in the output.
|
||||
//
|
||||
// 1) Check if shape exists, return if not
|
||||
// 2) Get the shape
|
||||
// 3) Check if signal dim value exists, return if not
|
||||
if (!hasInputShape(ctx, 0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
auto signal_dim = input_shape.dim(1);
|
||||
if (!signal_dim.has_dim_value())
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto signal_size = signal_dim.dim_value();
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
auto signal_dim = input_shape.dim(1);
|
||||
if (!signal_dim.has_dim_value()) {
|
||||
return;
|
||||
}
|
||||
auto signal_size = signal_dim.dim_value();
|
||||
|
||||
// The frame step is a required input.
|
||||
// Its value is needed to compute the number output nDFTs, so return early is missing.
|
||||
const auto* frame_step = ctx.getInputData(1);
|
||||
if (nullptr == frame_step) {
|
||||
return;
|
||||
}
|
||||
auto frame_step_value = get_scalar_value_from_tensor<int64_t>(frame_step);
|
||||
// The frame step is a required input.
|
||||
// Its value is needed to compute the number output nDFTs, so return early is missing.
|
||||
const auto* frame_step = ctx.getInputData(1);
|
||||
if (nullptr == frame_step) {
|
||||
return;
|
||||
}
|
||||
auto frame_step_value = get_scalar_value_from_tensor<int64_t>(frame_step);
|
||||
|
||||
// Determine the size of the DFT based on the 2 optional inputs window and frame_length.
|
||||
// One must be set.
|
||||
int64_t dft_size = -1;
|
||||
const ONNX_NAMESPACE::TensorProto* frame_length = nullptr;
|
||||
if (ctx.getNumInputs() >= 4 && ctx.getInputType(3) != nullptr) {
|
||||
frame_length = ctx.getInputData(3);
|
||||
if (frame_length == nullptr) {
|
||||
// If we cannot read the frame_length, we cannot infer shape
|
||||
// return...
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Determine the size of the DFT based on the 2 optional inputs window and frame_length.
|
||||
// One must be set.
|
||||
int64_t dft_size = -1;
|
||||
const ONNX_NAMESPACE::TensorProto* frame_length = nullptr;
|
||||
if (ctx.getNumInputs() >= 4 && ctx.getInputType(3) != nullptr) {
|
||||
frame_length = ctx.getInputData(3);
|
||||
if (frame_length == nullptr) {
|
||||
// If we cannot read the frame_length, we cannot infer shape
|
||||
// return...
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr;
|
||||
if (ctx.getNumInputs() >= 3) {
|
||||
window_shape = getOptionalInputShape(ctx, 2);
|
||||
} else {
|
||||
window_shape = nullptr;
|
||||
}
|
||||
const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr;
|
||||
if (ctx.getNumInputs() >= 3) {
|
||||
window_shape = getOptionalInputShape(ctx, 2);
|
||||
} else {
|
||||
window_shape = nullptr;
|
||||
}
|
||||
|
||||
if (window_shape == nullptr && frame_length == nullptr)
|
||||
{
|
||||
// STFT expects to have at least one of these inputs set: [window, frame_length],
|
||||
// but they may not be available at shape inference time
|
||||
return;
|
||||
} else if (window_shape != nullptr && frame_length != nullptr)
|
||||
{
|
||||
if (frame_length->dims_size() != 0) {
|
||||
fail_shape_inference("frame_length input must be scalar.");
|
||||
}
|
||||
auto frame_length_value = get_scalar_value_from_tensor<int64_t>(frame_length);
|
||||
if (window_shape == nullptr && frame_length == nullptr) {
|
||||
// STFT expects to have at least one of these inputs set: [window, frame_length],
|
||||
// but they may not be available at shape inference time
|
||||
return;
|
||||
} else if (window_shape != nullptr && frame_length != nullptr) {
|
||||
if (frame_length->dims_size() != 0) {
|
||||
fail_shape_inference("frame_length input must be scalar.");
|
||||
}
|
||||
auto frame_length_value = get_scalar_value_from_tensor<int64_t>(frame_length);
|
||||
|
||||
// Ensure that the window length and the dft_length match.
|
||||
if (window_shape->dim_size() != 1) {
|
||||
fail_shape_inference("window input must have rank = 1.");
|
||||
}
|
||||
if (window_shape->dim(0).has_dim_value())
|
||||
{
|
||||
auto window_length = window_shape->dim(0).dim_value();
|
||||
if (window_length != frame_length_value)
|
||||
{
|
||||
fail_type_inference("If STFT has both a window input and frame_length specified, the dimension of the window must match the frame_length specified!");
|
||||
}
|
||||
}
|
||||
// Ensure that the window length and the dft_length match.
|
||||
if (window_shape->dim_size() != 1) {
|
||||
fail_shape_inference("window input must have rank = 1.");
|
||||
}
|
||||
if (window_shape->dim(0).has_dim_value()) {
|
||||
auto window_length = window_shape->dim(0).dim_value();
|
||||
if (window_length != frame_length_value) {
|
||||
fail_type_inference(
|
||||
"If STFT has both a window input and frame_length specified, the dimension of the "
|
||||
"window must match the frame_length specified!");
|
||||
}
|
||||
}
|
||||
|
||||
dft_size = frame_length_value;
|
||||
} else if (window_shape != nullptr)
|
||||
{
|
||||
// Ensure that the window length and the dft_length match.
|
||||
if (window_shape->dim_size() != 1) {
|
||||
fail_shape_inference("window input must have rank = 1.");
|
||||
}
|
||||
if (window_shape->dim(0).has_dim_value()) {
|
||||
dft_size = window_shape->dim(0).dim_value();
|
||||
} else {
|
||||
// Cannot determine the window size, and there is no frame_length,
|
||||
// So shape inference cannot proceed.
|
||||
return;
|
||||
}
|
||||
} else if (frame_length != nullptr)
|
||||
{
|
||||
if (frame_length->dims_size() != 0) {
|
||||
fail_shape_inference("frame_length input must be scalar.");
|
||||
}
|
||||
dft_size = get_scalar_value_from_tensor<int64_t>(frame_length);
|
||||
}
|
||||
dft_size = frame_length_value;
|
||||
} else if (window_shape != nullptr) {
|
||||
// Ensure that the window length and the dft_length match.
|
||||
if (window_shape->dim_size() != 1) {
|
||||
fail_shape_inference("window input must have rank = 1.");
|
||||
}
|
||||
if (window_shape->dim(0).has_dim_value()) {
|
||||
dft_size = window_shape->dim(0).dim_value();
|
||||
} else {
|
||||
// Cannot determine the window size, and there is no frame_length,
|
||||
// So shape inference cannot proceed.
|
||||
return;
|
||||
}
|
||||
} else if (frame_length != nullptr) {
|
||||
if (frame_length->dims_size() != 0) {
|
||||
fail_shape_inference("frame_length input must be scalar.");
|
||||
}
|
||||
dft_size = get_scalar_value_from_tensor<int64_t>(frame_length);
|
||||
}
|
||||
|
||||
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
|
||||
if (is_onesided) {
|
||||
dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size;
|
||||
}
|
||||
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
|
||||
if (is_onesided) {
|
||||
dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size;
|
||||
}
|
||||
|
||||
auto n_dfts = static_cast<int64_t>((signal_size - dft_size) / static_cast<float>(frame_step_value)) + 1;
|
||||
auto n_dfts = static_cast<int64_t>((signal_size - dft_size) / static_cast<float>(frame_step_value)) + 1;
|
||||
|
||||
// The output has the following shape: [batch_size][frames][dft_unique_bins][2]
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape_proto;
|
||||
result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size
|
||||
result_shape_proto.add_dim()->set_dim_value(n_dfts);
|
||||
result_shape_proto.add_dim()->set_dim_value(dft_size);
|
||||
result_shape_proto.add_dim()->set_dim_value(2);
|
||||
updateOutputShape(ctx, 0, result_shape_proto);
|
||||
});
|
||||
// The output has the following shape: [batch_size][frames][dft_unique_bins][2]
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape_proto;
|
||||
result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size
|
||||
result_shape_proto.add_dim()->set_dim_value(n_dfts);
|
||||
result_shape_proto.add_dim()->set_dim_value(dft_size);
|
||||
result_shape_proto.add_dim()->set_dim_value(2);
|
||||
updateOutputShape(ctx, 0, result_shape_proto);
|
||||
});
|
||||
|
||||
// Window Functions
|
||||
MS_SIGNAL_OPERATOR_SCHEMA(HannWindow)
|
||||
.SetDomain(kMSExperimentalDomain)
|
||||
.SinceVersion(1)
|
||||
.FillUsing(CosineSumWindowOpDocGenerator("Hann"))
|
||||
.TypeConstraint(
|
||||
.TypeConstraint(
|
||||
"T1",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain the input size to int64_t.")
|
||||
.TypeConstraint(
|
||||
"Constrain the input size to int64_t.")
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(),
|
||||
"Constrain output types to numeric tensors.")
|
||||
.FunctionBody(R"ONNX(
|
||||
.FunctionBody(R"ONNX(
|
||||
{
|
||||
A0 = Constant <value = float {0.5}>()
|
||||
A1 = Constant <value = float {0.5}>()
|
||||
|
|
@ -565,22 +560,21 @@ void RegisterSignalSchemas() {
|
|||
Temp1 = Sub (A0, Temp0)
|
||||
output = Cast <to : int = @output_datatype> (Temp1)
|
||||
}
|
||||
)ONNX"
|
||||
);
|
||||
)ONNX");
|
||||
|
||||
MS_SIGNAL_OPERATOR_SCHEMA(HammingWindow)
|
||||
.SetDomain(kMSExperimentalDomain)
|
||||
.SinceVersion(1)
|
||||
.FillUsing(CosineSumWindowOpDocGenerator("Hamming"))
|
||||
.TypeConstraint(
|
||||
.TypeConstraint(
|
||||
"T1",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain the input size to int64_t.")
|
||||
.TypeConstraint(
|
||||
"Constrain the input size to int64_t.")
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(),
|
||||
"Constrain output types to numeric tensors.")
|
||||
.FunctionBody(R"ONNX(
|
||||
.FunctionBody(R"ONNX(
|
||||
{
|
||||
A0 = Constant <value = float {0.54347826087}>()
|
||||
A1 = Constant <value = float {0.45652173913}>()
|
||||
|
|
@ -602,22 +596,21 @@ void RegisterSignalSchemas() {
|
|||
Temp1 = Sub (A0, Temp0)
|
||||
output = Cast <to : int = @output_datatype> (Temp1)
|
||||
}
|
||||
)ONNX"
|
||||
);
|
||||
)ONNX");
|
||||
|
||||
MS_SIGNAL_OPERATOR_SCHEMA(BlackmanWindow)
|
||||
.SetDomain(kMSExperimentalDomain)
|
||||
.SinceVersion(1)
|
||||
.FillUsing(CosineSumWindowOpDocGenerator("Blackman"))
|
||||
.TypeConstraint(
|
||||
.TypeConstraint(
|
||||
"T1",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain the input size to int64_t.")
|
||||
.TypeConstraint(
|
||||
"Constrain the input size to int64_t.")
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(),
|
||||
"Constrain output types to numeric tensors.")
|
||||
.FunctionBody(R"ONNX(
|
||||
.FunctionBody(R"ONNX(
|
||||
{
|
||||
A0 = Constant <value = float {0.42}>()
|
||||
A1 = Constant <value = float {0.5}>()
|
||||
|
|
@ -639,18 +632,20 @@ void RegisterSignalSchemas() {
|
|||
Temp1 = Sub (A0, Temp0)
|
||||
output = Cast <to : int = @output_datatype> (Temp1)
|
||||
}
|
||||
)ONNX"
|
||||
);
|
||||
)ONNX");
|
||||
|
||||
static const char* MelWeightMatrix_ver17_doc = R"DOC(
|
||||
Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra (from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on the mel scale.
|
||||
static const char* MelWeightMatrix_ver17_doc = R"DOC(
|
||||
Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra
|
||||
(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range
|
||||
on the mel scale.
|
||||
This function defines the mel scale in terms of a frequency in hertz according to the following formula:
|
||||
|
||||
mel(f) = 2595 * log10(1 + f/700)
|
||||
|
||||
In the returned matrix, all the triangles (filterbanks) have a peak value of 1.0.
|
||||
|
||||
The returned MelWeightMatrix can be used to right-multiply a spectrogram S of shape [frames, num_spectrogram_bins] of linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram" M of shape [frames, num_mel_bins].
|
||||
The returned MelWeightMatrix can be used to right-multiply a spectrogram S of shape [frames, num_spectrogram_bins] of
|
||||
linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram" M of shape [frames, num_mel_bins].
|
||||
)DOC";
|
||||
|
||||
MS_SIGNAL_OPERATOR_SCHEMA(MelWeightMatrix)
|
||||
|
|
@ -687,56 +682,57 @@ The returned MelWeightMatrix can be used to right-multiply a spectrogram S of sh
|
|||
"The MEL Matrix",
|
||||
"T3")
|
||||
.TypeConstraint(
|
||||
"T1",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain to integer tensors.")
|
||||
"T1",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain to integer tensors.")
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
{"tensor(float)",
|
||||
"tensor(float16)",
|
||||
"tensor(double)",
|
||||
"tensor(bfloat16)"},
|
||||
"Constrain to float tensors")
|
||||
"T2",
|
||||
{"tensor(float)",
|
||||
"tensor(float16)",
|
||||
"tensor(double)",
|
||||
"tensor(bfloat16)"},
|
||||
"Constrain to float tensors")
|
||||
.TypeConstraint(
|
||||
"T3",
|
||||
ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(),
|
||||
"Constrain to any numerical types.")
|
||||
"T3",
|
||||
ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(),
|
||||
"Constrain to any numerical types.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
auto output_datatype = getAttribute(ctx, "output_datatype", static_cast<int64_t>(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT));
|
||||
auto output_datatype = getAttribute(
|
||||
ctx, "output_datatype", static_cast<int64_t>(onnx::TensorProto::DataType::TensorProto_DataType_FLOAT));
|
||||
updateOutputElemType(ctx, 0, static_cast<int32_t>(output_datatype));
|
||||
|
||||
if (!hasInputShape(ctx, 0) || !hasInputShape(ctx, 1)) {
|
||||
return;
|
||||
return;
|
||||
}
|
||||
|
||||
const auto* num_mel_bins = ctx.getInputData(0);
|
||||
const auto* dft_length = ctx.getInputData(1);
|
||||
if (nullptr == num_mel_bins || nullptr == dft_length) {
|
||||
return;
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t num_mel_bins_value = -1;
|
||||
int64_t dft_length_value = -1;
|
||||
if (num_mel_bins->dims_size() != 0) {
|
||||
fail_shape_inference("num_mel_bins input must be scalar.");
|
||||
fail_shape_inference("num_mel_bins input must be scalar.");
|
||||
}
|
||||
num_mel_bins_value = get_scalar_value_from_tensor<int64_t>(num_mel_bins);
|
||||
|
||||
if (dft_length->dims_size() != 0) {
|
||||
fail_shape_inference("dft_length input must be scalar.");
|
||||
fail_shape_inference("dft_length input must be scalar.");
|
||||
}
|
||||
dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
|
||||
|
||||
if (num_mel_bins_value > 0 && dft_length_value > 0) {
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape;
|
||||
result_shape.add_dim()->set_dim_value(static_cast<int64_t>((dft_length_value >> 1) + 1));
|
||||
result_shape.add_dim()->set_dim_value(num_mel_bins_value);
|
||||
updateOutputShape(ctx, 0, result_shape);
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape;
|
||||
result_shape.add_dim()->set_dim_value(static_cast<int64_t>((dft_length_value >> 1) + 1));
|
||||
result_shape.add_dim()->set_dim_value(num_mel_bins_value);
|
||||
updateOutputShape(ctx, 0, result_shape);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace audio
|
||||
} // namespace signal
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue