mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
* move all contrib ops to one place * namespace changes * bug fix - remove redundant file after merge master * plus more minor bug fixes * bug fix * fix extra space in include header + namespace fix * fix linux build failure: * fix test group names * remove redundant test
42 lines
1,016 B
C++
42 lines
1,016 B
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "core/common/common.h"
|
|
#include "core/framework/op_kernel.h"
|
|
#include "core/util/math_cpuonly.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace contrib {
|
|
template <typename T>
|
|
class Affine final : public OpKernel {
|
|
public:
|
|
Affine(const OpKernelInfo& info) : OpKernel(info) {
|
|
// Either model-supplied or default values should be returned for alpha and beta
|
|
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
|
|
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
|
|
}
|
|
|
|
Status Compute(OpKernelContext* context) const override;
|
|
|
|
private:
|
|
float alpha_;
|
|
float beta_;
|
|
};
|
|
|
|
template <typename T>
|
|
class Scale final : public OpKernel {
|
|
public:
|
|
Scale(const OpKernelInfo& info) : OpKernel(info) {
|
|
ORT_ENFORCE(info.GetAttr("scale", &scale_).IsOK());
|
|
}
|
|
|
|
Status Compute(OpKernelContext* context) const override;
|
|
|
|
private:
|
|
float scale_;
|
|
};
|
|
|
|
} // namespace contrib
|
|
} // namespace onnxruntime
|