mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add GridSample implementation to DirectML (#15788)
Add GridSample implementation to DirectML EP. Temporary add HLSL shader in the DirectML EP to handle GridSample until officially added to DirectML.
This commit is contained in:
parent
45f5c27632
commit
2b7f26af7c
39 changed files with 218019 additions and 18 deletions
|
|
@ -955,6 +955,7 @@ Do not modify directly.*
|
|||
|||7+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(bool)|
|
||||
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|
||||
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|
||||
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|
||||
|Hardmax|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|||11+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -661,9 +661,10 @@ namespace Dml
|
|||
|
||||
bool IsCustomOpShader(const onnxruntime::Node& node)
|
||||
{
|
||||
auto custom_ops = std::array<char*, 2>{
|
||||
auto custom_ops = std::array<char*, 3>{
|
||||
"DFT",
|
||||
"STFT"
|
||||
"STFT",
|
||||
"GridSample"
|
||||
};
|
||||
|
||||
for (auto& custom_op : custom_ops)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,988 @@
|
|||
#pragma once
|
||||
|
||||
#include "../../../OperatorAuthorHelper/OperatorHelper.h"
|
||||
#include "../MLOperatorAuthorImpl.h"
|
||||
|
||||
#include "../External/D3DX12/d3dx12.h"
|
||||
#include <d3d12.h>
|
||||
|
||||
// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
|
||||
// should be removed from IsCustomOpShader(...) in
|
||||
// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
|
||||
|
||||
// The shader headers are produced using "GeneratedShaders/GenerateShaders.bat"
|
||||
namespace GridSample_uint16_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint16_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint64_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint64_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int16_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int16_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int64_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int64_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_fp16_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_fp16_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_float_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_float_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_double_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_double_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_bool_float
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_bool_float.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint16_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint16_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint64_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint64_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int16_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int16_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int64_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int64_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_fp16_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_fp16_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_float_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_float_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_double_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_double_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_bool_fp16
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_bool_fp16.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint16_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint16_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_uint64_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_uint64_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int16_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int16_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_int64_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_int64_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_fp16_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_fp16_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_float_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_float_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_double_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_double_double.h"
|
||||
}
|
||||
|
||||
namespace GridSample_bool_double
|
||||
{
|
||||
#include "GeneratedShaders/grid_sample_bool_double.h"
|
||||
}
|
||||
|
||||
|
||||
#include <wrl/client.h>
|
||||
#include <wrl/implements.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
using namespace Microsoft::WRL;
|
||||
|
||||
enum DmlGridSampleKernelInputIndex : uint32_t
|
||||
{
|
||||
Input,
|
||||
Grid,
|
||||
};
|
||||
|
||||
enum DmlGridSampleMode : uint32_t
|
||||
{
|
||||
Bilinear,
|
||||
Nearest,
|
||||
Bicubic,
|
||||
};
|
||||
|
||||
enum DmlGridSamplePaddingMode : uint32_t
|
||||
{
|
||||
Zeros,
|
||||
Border,
|
||||
Reflection
|
||||
};
|
||||
|
||||
// Helper to derive dimensions and attributes from either the GridSample shape inferrer or the GridSample kernel constructor.
|
||||
struct DmlGridSampleParameters
|
||||
{
|
||||
uint32_t batchSize = 0;
|
||||
uint32_t channelSize = 0;
|
||||
uint32_t height = 0;
|
||||
uint32_t width = 0;
|
||||
int64_t alignCorners = 0;
|
||||
DmlGridSampleMode mode = DmlGridSampleMode::Bilinear;
|
||||
DmlGridSamplePaddingMode paddingMode = DmlGridSamplePaddingMode::Zeros;
|
||||
|
||||
DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
|
||||
DmlGridSampleParameters(){}
|
||||
|
||||
DmlGridSampleParameters(
|
||||
const OperatorHelper::IKernelInformationAdapter& kernelInfo,
|
||||
const OperatorHelper::IShapeInformationAdapter& shapeInfo)
|
||||
{
|
||||
auto& attributes = kernelInfo.GetAttributes();
|
||||
|
||||
alignCorners = attributes.GetOptionalAttribute<int64_t>(AttrName::AlignCorners, 0);
|
||||
|
||||
std::string str_attrib = attributes.GetOptionalAttribute<std::string>(AttrName::Mode, "bilinear");
|
||||
ML_CHECK_VALID_ARGUMENT(str_attrib == "bilinear" || str_attrib == "nearest" || str_attrib == "bicubic");
|
||||
if (str_attrib == "bilinear")
|
||||
{
|
||||
mode = DmlGridSampleMode::Bilinear;
|
||||
}
|
||||
else if (str_attrib == "nearest")
|
||||
{
|
||||
mode = DmlGridSampleMode::Nearest;
|
||||
}
|
||||
else if (str_attrib == "bicubic")
|
||||
{
|
||||
mode = DmlGridSampleMode::Bicubic;
|
||||
}
|
||||
|
||||
str_attrib = attributes.GetOptionalAttribute<std::string>(AttrName::PaddingMode, "zeros");
|
||||
ML_CHECK_VALID_ARGUMENT(str_attrib == "zeros" || str_attrib == "border" || str_attrib == "reflection");
|
||||
if (str_attrib == "zeros")
|
||||
{
|
||||
paddingMode = DmlGridSamplePaddingMode::Zeros;
|
||||
}
|
||||
else if (str_attrib == "border")
|
||||
{
|
||||
paddingMode = DmlGridSamplePaddingMode::Border;
|
||||
}
|
||||
else if (str_attrib == "reflection")
|
||||
{
|
||||
paddingMode = DmlGridSamplePaddingMode::Reflection;
|
||||
}
|
||||
|
||||
// input 0: signal (required; tensor)
|
||||
{
|
||||
// Input shape is expected to be [batch_size, channels, height, width]
|
||||
// 4-D tensor of shape (N, C, H_out, W_out) of sampled values.
|
||||
// For integer input types, intermediate values are computed as floating point and cast to integer at the end. uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input);
|
||||
uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input);
|
||||
ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D.");
|
||||
|
||||
auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Input);
|
||||
assert(dims.size() == rank);
|
||||
this->batchSize = dims[0];
|
||||
this->channelSize = dims[1];
|
||||
|
||||
MLOperatorEdgeDescription edgeDesc = kernelInfo.GetInputEdgeDescription(DmlGridSampleKernelInputIndex::Input);
|
||||
|
||||
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
|
||||
this->dataType = Dml::GetDmlDataTypeFromMlDataType(edgeDesc.tensorDataType);
|
||||
}
|
||||
|
||||
// input 1: grid (required; tensor)
|
||||
{
|
||||
// Grid shape is expected to be [batch_size, height_out, width_out, 2]
|
||||
uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Grid);
|
||||
ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D.");
|
||||
|
||||
auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Grid);
|
||||
assert(dims.size() == rank);
|
||||
this->height = dims[1];
|
||||
this->width = dims[2];
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
namespace GridSampleHelpers
|
||||
{
|
||||
// Divides and rounds
|
||||
inline uint32_t CeilDivide(uint32_t dividend, uint32_t divisor)
|
||||
{
|
||||
uint64_t temp = static_cast<uint64_t>(dividend) + divisor - 1;
|
||||
return static_cast<uint32_t>(temp / divisor);
|
||||
}
|
||||
|
||||
// Gets the next number of elements to dispatch to the GPU within a loop handling a large
|
||||
// total number of tensor elements and threads.
|
||||
void GetNextDispatchSize(
|
||||
uint32_t elementCount,
|
||||
uint32_t elementsPerThread,
|
||||
uint32_t numThreads,
|
||||
_Out_ uint32_t& dispatch,
|
||||
_Out_ uint32_t& pendingElementCount
|
||||
)
|
||||
{
|
||||
// Max threads per workgroup is 2^10 (1024). Max dispatch per dimension is 2^16. Taken together, we can dispatch a maximum of
|
||||
// 2^26 (268,435,456) threads along a single dimension. This should suffice for a majority of the workload. Therefore, even
|
||||
// though it is possible to dispatch up to (2^16)^3 workgroups simultaneously, we stick to the simpler 1D dispatch alternative.
|
||||
assert(numThreads <= D3D12_CS_THREAD_GROUP_MAX_THREADS_PER_GROUP);
|
||||
|
||||
const uint32_t maxThreadsPerDispatch = numThreads * D3D12_CS_DISPATCH_MAX_THREAD_GROUPS_PER_DIMENSION;
|
||||
|
||||
const uint32_t requiredThreadCount = CeilDivide(elementCount, elementsPerThread);
|
||||
|
||||
// Compute max dispatchable elements
|
||||
const uint32_t availableThreadCount = std::min(requiredThreadCount, maxThreadsPerDispatch);
|
||||
|
||||
// Compute required thread group count
|
||||
uint32_t workGroupCount1D = CeilDivide(availableThreadCount, numThreads);
|
||||
|
||||
// Compute min dispatch size
|
||||
dispatch = workGroupCount1D;
|
||||
|
||||
// With the dispatch size computed, compute the dispatched element count
|
||||
const uint32_t dispatchedElementCount = workGroupCount1D * numThreads * elementsPerThread;
|
||||
|
||||
// Update the pending element count
|
||||
pendingElementCount = (dispatchedElementCount < elementCount) ? elementCount - dispatchedElementCount : 0;
|
||||
}
|
||||
}
|
||||
|
||||
class DmlGridSampleOperator : public WRL::Base<IMLOperatorKernel>
|
||||
{
|
||||
private:
|
||||
ComPtr<ID3D12Device> m_device;
|
||||
ComPtr<ID3D12RootSignature> m_gridSampleRootSignature;
|
||||
ComPtr<ID3D12PipelineState> m_gridSamplePipelineState;
|
||||
DmlGridSampleParameters m_params = {};
|
||||
|
||||
|
||||
// Allocate temporary buffers if needed
|
||||
struct ResourceDesc
|
||||
{
|
||||
ComPtr<ID3D12Resource> Resource;
|
||||
std::array<uint32_t, 4> Sizes;
|
||||
std::array<uint32_t, 4> Strides;
|
||||
};
|
||||
|
||||
struct GridSampleShaderConstants
|
||||
{
|
||||
uint32_t StartIndex;
|
||||
uint32_t ElementCount;
|
||||
uint32_t Mode;
|
||||
uint32_t PaddingMode;
|
||||
uint32_t InputSizes[4];
|
||||
uint32_t InputStrides[4];
|
||||
uint32_t GridSizes[4];
|
||||
uint32_t GridStrides[4];
|
||||
uint32_t OutputSizes[4];
|
||||
uint32_t OutputStrides[4];
|
||||
uint32_t AlignCorners;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
DmlGridSampleOperator(IMLOperatorKernelCreationContext* context)
|
||||
{
|
||||
ComPtr<IUnknown> executionObject;
|
||||
context->GetExecutionInterface(executionObject.GetAddressOf());
|
||||
|
||||
ComPtr<ID3D12GraphicsCommandList> commandList;
|
||||
executionObject.As(&commandList);
|
||||
|
||||
ORT_THROW_IF_FAILED(commandList->GetDevice(IID_ID3D12Device, &m_device));
|
||||
|
||||
MLOperatorKernelCreationContext creationContext(context);
|
||||
OperatorHelper::KernelInformationAdapter kernelInfo{creationContext};
|
||||
OperatorHelper::ShapeInformationAdapter shapeInfo{creationContext};
|
||||
m_params = DmlGridSampleParameters(kernelInfo, shapeInfo);
|
||||
|
||||
MLOperatorEdgeDescription inputEdgeDesc;
|
||||
ORT_THROW_IF_FAILED(context->GetInputEdgeDescription(0, &inputEdgeDesc));
|
||||
assert(inputEdgeDesc.edgeType == MLOperatorEdgeType::Tensor);
|
||||
|
||||
MLOperatorEdgeDescription gridEdgeDesc;
|
||||
ORT_THROW_IF_FAILED(context->GetInputEdgeDescription(0, &gridEdgeDesc));
|
||||
assert(gridEdgeDesc.edgeType == MLOperatorEdgeType::Tensor);
|
||||
|
||||
PrepareGridSample(inputEdgeDesc.tensorDataType, gridEdgeDesc.tensorDataType);
|
||||
}
|
||||
|
||||
void PrepareGridSample(MLOperatorTensorDataType inputDataType, MLOperatorTensorDataType gridDataType)
|
||||
{
|
||||
// Compute root signature.
|
||||
const int uavCount = 3; // 3 bound UAVs: input, grid & output
|
||||
std::vector<CD3DX12_ROOT_PARAMETER1> rootParameters;
|
||||
rootParameters.resize(uavCount + 1);
|
||||
|
||||
for (uint32_t i = 0; i < uavCount; i++)
|
||||
{
|
||||
rootParameters[i].InitAsUnorderedAccessView(i);
|
||||
}
|
||||
|
||||
// cbuffer Constants
|
||||
// {
|
||||
// uint StartIndex;
|
||||
// uint ElementCount;
|
||||
// uint Mode;
|
||||
// uint PaddingMode;
|
||||
// uint4 InputSizes;
|
||||
// uint4 InputStrides;
|
||||
// uint4 GridSizes;
|
||||
// uint4 GridStrides;
|
||||
// uint4 OutputSizes;
|
||||
// uint4 OutputStrides;
|
||||
// uint AlignCorners;
|
||||
// };
|
||||
int constantCount = 29;
|
||||
rootParameters[uavCount].InitAsConstants(constantCount, 0);
|
||||
|
||||
CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC desc;
|
||||
desc.Init_1_1(static_cast<uint32_t>(rootParameters.size()), rootParameters.data());
|
||||
|
||||
ComPtr<ID3DBlob> rootSignatureBlob;
|
||||
ComPtr<ID3DBlob> rootSignatureErrorBlob;
|
||||
ORT_THROW_IF_FAILED(D3D12SerializeVersionedRootSignature(
|
||||
&desc,
|
||||
rootSignatureBlob.GetAddressOf(),
|
||||
rootSignatureErrorBlob.GetAddressOf()
|
||||
));
|
||||
|
||||
ORT_THROW_IF_FAILED(m_device->CreateRootSignature(
|
||||
0,
|
||||
rootSignatureBlob->GetBufferPointer(),
|
||||
rootSignatureBlob->GetBufferSize(),
|
||||
IID_ID3D12RootSignature,
|
||||
&m_gridSampleRootSignature
|
||||
));
|
||||
|
||||
// Describe and create the compute pipeline state object (PSO).
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
|
||||
computePsoDesc.pRootSignature = m_gridSampleRootSignature.Get();
|
||||
|
||||
switch (gridDataType)
|
||||
{
|
||||
case MLOperatorTensorDataType::Float:
|
||||
{
|
||||
switch (inputDataType)
|
||||
{
|
||||
case MLOperatorTensorDataType::UInt16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_float::g_GridSample, sizeof(GridSample_uint16_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt32:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_float::g_GridSample, sizeof(GridSample_uint_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt64:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_float::g_GridSample, sizeof(GridSample_uint64_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_float::g_GridSample, sizeof(GridSample_int16_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_float::g_GridSample, sizeof(GridSample_int_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_float::g_GridSample, sizeof(GridSample_int64_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_float::g_GridSample, sizeof(GridSample_fp16_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_float::g_GridSample, sizeof(GridSample_float_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Double:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_float::g_GridSample, sizeof(GridSample_double_float::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Bool:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_float::g_GridSample, sizeof(GridSample_bool_float::g_GridSample));
|
||||
break;
|
||||
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case MLOperatorTensorDataType::Float16:
|
||||
{
|
||||
switch (inputDataType)
|
||||
{
|
||||
case MLOperatorTensorDataType::UInt16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_fp16::g_GridSample, sizeof(GridSample_uint16_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt32:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_fp16::g_GridSample, sizeof(GridSample_uint_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt64:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_fp16::g_GridSample, sizeof(GridSample_uint64_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_fp16::g_GridSample, sizeof(GridSample_int16_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_fp16::g_GridSample, sizeof(GridSample_int_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_fp16::g_GridSample, sizeof(GridSample_int64_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_fp16::g_GridSample, sizeof(GridSample_fp16_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_fp16::g_GridSample, sizeof(GridSample_float_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Double:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_fp16::g_GridSample, sizeof(GridSample_double_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Bool:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_fp16::g_GridSample, sizeof(GridSample_bool_fp16::g_GridSample));
|
||||
break;
|
||||
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case MLOperatorTensorDataType::Double:
|
||||
{
|
||||
switch (inputDataType)
|
||||
{
|
||||
case MLOperatorTensorDataType::UInt16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_double::g_GridSample, sizeof(GridSample_uint16_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt32:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_double::g_GridSample, sizeof(GridSample_uint_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt64:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_double::g_GridSample, sizeof(GridSample_uint64_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_double::g_GridSample, sizeof(GridSample_int16_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_double::g_GridSample, sizeof(GridSample_int_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_double::g_GridSample, sizeof(GridSample_int64_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float16:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_double::g_GridSample, sizeof(GridSample_fp16_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Float:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_double::g_GridSample, sizeof(GridSample_float_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Double:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_double::g_GridSample, sizeof(GridSample_double_double::g_GridSample));
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Bool:
|
||||
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_double::g_GridSample, sizeof(GridSample_bool_double::g_GridSample));
|
||||
break;
|
||||
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
|
||||
ORT_THROW_IF_FAILED(m_device->CreateComputePipelineState(&computePsoDesc, IID_ID3D12PipelineState, &m_gridSamplePipelineState));
|
||||
}
|
||||
|
||||
// Computes the outputs of the kernel. This may be called multiple times
|
||||
// simultaneously within the same instance of the class. Implementations
|
||||
// of this method must be thread-safe.
|
||||
STDMETHOD(Compute)(IMLOperatorKernelContext* context)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Get the input tensor
|
||||
ComPtr<IMLOperatorTensor> inputTensor;
|
||||
ORT_THROW_IF_FAILED(context->GetInputTensor(0, inputTensor.GetAddressOf()));
|
||||
|
||||
// Get the grid tensor
|
||||
ComPtr<IMLOperatorTensor> gridTensor;
|
||||
ORT_THROW_IF_FAILED(context->GetInputTensor(1, gridTensor.GetAddressOf()));
|
||||
|
||||
// Get the output tensor
|
||||
ComPtr<IMLOperatorTensor> outputTensor;
|
||||
context->GetOutputTensor(0, outputTensor.GetAddressOf());
|
||||
|
||||
if (outputTensor->IsCpuData() || inputTensor->IsCpuData() || gridTensor->IsCpuData())
|
||||
{
|
||||
return E_UNEXPECTED;
|
||||
}
|
||||
|
||||
ComPtr<IUnknown> executionObject;
|
||||
ComPtr<ID3D12GraphicsCommandList> commandList;
|
||||
context->GetExecutionInterface(executionObject.GetAddressOf());
|
||||
executionObject.As(&commandList);
|
||||
|
||||
// Get the input and output shape sizes
|
||||
auto inputDims = GetTensorDimensions(inputTensor.Get());
|
||||
auto gridDims = GetTensorDimensions(gridTensor.Get());
|
||||
auto outputDims = GetTensorDimensions(outputTensor.Get());
|
||||
|
||||
ComPtr<IUnknown> inputUnknown;
|
||||
ComPtr<ID3D12Resource> inputResource;
|
||||
inputTensor->GetDataInterface(inputUnknown.GetAddressOf());
|
||||
ORT_THROW_IF_FAILED(inputUnknown.As(&inputResource));
|
||||
|
||||
ComPtr<IUnknown> gridUnknown;
|
||||
ComPtr<ID3D12Resource> gridResource;
|
||||
gridTensor->GetDataInterface(gridUnknown.GetAddressOf());
|
||||
ORT_THROW_IF_FAILED(gridUnknown.As(&gridResource));
|
||||
|
||||
ComPtr<IUnknown> outputUnknown;
|
||||
ComPtr<ID3D12Resource> outputResource;
|
||||
outputTensor->GetDataInterface(outputUnknown.GetAddressOf());
|
||||
ORT_THROW_IF_FAILED(outputUnknown.As(&outputResource));
|
||||
|
||||
return Compute(
|
||||
commandList.Get(),
|
||||
context,
|
||||
inputResource.Get(),
|
||||
inputDims,
|
||||
gridResource.Get(),
|
||||
gridDims,
|
||||
outputResource.Get(),
|
||||
outputDims
|
||||
);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return E_FAIL;
|
||||
}
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
HRESULT Compute(
|
||||
ID3D12GraphicsCommandList* commandList,
|
||||
IMLOperatorKernelContext* context,
|
||||
ID3D12Resource* inputResource,
|
||||
gsl::span<const uint32_t> inputDims,
|
||||
ID3D12Resource* gridResource,
|
||||
gsl::span<const uint32_t> gridDims,
|
||||
ID3D12Resource* outputResource,
|
||||
gsl::span<const uint32_t> outputDims)
|
||||
{
|
||||
try
|
||||
{
|
||||
GridSample(
|
||||
inputResource,
|
||||
inputDims,
|
||||
gridResource,
|
||||
gridDims,
|
||||
outputResource,
|
||||
outputDims,
|
||||
commandList);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return E_FAIL;
|
||||
}
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
void GridSample(
|
||||
ID3D12Resource* inputResource,
|
||||
gsl::span<const uint32_t> inputDims,
|
||||
ID3D12Resource* gridResource,
|
||||
gsl::span<const uint32_t> gridDims,
|
||||
ID3D12Resource* outputResource,
|
||||
gsl::span<const uint32_t> outputDims,
|
||||
ID3D12GraphicsCommandList* commandList)
|
||||
{
|
||||
std::array<uint32_t, 4> inputStrides;
|
||||
std::array<uint32_t, 4> gridStrides;
|
||||
std::array<uint32_t, 4> outputStrides;
|
||||
Dml::GetDescendingPackedStrides(inputDims, inputStrides);
|
||||
Dml::GetDescendingPackedStrides(gridDims, gridStrides);
|
||||
Dml::GetDescendingPackedStrides(outputDims, outputStrides);
|
||||
|
||||
// Transition resources from common to UAV state
|
||||
D3D12_RESOURCE_BARRIER barriers[3];
|
||||
|
||||
barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition(
|
||||
inputResource,
|
||||
D3D12_RESOURCE_STATE_COMMON,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS
|
||||
);
|
||||
|
||||
barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition(
|
||||
gridResource,
|
||||
D3D12_RESOURCE_STATE_COMMON,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS
|
||||
);
|
||||
|
||||
barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition(
|
||||
outputResource,
|
||||
D3D12_RESOURCE_STATE_COMMON,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS
|
||||
);
|
||||
|
||||
inputResource->SetName(L"InputResource");
|
||||
outputResource->SetName(L"OutputResource");
|
||||
gridResource->SetName(L"GridResource");
|
||||
|
||||
commandList->ResourceBarrier(3, barriers);
|
||||
|
||||
// Set the root signature and pipeline state
|
||||
commandList->SetComputeRootSignature(m_gridSampleRootSignature.Get());
|
||||
commandList->SetPipelineState(m_gridSamplePipelineState.Get());
|
||||
|
||||
// Each iteration of the below loop represents 1 level in the Stockham DFT
|
||||
// Dispatch in a loop
|
||||
GridSampleShaderConstants constants = {};
|
||||
constants.AlignCorners = static_cast<uint32_t>(m_params.alignCorners);
|
||||
constants.Mode = static_cast<uint32_t>(m_params.mode);
|
||||
constants.PaddingMode = static_cast<uint32_t>(m_params.paddingMode);
|
||||
std::copy(inputDims.begin(), inputDims.end(), constants.InputSizes);
|
||||
std::copy(inputStrides.begin(), inputStrides.end(), constants.InputStrides);
|
||||
std::copy(gridDims.begin(), gridDims.end(), constants.GridSizes);
|
||||
std::copy(gridStrides.begin(), gridStrides.end(), constants.GridStrides);
|
||||
std::copy(outputDims.begin(), outputDims.end(), constants.OutputSizes);
|
||||
std::copy(outputStrides.begin(), outputStrides.end(), constants.OutputStrides);
|
||||
|
||||
constants.ElementCount = ComputeElementCountFromDimensions(constants.OutputSizes);
|
||||
std::array<ID3D12Resource*, 3> uav_resources = { inputResource, gridResource, outputResource };
|
||||
Dispatch(uav_resources, constants, commandList);
|
||||
|
||||
// Transition resources to common state
|
||||
barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition(
|
||||
inputResource,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
D3D12_RESOURCE_STATE_COMMON
|
||||
);
|
||||
|
||||
barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition(
|
||||
gridResource,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
D3D12_RESOURCE_STATE_COMMON
|
||||
);
|
||||
|
||||
barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition(
|
||||
outputResource,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
D3D12_RESOURCE_STATE_COMMON
|
||||
);
|
||||
|
||||
commandList->ResourceBarrier(3, barriers);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> GetTensorDimensions(IMLOperatorTensor* tensor)
|
||||
{
|
||||
auto inputDimsSize = tensor->GetDimensionCount();
|
||||
auto dims = std::vector<uint32_t>(inputDimsSize);
|
||||
ORT_THROW_IF_FAILED(tensor->GetShape(static_cast<uint32_t>(dims.size()), dims.data()));
|
||||
return dims;
|
||||
}
|
||||
|
||||
template <typename TConstants, uint32_t TSize>
|
||||
void Dispatch(
|
||||
std::array<ID3D12Resource*, TSize>& resources,
|
||||
TConstants& constants,
|
||||
ID3D12GraphicsCommandList* commandList)
|
||||
{
|
||||
D3D12_RESOURCE_BARRIER uav_barriers[TSize];
|
||||
|
||||
std::transform(
|
||||
resources.begin(), resources.end(),
|
||||
uav_barriers,
|
||||
[](auto& resource) { return CD3DX12_RESOURCE_BARRIER::UAV(resource); } );
|
||||
commandList->ResourceBarrier(TSize, uav_barriers);
|
||||
|
||||
for (uint32_t i = 0; i < TSize; i++)
|
||||
{
|
||||
// Set resource views
|
||||
if (resources[i]) {
|
||||
commandList->SetComputeRootUnorderedAccessView(
|
||||
i, // root parameter index
|
||||
resources[i]->GetGPUVirtualAddress()
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
commandList->SetComputeRootUnorderedAccessView(
|
||||
i, // root parameter index
|
||||
{}
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
auto pendingElementCount = constants.ElementCount;
|
||||
|
||||
// Dispatch up to the maximum number of threads per iteration until
|
||||
// all elements are completed
|
||||
while (pendingElementCount > 0)
|
||||
{
|
||||
constants.StartIndex = constants.ElementCount - pendingElementCount;
|
||||
|
||||
uint32_t dispatchSizeX;
|
||||
|
||||
GridSampleHelpers::GetNextDispatchSize(
|
||||
pendingElementCount,
|
||||
1,
|
||||
64,
|
||||
dispatchSizeX,
|
||||
pendingElementCount
|
||||
);
|
||||
|
||||
// Set root constants
|
||||
commandList->SetComputeRoot32BitConstants(
|
||||
TSize, // root parameter index
|
||||
29, // Constant count
|
||||
&constants,
|
||||
0 // offset
|
||||
);
|
||||
|
||||
commandList->Dispatch(dispatchSizeX, 1, 1);
|
||||
}
|
||||
|
||||
commandList->ResourceBarrier(2, uav_barriers);
|
||||
}
|
||||
};
|
||||
|
||||
struct GridSampleShapeInferrer : public WRL::Base<IMLOperatorShapeInferrer>
|
||||
{
|
||||
STDMETHOD(InferOutputShapes)(IMLOperatorShapeInferenceContext* context) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
ComPtr<IMLOperatorShapeInferenceContextPrivate> contextPrivate;
|
||||
ORT_THROW_IF_FAILED(context->QueryInterface(IID_PPV_ARGS(&contextPrivate)));
|
||||
|
||||
MLShapeInferenceContext inferenceContext(context);
|
||||
OperatorHelper::KernelInformationAdapter kernelInfo{inferenceContext};
|
||||
OperatorHelper::ShapeInformationAdapter shapeInfo{inferenceContext};
|
||||
DmlGridSampleParameters params(kernelInfo, shapeInfo);
|
||||
|
||||
std::array<uint32_t, 4> outputDims = { params.batchSize, params.channelSize, params.height, params.width };
|
||||
|
||||
ORT_THROW_IF_FAILED(context->SetOutputTensorShape(0, onnxruntime::narrow<uint32_t>(outputDims.size()), outputDims.data()));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return E_FAIL;
|
||||
}
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
};
|
||||
|
||||
class DmlGridSampleOperatorFactory : public WRL::Base<IMLOperatorKernelFactory>
|
||||
{
|
||||
public:
|
||||
STDMETHOD(CreateKernel)(
|
||||
IMLOperatorKernelCreationContext* context,
|
||||
IMLOperatorKernel** kernel)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto dftOperator = wil::MakeOrThrow<DmlGridSampleOperator>(context);
|
||||
dftOperator.CopyTo(kernel);
|
||||
return S_OK;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return E_FAIL;
|
||||
}
|
||||
}
|
||||
|
||||
static void RegisterGridSampleKernel(IMLOperatorRegistry* registry)
|
||||
{
|
||||
MLOperatorKernelDescription kernelDescription = {};
|
||||
kernelDescription.domain = "";
|
||||
kernelDescription.name = "GridSample";
|
||||
kernelDescription.minimumOperatorSetVersion = 16;
|
||||
kernelDescription.executionType = MLOperatorExecutionType::D3D12;
|
||||
|
||||
// T1: tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
|
||||
MLOperatorEdgeTypeConstrant t1Constraint;
|
||||
t1Constraint.typeLabel = "T1";
|
||||
std::vector<MLOperatorEdgeDescription> t1AllowedEdges
|
||||
{
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float16 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int8 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int16 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int32 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int64 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt8 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt16 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt32 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt64 },
|
||||
};
|
||||
t1Constraint.allowedTypes = t1AllowedEdges.data();
|
||||
t1Constraint.allowedTypeCount = static_cast<uint32_t>(t1AllowedEdges.size());
|
||||
|
||||
// T2 : tensor(int32), tensor(int64)
|
||||
MLOperatorEdgeTypeConstrant t2Constraint;
|
||||
t2Constraint.typeLabel = "T2";
|
||||
std::vector<MLOperatorEdgeDescription> t2AllowedEdges
|
||||
{
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float16 },
|
||||
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float },
|
||||
};
|
||||
t2Constraint.allowedTypes = t2AllowedEdges.data();
|
||||
t2Constraint.allowedTypeCount = static_cast<uint32_t>(t2AllowedEdges.size());
|
||||
|
||||
std::vector<MLOperatorEdgeTypeConstrant> typeConstraints{ t1Constraint, t2Constraint };
|
||||
kernelDescription.typeConstraints = typeConstraints.data();
|
||||
kernelDescription.typeConstraintCount = static_cast<uint32_t>(typeConstraints.size());
|
||||
|
||||
MLOperatorAttributeNameValue alignedCornersAttributeValue;
|
||||
alignedCornersAttributeValue.name = AttrName::AlignCorners;
|
||||
alignedCornersAttributeValue.type = MLOperatorAttributeType::Int;
|
||||
alignedCornersAttributeValue.valueCount = 1;
|
||||
static const int64_t alignedCorners[] = { 0 };
|
||||
alignedCornersAttributeValue.ints = alignedCorners;
|
||||
|
||||
MLOperatorAttributeNameValue modeAttributeValue;
|
||||
modeAttributeValue.name = AttrName::Mode;
|
||||
modeAttributeValue.type = MLOperatorAttributeType::String;
|
||||
modeAttributeValue.valueCount = 1;
|
||||
static const char* modes[] = { "bilinear" };
|
||||
modeAttributeValue.strings = modes;
|
||||
|
||||
MLOperatorAttributeNameValue paddingModeAttributeValue;
|
||||
paddingModeAttributeValue.name = AttrName::Mode;
|
||||
paddingModeAttributeValue.type = MLOperatorAttributeType::String;
|
||||
paddingModeAttributeValue.valueCount = 1;
|
||||
static const char* paddingModes[] = { "zeros" };
|
||||
paddingModeAttributeValue.strings = paddingModes;
|
||||
|
||||
std::vector<MLOperatorAttributeNameValue> attributeDefaultValues{
|
||||
alignedCornersAttributeValue,
|
||||
modeAttributeValue,
|
||||
paddingModeAttributeValue
|
||||
};
|
||||
|
||||
kernelDescription.defaultAttributes = attributeDefaultValues.data();
|
||||
kernelDescription.defaultAttributeCount = static_cast<uint32_t>(attributeDefaultValues.size());
|
||||
kernelDescription.options = MLOperatorKernelOptions::None;
|
||||
kernelDescription.executionOptions = 0;
|
||||
|
||||
auto shareInferrer = wil::MakeOrThrow<GridSampleShapeInferrer>();
|
||||
auto factory = wil::MakeOrThrow<DmlGridSampleOperatorFactory>();
|
||||
|
||||
ComPtr<IMLOperatorRegistryPrivate> registryPrivate;
|
||||
ORT_THROW_IF_FAILED(registry->QueryInterface(IID_PPV_ARGS(®istryPrivate)));
|
||||
|
||||
ORT_THROW_IF_FAILED(registryPrivate->RegisterOperatorKernel(
|
||||
&kernelDescription,
|
||||
factory.Get(),
|
||||
shareInferrer.Get(),
|
||||
nullptr,
|
||||
false, // isInternalOperator
|
||||
false, // alias
|
||||
false, // supportsGraph
|
||||
nullptr,
|
||||
nullptr,
|
||||
0));
|
||||
|
||||
}
|
||||
};
|
||||
|
|
@ -7,10 +7,78 @@ if "%1" == "DEBUG" (
|
|||
|
||||
fxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_5_0 /DTBUFFER=float /Zi /Od /Fh bluestein_chirp.h
|
||||
dxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh bluestein_chirp_fp16.h
|
||||
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=uint /DTBUFFER2=float /Zi /Od /Fh grid_sample_uint_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=int /DTBUFFER2=float /Zi /Od /Fh grid_sample_int_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /Zi /Od /Fh grid_sample_float_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /Zi /Od /Fh grid_sample_double_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /Zi /Od /Fh grid_sample_bool_float.h
|
||||
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_fp16.h
|
||||
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_double.h
|
||||
|
||||
) else (
|
||||
fxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh stockham.h
|
||||
dxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh stockham_fp16.h
|
||||
|
||||
fxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh bluestein_chirp.h
|
||||
dxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh bluestein_chirp_fp16.h
|
||||
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=uint /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_uint_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=int /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_int_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_float.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_float_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_double_float.h
|
||||
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_bool_float.h
|
||||
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_fp16.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_fp16.h
|
||||
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_double.h
|
||||
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_double.h
|
||||
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -4,6 +4,7 @@
|
|||
#include "precomp.h"
|
||||
#include "DmlDFT.h"
|
||||
#include "DmlSTFT.h"
|
||||
#include "DmlGridSample.h"
|
||||
#include "OperatorRegistration.h"
|
||||
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
|
||||
#include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h"
|
||||
|
|
@ -1054,6 +1055,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
|
||||
GpuDFTOperatorFactory::RegisterDFTKernel(registry);
|
||||
DmlSTFTOperatorFactory::RegisterSTFTKernel(registry);
|
||||
DmlGridSampleOperatorFactory::RegisterGridSampleKernel(registry);
|
||||
}
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -0,0 +1,289 @@
|
|||
// TBUFFER is the data type to read from src and write to dst.
|
||||
// Arithmetic is always done in FP32.
|
||||
#if !defined(TBUFFER1)
|
||||
#define TBUFFER1 float
|
||||
#endif
|
||||
#if !defined(TBUFFER2)
|
||||
#define TBUFFER2 float
|
||||
#endif
|
||||
|
||||
RWStructuredBuffer<TBUFFER1> input : register(u0);
|
||||
RWStructuredBuffer<TBUFFER2> grid : register(u1);
|
||||
RWStructuredBuffer<TBUFFER1> output : register(u2);
|
||||
|
||||
static const uint Zeros = 0;
|
||||
static const uint Border = 1;
|
||||
static const uint Reflection = 2;
|
||||
|
||||
static const uint Bilinear = 0;
|
||||
static const uint Nearest = 1;
|
||||
static const uint Bicubic = 2;
|
||||
|
||||
cbuffer Constants
|
||||
{
|
||||
uint StartIndex;
|
||||
uint ElementCount;
|
||||
uint Mode;
|
||||
uint PaddingMode;
|
||||
uint4 InputSizes;
|
||||
uint4 InputStrides;
|
||||
uint4 GridSizes;
|
||||
uint4 GridStrides;
|
||||
uint4 OutputSizes;
|
||||
uint4 OutputStrides;
|
||||
uint AlignCorners;
|
||||
};
|
||||
|
||||
uint4 DecomposeIndex(uint index)
|
||||
{
|
||||
uint4 idx = uint4(0, 0, 0, 0);
|
||||
|
||||
uint n = OutputSizes.x;
|
||||
uint c = OutputSizes.y;
|
||||
uint h = OutputSizes.z;
|
||||
uint w = OutputSizes.w;
|
||||
|
||||
uint width_denominator = 1;
|
||||
uint height_denominator = w;
|
||||
uint channel_denominator = w * h;
|
||||
uint batch_denominator = w * h * c;
|
||||
|
||||
idx.x = (index) / batch_denominator; // batch
|
||||
idx.y = (index - (idx.x * OutputStrides.x)) / channel_denominator; // channel
|
||||
idx.z = (index - (idx.x * OutputStrides.x) - (idx.y * OutputStrides.y)) / height_denominator; // height
|
||||
idx.w = (index - (idx.x * OutputStrides.x) - (idx.y * OutputStrides.y) - (idx.z * OutputStrides.z)) / width_denominator; // width
|
||||
return idx;
|
||||
}
|
||||
// Returns the indices for the real and complex output uav
|
||||
float2 FetchGridVector(uint4 index)
|
||||
{
|
||||
// The index is in (n,c,h,w)
|
||||
// The shape of gridsizes (and gridstrides) is (n, h, w, 2)
|
||||
uint n = index.x;
|
||||
uint c = index.y;
|
||||
uint h = index.z;
|
||||
uint w = index.w;
|
||||
|
||||
float4 gridIdx = float4(n, h, w, 0);
|
||||
float2 flattenedGridIndex = float2(0, 0);
|
||||
flattenedGridIndex.x = dot(gridIdx, GridStrides);
|
||||
flattenedGridIndex.y = flattenedGridIndex.x + GridStrides.w;
|
||||
return float2((float)grid[flattenedGridIndex.x],
|
||||
(float)grid[flattenedGridIndex.y]);
|
||||
}
|
||||
|
||||
// {x_min, y_min, x_max, y_max};
|
||||
float4 CalculateBorders()
|
||||
{
|
||||
// Force float here to avoid possible issue in integer T case
|
||||
float2 mins = float2(-0.5f, -0.5f);
|
||||
float2 maxes = float2(InputSizes.w - 0.5f, // W_in
|
||||
InputSizes.z - 0.5f); // H_in
|
||||
|
||||
if (AlignCorners) {
|
||||
mins = float2(0.f, 0.f);
|
||||
maxes = float2(InputSizes.w - 1.f, // W_in
|
||||
InputSizes.z - 1.f); // H_in
|
||||
}
|
||||
return float4(mins.xy, maxes.xy);
|
||||
}
|
||||
|
||||
|
||||
// Reflect by the near border till within the borders
|
||||
// Use float for borders to avoid potential issues with integer T
|
||||
float Reflect(float x, float x_min, float x_max) {
|
||||
float range = x_max - x_min;
|
||||
if (x < x_min) {
|
||||
float dx = x_min - x;
|
||||
uint n = dx / range;
|
||||
float r = dx - n * range;
|
||||
if (n % 2 == 0) {
|
||||
x = x_min + r;
|
||||
} else {
|
||||
x = x_max - r;
|
||||
}
|
||||
} else if (x > x_max) {
|
||||
float dx = x - x_max;
|
||||
uint n = dx / range;
|
||||
float r = dx - n * range;
|
||||
if (n % 2 == 0) {
|
||||
x = x_max - r;
|
||||
} else {
|
||||
x = x_min + r;
|
||||
}
|
||||
}
|
||||
// else fallthrough
|
||||
return x;
|
||||
}
|
||||
|
||||
|
||||
// Restore normalized location to actual image location
|
||||
// When align_corners is true:
|
||||
// Normalized location (-1, -1) points to the top-left pixel.
|
||||
// Normalized location (1, 1) points to the bottom-right pixel.
|
||||
// When align_corners is false [default]:
|
||||
// Normalized location (-1, -1) points to the top-left pixel minus half
|
||||
// pixel in both directions, i.e, (-0.5, -0.5) in actual image space.
|
||||
// Normalized location (1, 1) points to the bottom-right pixel plus half
|
||||
// pixel in both directions, i.e. (H - 0.5, W - 0.5) in actual image space.
|
||||
float2 DenormalizeInput(float2 n, float4 border)
|
||||
{
|
||||
float2 dims = InputSizes.wz; // w-h
|
||||
if (AlignCorners == 1)
|
||||
{
|
||||
// AlignCorners: true => [-1, 1] to [0, dims - 1]
|
||||
n = (n + 1) / 2.f * (dims - 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
// AlignCorners: false => [-1, 1] to [-0.5, dims - 0.5]
|
||||
n = ((n + 1) * dims - 1) / 2.f;
|
||||
}
|
||||
|
||||
if (Mode == Nearest)
|
||||
{
|
||||
n = round(n);
|
||||
}
|
||||
|
||||
float x_min = border.x;
|
||||
float y_min = border.y;
|
||||
float x_max = border.z;
|
||||
float y_max = border.w;
|
||||
|
||||
if (n.x < x_min || n.x > x_max || n.y < y_min || n.y > y_max) { // out of bound
|
||||
if (PaddingMode == Border) {
|
||||
// use original border in both align_corner cases
|
||||
n.x = clamp(n.x, 0, InputSizes.w - 1);
|
||||
n.y = clamp(n.y, 0, InputSizes.z - 1);
|
||||
} else if (PaddingMode == Reflection) {
|
||||
n.x = Reflect(n.x, x_min, x_max);
|
||||
n.y = Reflect(n.y, y_min, y_max);
|
||||
}
|
||||
} // out of bound
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
float FetchInputPixel(uint4 index)
|
||||
{
|
||||
// index and InputStrides is in (n,c, h, w)
|
||||
return (float)input[dot(index, InputStrides)];
|
||||
}
|
||||
|
||||
float PixelAtGrid(float4 inputIdx, float4 border) {
|
||||
float pixel = 0; // default 0
|
||||
|
||||
if (PaddingMode == Zeros)
|
||||
{
|
||||
if (inputIdx.w >= 0 && (uint)inputIdx.w < (uint)InputSizes.w &&
|
||||
inputIdx.z >= 0 && (uint)inputIdx.z < (uint)InputSizes.z)
|
||||
{
|
||||
pixel = FetchInputPixel(inputIdx);
|
||||
}
|
||||
}
|
||||
else if (PaddingMode == Border)
|
||||
{
|
||||
uint w = clamp(inputIdx.w, 0, InputSizes.w - 1);
|
||||
uint z = clamp(inputIdx.z, 0, InputSizes.z - 1);
|
||||
pixel = FetchInputPixel(float4(inputIdx.xy, z, w));
|
||||
}
|
||||
else if (PaddingMode == Reflection)
|
||||
{
|
||||
uint w = Reflect(inputIdx.w, border.x, border.z);
|
||||
uint z = Reflect(inputIdx.z, border.y, border.w);
|
||||
pixel = FetchInputPixel(float4(inputIdx.xy, z, w));
|
||||
}
|
||||
|
||||
return pixel;
|
||||
}
|
||||
|
||||
float BicubicConvolutionPFunction(float t, float fminus1, float f0, float f1, float f2)
|
||||
{
|
||||
static const float a = -.75;
|
||||
static const float4x4 bicubicConvolutionMatrix =
|
||||
{
|
||||
0, 1, 0, 0,
|
||||
a, 0, -a, 0,
|
||||
-2*a, -3-a, 3+2*a, a,
|
||||
a, 2+a, -2-a, -a
|
||||
};
|
||||
|
||||
float4 t_vec = float4( 1, t, t * t, t * t * t);
|
||||
float4 f_vec = float4(fminus1, f0, f1, f2);
|
||||
return mul(t_vec, mul(bicubicConvolutionMatrix, f_vec));
|
||||
}
|
||||
|
||||
[numthreads(64, 1, 1)]
|
||||
void GridSample(uint3 dtid : SV_DispatchThreadId)
|
||||
{
|
||||
uint n = StartIndex + dtid.x;
|
||||
if (n < ElementCount)
|
||||
{
|
||||
float4 border = CalculateBorders();
|
||||
uint4 index = DecomposeIndex(n);
|
||||
float2 flowVector = FetchGridVector(index);
|
||||
float2 inputWidthAndHeightIdx = DenormalizeInput(flowVector, border);
|
||||
float4 inputIdx = float4(index.x, // N
|
||||
index.y, // C
|
||||
inputWidthAndHeightIdx.y, // H
|
||||
inputWidthAndHeightIdx.x); // W
|
||||
|
||||
if (Mode == Nearest)
|
||||
{
|
||||
output[n] = (TBUFFER1)(PixelAtGrid(inputIdx, border));
|
||||
}
|
||||
else if (Mode == Bilinear)
|
||||
{
|
||||
float x1 = floor(inputIdx.w);
|
||||
float y1 = floor(inputIdx.z);
|
||||
float x2 = x1 + 1;
|
||||
float y2 = y1 + 1;
|
||||
|
||||
float p11 = PixelAtGrid(float4(index.x, index.y, y1, x1), border);
|
||||
float p12 = PixelAtGrid(float4(index.x, index.y, y1, x2), border);
|
||||
float p21 = PixelAtGrid(float4(index.x, index.y, y2, x1), border);
|
||||
float p22 = PixelAtGrid(float4(index.x, index.y, y2, x2), border);
|
||||
|
||||
// p11--p12
|
||||
// | |
|
||||
// p21--p22
|
||||
float p1 = lerp(p11, p12, frac(inputIdx.w));
|
||||
float p2 = lerp(p21, p22, frac(inputIdx.w));
|
||||
float p = lerp(p1, p2, frac(inputIdx.z));
|
||||
output[n] = (TBUFFER1)(p);
|
||||
}
|
||||
else if (Mode == Bicubic)
|
||||
{
|
||||
float x0 = floor(inputIdx.w) - 1;
|
||||
float y0 = floor(inputIdx.z) - 1;
|
||||
|
||||
float f00 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 0), border);
|
||||
float f01 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 1), border);
|
||||
float f02 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 2), border);
|
||||
float f03 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 3), border);
|
||||
float f10 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 0), border);
|
||||
float f11 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 1), border);
|
||||
float f12 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 2), border);
|
||||
float f13 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 3), border);
|
||||
float f20 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 0), border);
|
||||
float f21 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 1), border);
|
||||
float f22 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 2), border);
|
||||
float f23 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 3), border);
|
||||
float f30 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 0), border);
|
||||
float f31 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 1), border);
|
||||
float f32 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 2), border);
|
||||
float f33 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 3), border);
|
||||
|
||||
float tx = frac(inputIdx.w);
|
||||
float ty = frac(inputIdx.z);
|
||||
|
||||
float bminus1 = BicubicConvolutionPFunction(ty, f00, f10, f20, f30);
|
||||
float b0 = BicubicConvolutionPFunction(ty, f01, f11, f21, f31);
|
||||
float b1 = BicubicConvolutionPFunction(ty, f02, f12, f22, f32);
|
||||
float b2 = BicubicConvolutionPFunction(ty, f03, f13, f23, f33);
|
||||
float p = BicubicConvolutionPFunction(tx, bminus1, b0, b1, b2);
|
||||
|
||||
output[n] = (TBUFFER1)(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -11,6 +11,7 @@ namespace AttrName
|
|||
static constexpr const char* Activations = "activations";
|
||||
static constexpr const char* AllowZero = "allowzero";
|
||||
static constexpr const char* Alpha = "alpha";
|
||||
static constexpr const char* AlignCorners = "align_corners";
|
||||
static constexpr const char* AutoPad = "auto_pad";
|
||||
static constexpr const char* Axes = "axes";
|
||||
static constexpr const char* Axis = "axis";
|
||||
|
|
@ -64,6 +65,7 @@ namespace AttrName
|
|||
static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes";
|
||||
static constexpr const char* NormalizeVariance = "normalize_variance";
|
||||
static constexpr const char* P = "p";
|
||||
static constexpr const char* PaddingMode = "padding_mode";
|
||||
static constexpr const char* OutputHeight = "output_height";
|
||||
static constexpr const char* OutputShape = "output_shape";
|
||||
static constexpr const char* OutputPadding = "output_padding";
|
||||
|
|
|
|||
|
|
@ -1124,16 +1124,174 @@ static void ModelBuilding_ConstantMatmul() {
|
|||
#endif
|
||||
}
|
||||
|
||||
|
||||
#if !defined(BUILD_INBOX)
|
||||
|
||||
enum class Mode : uint32_t {
|
||||
Bilinear,
|
||||
Nearest,
|
||||
Bicubic,
|
||||
};
|
||||
|
||||
enum class PaddingMode : uint32_t {
|
||||
Zeros,
|
||||
Border,
|
||||
Reflection,
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
static void GridSample(
|
||||
LearningModelDeviceKind kind,
|
||||
const std::vector<T>& input,
|
||||
const std::vector<int64_t>& input_dims,
|
||||
const std::vector<U>& grid,
|
||||
const std::vector<int64_t>& grid_dims,
|
||||
bool align_corners,
|
||||
Mode mode,
|
||||
PaddingMode padding_mode
|
||||
) {
|
||||
const hstring modes[] = {
|
||||
L"bilinear",
|
||||
L"nearest",
|
||||
L"bicubic"};
|
||||
|
||||
const hstring padding_modes[] = {
|
||||
L"zeros",
|
||||
L"border",
|
||||
L"reflection"};
|
||||
|
||||
auto model =
|
||||
LearningModelBuilder::Create(17)
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, input_dims))
|
||||
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Grid", TensorKind::Float, grid_dims))
|
||||
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, {-1, -1, -1, -1}))
|
||||
.Operators().Add(Operator(L"GridSample")
|
||||
.SetInput(L"X", L"Input")
|
||||
.SetInput(L"grid", L"Grid")
|
||||
.SetAttribute(L"align_corners", TensorInt64Bit::CreateFromArray({ }, {INT64(align_corners)}))
|
||||
.SetAttribute(L"mode", TensorString::CreateFromArray({ }, { modes[static_cast<uint32_t>(mode)] }))
|
||||
.SetAttribute(L"padding_mode", TensorString::CreateFromArray({ }, { padding_modes[static_cast<uint32_t>(padding_mode)] }))
|
||||
.SetOutput(L"Y", L"Output"))
|
||||
.CreateModel();
|
||||
auto cpu_device = LearningModelDevice(LearningModelDeviceKind::Cpu);
|
||||
auto device = LearningModelDevice(kind);
|
||||
LearningModelSession device_session(model, device);
|
||||
LearningModelBinding device_binding(device_session);
|
||||
LearningModelSession cpu_session(model, cpu_device);
|
||||
LearningModelBinding cpu_binding(cpu_session);
|
||||
|
||||
device_binding.Bind(L"Input", TensorFloat::CreateFromShapeArrayAndDataArray(input_dims, input));
|
||||
device_binding.Bind(L"Grid", TensorFloat::CreateFromShapeArrayAndDataArray(grid_dims, grid));
|
||||
cpu_binding.Bind(L"Input", TensorFloat::CreateFromShapeArrayAndDataArray(input_dims, input));
|
||||
cpu_binding.Bind(L"Grid", TensorFloat::CreateFromShapeArrayAndDataArray(grid_dims, grid));
|
||||
|
||||
auto cpu_result = cpu_session.Evaluate(cpu_binding, L"");
|
||||
|
||||
// Evaluate
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
auto device_result = device_session.Evaluate(device_binding, L"");
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double, std::micro> evaluate_duration_in_microseconds = end - start;
|
||||
printf("GridSample[Mode=%ls, PaddingMode=%ls, AlignCorners=%s] took %fus.\n",
|
||||
modes[static_cast<uint32_t>(mode)].c_str(),
|
||||
padding_modes[static_cast<uint32_t>(padding_mode)].c_str(),
|
||||
align_corners ? "True" : "False",
|
||||
evaluate_duration_in_microseconds.count());
|
||||
|
||||
// Check results
|
||||
constexpr float error_threshold = .001f;
|
||||
auto device_y_tensor = device_result.Outputs().Lookup(L"Output").as<TensorFloat>();
|
||||
auto device_y_ivv = device_y_tensor.GetAsVectorView();
|
||||
auto cpu_y_tensor = cpu_result.Outputs().Lookup(L"Output").as<TensorFloat>();
|
||||
auto cpu_y_ivv = cpu_y_tensor.GetAsVectorView();
|
||||
WINML_EXPECT_EQUAL(device_y_ivv.Size(), cpu_y_ivv.Size());
|
||||
for (uint32_t i = 0; i < device_y_ivv.Size(); i++) {
|
||||
bool in_range = abs(device_y_ivv.GetAt(i) - cpu_y_ivv.GetAt(i)) < error_threshold;
|
||||
if (!in_range) {
|
||||
printf("[%d] ACTUAL(%f) EXPECTED(%f)\n", (int)i, device_y_ivv.GetAt(i), cpu_y_ivv.GetAt(i));
|
||||
}
|
||||
WINML_EXPECT_TRUE(in_range);
|
||||
}
|
||||
}
|
||||
|
||||
static void GridSampleRunner(LearningModelDeviceKind kind,
|
||||
const std::vector<float>& input,
|
||||
const std::vector<int64_t>& input_dims,
|
||||
const std::vector<float>& grid,
|
||||
const std::vector<int64_t>& grid_dims)
|
||||
{
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bilinear, PaddingMode::Zeros);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bilinear, PaddingMode::Border);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bilinear, PaddingMode::Reflection);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Nearest, PaddingMode::Zeros);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Nearest, PaddingMode::Border);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Nearest, PaddingMode::Reflection);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bicubic, PaddingMode::Zeros);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bicubic, PaddingMode::Border);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bicubic, PaddingMode::Reflection);
|
||||
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bilinear, PaddingMode::Zeros);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bilinear, PaddingMode::Border);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bilinear, PaddingMode::Reflection);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Nearest, PaddingMode::Zeros);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Nearest, PaddingMode::Border);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Nearest, PaddingMode::Reflection);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bicubic, PaddingMode::Zeros);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bicubic, PaddingMode::Border);
|
||||
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bicubic, PaddingMode::Reflection);
|
||||
}
|
||||
|
||||
static void ModelBuilding_GridSample_Internal(LearningModelDeviceKind kind) {
|
||||
std::vector<float> input =
|
||||
{
|
||||
0.00f, 1.00f, 2.00f, 3.00f,
|
||||
4.00f, 5.00f, 6.00f, 7.00f,
|
||||
8.00f, 9.00f, 10.00f, 11.00f,
|
||||
12.00f, 13.00f, 14.00f, 15.00f,
|
||||
};
|
||||
|
||||
std::vector<float> grid =
|
||||
{
|
||||
0.00f, 1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f, 9.00f,
|
||||
10.00f, 11.00f, 12.00f, 13.00f, 14.00f, 15.00f, 16.00f, 17.00f, 18.00f, 19.00f,
|
||||
20.00f, 21.00f, 22.00f, 23.00f, 24.00f, 25.00f, 26.00f, 27.00f, 28.00f, 29.00f,
|
||||
30.00f, 31.00f, 32.00f, 33.00f, 34.00f, 35.00f, 36.00f, 37.00f, 38.00f, 39.00f,
|
||||
40.00f, 41.00f, 42.00f, 43.00f, 44.00f, 45.00f, 46.00f, 47.00f, 48.00f, 49.00f,
|
||||
};
|
||||
std::transform(grid.begin(), grid.end(), grid.begin(), [&](auto& in) { return in / grid.size(); });
|
||||
std::vector<int64_t> input_dims = {1, 1, 4, 4};
|
||||
std::vector<int64_t> grid_dims = {1, 5, 5, 2};
|
||||
|
||||
GridSampleRunner(kind, input, input_dims, grid, grid_dims);
|
||||
|
||||
input = { 0.0f, 1.0f, 2.0f, 3.0f, 4.0, 5.0f };
|
||||
grid =
|
||||
{
|
||||
-10.0000f, -10.0000f,
|
||||
-5.0000f, -5.0000f,
|
||||
-0.2000f, -0.2000f,
|
||||
10.0000f, 10.0000f,
|
||||
|
||||
10.0000f, 10.0000f,
|
||||
-0.2000f, -0.2000f,
|
||||
5.0000f, 5.0000f,
|
||||
10.0000f, 10.0000f
|
||||
};
|
||||
input_dims = {1, 1, 3, 2};
|
||||
grid_dims = {1, 2, 4, 2};
|
||||
|
||||
GridSampleRunner(kind, input, input_dims, grid, grid_dims);
|
||||
}
|
||||
|
||||
static void ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceKind kind) {
|
||||
std::vector<float> real_input =
|
||||
{
|
||||
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
};
|
||||
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
|
||||
};
|
||||
|
||||
std::vector<std::complex<float>> real_expected_axis_0_two_sided = {
|
||||
{5.000f, 0.000f}, {10.000f, 0.000f}, {15.000f, 0.000f}, {20.000f, 0.000f}, {25.000f, 0.000f}, {30.000f, 0.000f}, {35.000f, 0.000f}, {40.000f, 0.000f},
|
||||
|
|
@ -1155,17 +1313,17 @@ static void ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceK
|
|||
|
||||
std::vector<std::complex<float>> input =
|
||||
{
|
||||
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
|
||||
|
||||
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
|
||||
};
|
||||
|
||||
std::vector<std::complex<float>> expected_axis_0_two_sided = {
|
||||
|
|
@ -1259,6 +1417,12 @@ static void ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceK
|
|||
}
|
||||
#endif
|
||||
|
||||
static void ModelBuilding_GridSampleDeviceDirectX() {
|
||||
#if !defined(BUILD_INBOX)
|
||||
ModelBuilding_GridSample_Internal(LearningModelDeviceKind::DirectX);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ModelBuilding_DiscreteFourierTransform() {
|
||||
#if !defined(BUILD_INBOX)
|
||||
ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceKind::Cpu);
|
||||
|
|
@ -1563,6 +1727,7 @@ const LearningModelSessionAPITestsApi& getapi() {
|
|||
ModelBuilding_DiscreteFourierTransformInverseIdentity,
|
||||
ModelBuilding_DiscreteFourierTransformDeviceDirectX,
|
||||
ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX,
|
||||
ModelBuilding_GridSampleDeviceDirectX,
|
||||
ModelBuilding_HannWindow,
|
||||
ModelBuilding_HammingWindow,
|
||||
ModelBuilding_BlackmanWindow,
|
||||
|
|
@ -1581,6 +1746,7 @@ const LearningModelSessionAPITestsApi& getapi() {
|
|||
api.AdapterIdAndDevice = SkipTest;
|
||||
api.ModelBuilding_DiscreteFourierTransformDeviceDirectX = SkipTest;
|
||||
api.ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX = SkipTest;
|
||||
api.ModelBuilding_GridSampleDeviceDirectX = SkipTest;
|
||||
}
|
||||
if (RuntimeParameterExists(L"EdgeCore")) {
|
||||
api.AdapterIdAndDevice = SkipTest;
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ struct LearningModelSessionAPITestsApi {
|
|||
VoidTest ModelBuilding_DiscreteFourierTransformInverseIdentity;
|
||||
VoidTest ModelBuilding_DiscreteFourierTransformDeviceDirectX;
|
||||
VoidTest ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX;
|
||||
VoidTest ModelBuilding_GridSampleDeviceDirectX;
|
||||
VoidTest ModelBuilding_HannWindow;
|
||||
VoidTest ModelBuilding_HammingWindow;
|
||||
VoidTest ModelBuilding_BlackmanWindow;
|
||||
|
|
@ -68,6 +69,7 @@ WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransform)
|
|||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformInverseIdentity)
|
||||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformDeviceDirectX)
|
||||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX)
|
||||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_GridSampleDeviceDirectX)
|
||||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_HannWindow)
|
||||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_HammingWindow)
|
||||
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_BlackmanWindow)
|
||||
|
|
|
|||
Loading…
Reference in a new issue