Add GridSample implementation to DirectML (#15788)

Add GridSample implementation to DirectML EP.

Temporary add HLSL shader in the DirectML EP to handle GridSample until
officially added to DirectML.
This commit is contained in:
Sheil Kumar 2023-05-05 15:59:33 -07:00 committed by GitHub
parent 45f5c27632
commit 2b7f26af7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
39 changed files with 218019 additions and 18 deletions

View file

@ -955,6 +955,7 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(bool)|
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Hardmax|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|

View file

@ -661,9 +661,10 @@ namespace Dml
bool IsCustomOpShader(const onnxruntime::Node& node)
{
auto custom_ops = std::array<char*, 2>{
auto custom_ops = std::array<char*, 3>{
"DFT",
"STFT"
"STFT",
"GridSample"
};
for (auto& custom_op : custom_ops)

View file

@ -0,0 +1,988 @@
#pragma once
#include "../../../OperatorAuthorHelper/OperatorHelper.h"
#include "../MLOperatorAuthorImpl.h"
#include "../External/D3DX12/d3dx12.h"
#include <d3d12.h>
// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
// should be removed from IsCustomOpShader(...) in
// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
// The shader headers are produced using "GeneratedShaders/GenerateShaders.bat"
namespace GridSample_uint16_float
{
#include "GeneratedShaders/grid_sample_uint16_float.h"
}
namespace GridSample_uint_float
{
#include "GeneratedShaders/grid_sample_uint_float.h"
}
namespace GridSample_uint64_float
{
#include "GeneratedShaders/grid_sample_uint64_float.h"
}
namespace GridSample_int16_float
{
#include "GeneratedShaders/grid_sample_int16_float.h"
}
namespace GridSample_int_float
{
#include "GeneratedShaders/grid_sample_int_float.h"
}
namespace GridSample_int64_float
{
#include "GeneratedShaders/grid_sample_int64_float.h"
}
namespace GridSample_fp16_float
{
#include "GeneratedShaders/grid_sample_fp16_float.h"
}
namespace GridSample_float_float
{
#include "GeneratedShaders/grid_sample_float_float.h"
}
namespace GridSample_double_float
{
#include "GeneratedShaders/grid_sample_double_float.h"
}
namespace GridSample_bool_float
{
#include "GeneratedShaders/grid_sample_bool_float.h"
}
namespace GridSample_uint16_fp16
{
#include "GeneratedShaders/grid_sample_uint16_fp16.h"
}
namespace GridSample_uint_fp16
{
#include "GeneratedShaders/grid_sample_uint_fp16.h"
}
namespace GridSample_uint64_fp16
{
#include "GeneratedShaders/grid_sample_uint64_fp16.h"
}
namespace GridSample_int16_fp16
{
#include "GeneratedShaders/grid_sample_int16_fp16.h"
}
namespace GridSample_int_fp16
{
#include "GeneratedShaders/grid_sample_int_fp16.h"
}
namespace GridSample_int64_fp16
{
#include "GeneratedShaders/grid_sample_int64_fp16.h"
}
namespace GridSample_fp16_fp16
{
#include "GeneratedShaders/grid_sample_fp16_fp16.h"
}
namespace GridSample_float_fp16
{
#include "GeneratedShaders/grid_sample_float_fp16.h"
}
namespace GridSample_double_fp16
{
#include "GeneratedShaders/grid_sample_double_fp16.h"
}
namespace GridSample_bool_fp16
{
#include "GeneratedShaders/grid_sample_bool_fp16.h"
}
namespace GridSample_uint16_double
{
#include "GeneratedShaders/grid_sample_uint16_double.h"
}
namespace GridSample_uint_double
{
#include "GeneratedShaders/grid_sample_uint_double.h"
}
namespace GridSample_uint64_double
{
#include "GeneratedShaders/grid_sample_uint64_double.h"
}
namespace GridSample_int16_double
{
#include "GeneratedShaders/grid_sample_int16_double.h"
}
namespace GridSample_int_double
{
#include "GeneratedShaders/grid_sample_int_double.h"
}
namespace GridSample_int64_double
{
#include "GeneratedShaders/grid_sample_int64_double.h"
}
namespace GridSample_fp16_double
{
#include "GeneratedShaders/grid_sample_fp16_double.h"
}
namespace GridSample_float_double
{
#include "GeneratedShaders/grid_sample_float_double.h"
}
namespace GridSample_double_double
{
#include "GeneratedShaders/grid_sample_double_double.h"
}
namespace GridSample_bool_double
{
#include "GeneratedShaders/grid_sample_bool_double.h"
}
#include <wrl/client.h>
#include <wrl/implements.h>
#include <sstream>
using namespace Microsoft::WRL;
enum DmlGridSampleKernelInputIndex : uint32_t
{
Input,
Grid,
};
enum DmlGridSampleMode : uint32_t
{
Bilinear,
Nearest,
Bicubic,
};
enum DmlGridSamplePaddingMode : uint32_t
{
Zeros,
Border,
Reflection
};
// Helper to derive dimensions and attributes from either the GridSample shape inferrer or the GridSample kernel constructor.
struct DmlGridSampleParameters
{
uint32_t batchSize = 0;
uint32_t channelSize = 0;
uint32_t height = 0;
uint32_t width = 0;
int64_t alignCorners = 0;
DmlGridSampleMode mode = DmlGridSampleMode::Bilinear;
DmlGridSamplePaddingMode paddingMode = DmlGridSamplePaddingMode::Zeros;
DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN;
DmlGridSampleParameters(){}
DmlGridSampleParameters(
const OperatorHelper::IKernelInformationAdapter& kernelInfo,
const OperatorHelper::IShapeInformationAdapter& shapeInfo)
{
auto& attributes = kernelInfo.GetAttributes();
alignCorners = attributes.GetOptionalAttribute<int64_t>(AttrName::AlignCorners, 0);
std::string str_attrib = attributes.GetOptionalAttribute<std::string>(AttrName::Mode, "bilinear");
ML_CHECK_VALID_ARGUMENT(str_attrib == "bilinear" || str_attrib == "nearest" || str_attrib == "bicubic");
if (str_attrib == "bilinear")
{
mode = DmlGridSampleMode::Bilinear;
}
else if (str_attrib == "nearest")
{
mode = DmlGridSampleMode::Nearest;
}
else if (str_attrib == "bicubic")
{
mode = DmlGridSampleMode::Bicubic;
}
str_attrib = attributes.GetOptionalAttribute<std::string>(AttrName::PaddingMode, "zeros");
ML_CHECK_VALID_ARGUMENT(str_attrib == "zeros" || str_attrib == "border" || str_attrib == "reflection");
if (str_attrib == "zeros")
{
paddingMode = DmlGridSamplePaddingMode::Zeros;
}
else if (str_attrib == "border")
{
paddingMode = DmlGridSamplePaddingMode::Border;
}
else if (str_attrib == "reflection")
{
paddingMode = DmlGridSamplePaddingMode::Reflection;
}
// input 0: signal (required; tensor)
{
// Input shape is expected to be [batch_size, channels, height, width]
// 4-D tensor of shape (N, C, H_out, W_out) of sampled values.
// For integer input types, intermediate values are computed as floating point and cast to integer at the end. uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input);
uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input);
ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D.");
auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Input);
assert(dims.size() == rank);
this->batchSize = dims[0];
this->channelSize = dims[1];
MLOperatorEdgeDescription edgeDesc = kernelInfo.GetInputEdgeDescription(DmlGridSampleKernelInputIndex::Input);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
this->dataType = Dml::GetDmlDataTypeFromMlDataType(edgeDesc.tensorDataType);
}
// input 1: grid (required; tensor)
{
// Grid shape is expected to be [batch_size, height_out, width_out, 2]
uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Grid);
ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D.");
auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Grid);
assert(dims.size() == rank);
this->height = dims[1];
this->width = dims[2];
}
}
};
namespace GridSampleHelpers
{
// Divides and rounds
inline uint32_t CeilDivide(uint32_t dividend, uint32_t divisor)
{
uint64_t temp = static_cast<uint64_t>(dividend) + divisor - 1;
return static_cast<uint32_t>(temp / divisor);
}
// Gets the next number of elements to dispatch to the GPU within a loop handling a large
// total number of tensor elements and threads.
void GetNextDispatchSize(
uint32_t elementCount,
uint32_t elementsPerThread,
uint32_t numThreads,
_Out_ uint32_t& dispatch,
_Out_ uint32_t& pendingElementCount
)
{
// Max threads per workgroup is 2^10 (1024). Max dispatch per dimension is 2^16. Taken together, we can dispatch a maximum of
// 2^26 (268,435,456) threads along a single dimension. This should suffice for a majority of the workload. Therefore, even
// though it is possible to dispatch up to (2^16)^3 workgroups simultaneously, we stick to the simpler 1D dispatch alternative.
assert(numThreads <= D3D12_CS_THREAD_GROUP_MAX_THREADS_PER_GROUP);
const uint32_t maxThreadsPerDispatch = numThreads * D3D12_CS_DISPATCH_MAX_THREAD_GROUPS_PER_DIMENSION;
const uint32_t requiredThreadCount = CeilDivide(elementCount, elementsPerThread);
// Compute max dispatchable elements
const uint32_t availableThreadCount = std::min(requiredThreadCount, maxThreadsPerDispatch);
// Compute required thread group count
uint32_t workGroupCount1D = CeilDivide(availableThreadCount, numThreads);
// Compute min dispatch size
dispatch = workGroupCount1D;
// With the dispatch size computed, compute the dispatched element count
const uint32_t dispatchedElementCount = workGroupCount1D * numThreads * elementsPerThread;
// Update the pending element count
pendingElementCount = (dispatchedElementCount < elementCount) ? elementCount - dispatchedElementCount : 0;
}
}
class DmlGridSampleOperator : public WRL::Base<IMLOperatorKernel>
{
private:
ComPtr<ID3D12Device> m_device;
ComPtr<ID3D12RootSignature> m_gridSampleRootSignature;
ComPtr<ID3D12PipelineState> m_gridSamplePipelineState;
DmlGridSampleParameters m_params = {};
// Allocate temporary buffers if needed
struct ResourceDesc
{
ComPtr<ID3D12Resource> Resource;
std::array<uint32_t, 4> Sizes;
std::array<uint32_t, 4> Strides;
};
struct GridSampleShaderConstants
{
uint32_t StartIndex;
uint32_t ElementCount;
uint32_t Mode;
uint32_t PaddingMode;
uint32_t InputSizes[4];
uint32_t InputStrides[4];
uint32_t GridSizes[4];
uint32_t GridStrides[4];
uint32_t OutputSizes[4];
uint32_t OutputStrides[4];
uint32_t AlignCorners;
};
public:
DmlGridSampleOperator(IMLOperatorKernelCreationContext* context)
{
ComPtr<IUnknown> executionObject;
context->GetExecutionInterface(executionObject.GetAddressOf());
ComPtr<ID3D12GraphicsCommandList> commandList;
executionObject.As(&commandList);
ORT_THROW_IF_FAILED(commandList->GetDevice(IID_ID3D12Device, &m_device));
MLOperatorKernelCreationContext creationContext(context);
OperatorHelper::KernelInformationAdapter kernelInfo{creationContext};
OperatorHelper::ShapeInformationAdapter shapeInfo{creationContext};
m_params = DmlGridSampleParameters(kernelInfo, shapeInfo);
MLOperatorEdgeDescription inputEdgeDesc;
ORT_THROW_IF_FAILED(context->GetInputEdgeDescription(0, &inputEdgeDesc));
assert(inputEdgeDesc.edgeType == MLOperatorEdgeType::Tensor);
MLOperatorEdgeDescription gridEdgeDesc;
ORT_THROW_IF_FAILED(context->GetInputEdgeDescription(0, &gridEdgeDesc));
assert(gridEdgeDesc.edgeType == MLOperatorEdgeType::Tensor);
PrepareGridSample(inputEdgeDesc.tensorDataType, gridEdgeDesc.tensorDataType);
}
void PrepareGridSample(MLOperatorTensorDataType inputDataType, MLOperatorTensorDataType gridDataType)
{
// Compute root signature.
const int uavCount = 3; // 3 bound UAVs: input, grid & output
std::vector<CD3DX12_ROOT_PARAMETER1> rootParameters;
rootParameters.resize(uavCount + 1);
for (uint32_t i = 0; i < uavCount; i++)
{
rootParameters[i].InitAsUnorderedAccessView(i);
}
// cbuffer Constants
// {
// uint StartIndex;
// uint ElementCount;
// uint Mode;
// uint PaddingMode;
// uint4 InputSizes;
// uint4 InputStrides;
// uint4 GridSizes;
// uint4 GridStrides;
// uint4 OutputSizes;
// uint4 OutputStrides;
// uint AlignCorners;
// };
int constantCount = 29;
rootParameters[uavCount].InitAsConstants(constantCount, 0);
CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC desc;
desc.Init_1_1(static_cast<uint32_t>(rootParameters.size()), rootParameters.data());
ComPtr<ID3DBlob> rootSignatureBlob;
ComPtr<ID3DBlob> rootSignatureErrorBlob;
ORT_THROW_IF_FAILED(D3D12SerializeVersionedRootSignature(
&desc,
rootSignatureBlob.GetAddressOf(),
rootSignatureErrorBlob.GetAddressOf()
));
ORT_THROW_IF_FAILED(m_device->CreateRootSignature(
0,
rootSignatureBlob->GetBufferPointer(),
rootSignatureBlob->GetBufferSize(),
IID_ID3D12RootSignature,
&m_gridSampleRootSignature
));
// Describe and create the compute pipeline state object (PSO).
D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
computePsoDesc.pRootSignature = m_gridSampleRootSignature.Get();
switch (gridDataType)
{
case MLOperatorTensorDataType::Float:
{
switch (inputDataType)
{
case MLOperatorTensorDataType::UInt16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_float::g_GridSample, sizeof(GridSample_uint16_float::g_GridSample));
break;
case MLOperatorTensorDataType::UInt32:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_float::g_GridSample, sizeof(GridSample_uint_float::g_GridSample));
break;
case MLOperatorTensorDataType::UInt64:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_float::g_GridSample, sizeof(GridSample_uint64_float::g_GridSample));
break;
case MLOperatorTensorDataType::Int16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_float::g_GridSample, sizeof(GridSample_int16_float::g_GridSample));
break;
case MLOperatorTensorDataType::Int32:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_float::g_GridSample, sizeof(GridSample_int_float::g_GridSample));
break;
case MLOperatorTensorDataType::Int64:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_float::g_GridSample, sizeof(GridSample_int64_float::g_GridSample));
break;
case MLOperatorTensorDataType::Float16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_float::g_GridSample, sizeof(GridSample_fp16_float::g_GridSample));
break;
case MLOperatorTensorDataType::Float:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_float::g_GridSample, sizeof(GridSample_float_float::g_GridSample));
break;
case MLOperatorTensorDataType::Double:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_float::g_GridSample, sizeof(GridSample_double_float::g_GridSample));
break;
case MLOperatorTensorDataType::Bool:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_float::g_GridSample, sizeof(GridSample_bool_float::g_GridSample));
break;
default:
ORT_THROW_HR(E_INVALIDARG);
}
break;
}
case MLOperatorTensorDataType::Float16:
{
switch (inputDataType)
{
case MLOperatorTensorDataType::UInt16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_fp16::g_GridSample, sizeof(GridSample_uint16_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::UInt32:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_fp16::g_GridSample, sizeof(GridSample_uint_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::UInt64:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_fp16::g_GridSample, sizeof(GridSample_uint64_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Int16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_fp16::g_GridSample, sizeof(GridSample_int16_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Int32:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_fp16::g_GridSample, sizeof(GridSample_int_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Int64:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_fp16::g_GridSample, sizeof(GridSample_int64_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Float16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_fp16::g_GridSample, sizeof(GridSample_fp16_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Float:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_fp16::g_GridSample, sizeof(GridSample_float_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Double:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_fp16::g_GridSample, sizeof(GridSample_double_fp16::g_GridSample));
break;
case MLOperatorTensorDataType::Bool:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_fp16::g_GridSample, sizeof(GridSample_bool_fp16::g_GridSample));
break;
default:
ORT_THROW_HR(E_INVALIDARG);
}
break;
}
case MLOperatorTensorDataType::Double:
{
switch (inputDataType)
{
case MLOperatorTensorDataType::UInt16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_double::g_GridSample, sizeof(GridSample_uint16_double::g_GridSample));
break;
case MLOperatorTensorDataType::UInt32:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_double::g_GridSample, sizeof(GridSample_uint_double::g_GridSample));
break;
case MLOperatorTensorDataType::UInt64:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_double::g_GridSample, sizeof(GridSample_uint64_double::g_GridSample));
break;
case MLOperatorTensorDataType::Int16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_double::g_GridSample, sizeof(GridSample_int16_double::g_GridSample));
break;
case MLOperatorTensorDataType::Int32:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_double::g_GridSample, sizeof(GridSample_int_double::g_GridSample));
break;
case MLOperatorTensorDataType::Int64:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_double::g_GridSample, sizeof(GridSample_int64_double::g_GridSample));
break;
case MLOperatorTensorDataType::Float16:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_double::g_GridSample, sizeof(GridSample_fp16_double::g_GridSample));
break;
case MLOperatorTensorDataType::Float:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_double::g_GridSample, sizeof(GridSample_float_double::g_GridSample));
break;
case MLOperatorTensorDataType::Double:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_double::g_GridSample, sizeof(GridSample_double_double::g_GridSample));
break;
case MLOperatorTensorDataType::Bool:
computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_double::g_GridSample, sizeof(GridSample_bool_double::g_GridSample));
break;
default:
ORT_THROW_HR(E_INVALIDARG);
}
break;
}
default:
ORT_THROW_HR(E_INVALIDARG);
}
ORT_THROW_IF_FAILED(m_device->CreateComputePipelineState(&computePsoDesc, IID_ID3D12PipelineState, &m_gridSamplePipelineState));
}
// Computes the outputs of the kernel. This may be called multiple times
// simultaneously within the same instance of the class. Implementations
// of this method must be thread-safe.
STDMETHOD(Compute)(IMLOperatorKernelContext* context)
{
try
{
// Get the input tensor
ComPtr<IMLOperatorTensor> inputTensor;
ORT_THROW_IF_FAILED(context->GetInputTensor(0, inputTensor.GetAddressOf()));
// Get the grid tensor
ComPtr<IMLOperatorTensor> gridTensor;
ORT_THROW_IF_FAILED(context->GetInputTensor(1, gridTensor.GetAddressOf()));
// Get the output tensor
ComPtr<IMLOperatorTensor> outputTensor;
context->GetOutputTensor(0, outputTensor.GetAddressOf());
if (outputTensor->IsCpuData() || inputTensor->IsCpuData() || gridTensor->IsCpuData())
{
return E_UNEXPECTED;
}
ComPtr<IUnknown> executionObject;
ComPtr<ID3D12GraphicsCommandList> commandList;
context->GetExecutionInterface(executionObject.GetAddressOf());
executionObject.As(&commandList);
// Get the input and output shape sizes
auto inputDims = GetTensorDimensions(inputTensor.Get());
auto gridDims = GetTensorDimensions(gridTensor.Get());
auto outputDims = GetTensorDimensions(outputTensor.Get());
ComPtr<IUnknown> inputUnknown;
ComPtr<ID3D12Resource> inputResource;
inputTensor->GetDataInterface(inputUnknown.GetAddressOf());
ORT_THROW_IF_FAILED(inputUnknown.As(&inputResource));
ComPtr<IUnknown> gridUnknown;
ComPtr<ID3D12Resource> gridResource;
gridTensor->GetDataInterface(gridUnknown.GetAddressOf());
ORT_THROW_IF_FAILED(gridUnknown.As(&gridResource));
ComPtr<IUnknown> outputUnknown;
ComPtr<ID3D12Resource> outputResource;
outputTensor->GetDataInterface(outputUnknown.GetAddressOf());
ORT_THROW_IF_FAILED(outputUnknown.As(&outputResource));
return Compute(
commandList.Get(),
context,
inputResource.Get(),
inputDims,
gridResource.Get(),
gridDims,
outputResource.Get(),
outputDims
);
}
catch (...)
{
return E_FAIL;
}
return S_OK;
}
HRESULT Compute(
ID3D12GraphicsCommandList* commandList,
IMLOperatorKernelContext* context,
ID3D12Resource* inputResource,
gsl::span<const uint32_t> inputDims,
ID3D12Resource* gridResource,
gsl::span<const uint32_t> gridDims,
ID3D12Resource* outputResource,
gsl::span<const uint32_t> outputDims)
{
try
{
GridSample(
inputResource,
inputDims,
gridResource,
gridDims,
outputResource,
outputDims,
commandList);
}
catch (...)
{
return E_FAIL;
}
return S_OK;
}
void GridSample(
ID3D12Resource* inputResource,
gsl::span<const uint32_t> inputDims,
ID3D12Resource* gridResource,
gsl::span<const uint32_t> gridDims,
ID3D12Resource* outputResource,
gsl::span<const uint32_t> outputDims,
ID3D12GraphicsCommandList* commandList)
{
std::array<uint32_t, 4> inputStrides;
std::array<uint32_t, 4> gridStrides;
std::array<uint32_t, 4> outputStrides;
Dml::GetDescendingPackedStrides(inputDims, inputStrides);
Dml::GetDescendingPackedStrides(gridDims, gridStrides);
Dml::GetDescendingPackedStrides(outputDims, outputStrides);
// Transition resources from common to UAV state
D3D12_RESOURCE_BARRIER barriers[3];
barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition(
inputResource,
D3D12_RESOURCE_STATE_COMMON,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS
);
barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition(
gridResource,
D3D12_RESOURCE_STATE_COMMON,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS
);
barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition(
outputResource,
D3D12_RESOURCE_STATE_COMMON,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS
);
inputResource->SetName(L"InputResource");
outputResource->SetName(L"OutputResource");
gridResource->SetName(L"GridResource");
commandList->ResourceBarrier(3, barriers);
// Set the root signature and pipeline state
commandList->SetComputeRootSignature(m_gridSampleRootSignature.Get());
commandList->SetPipelineState(m_gridSamplePipelineState.Get());
// Each iteration of the below loop represents 1 level in the Stockham DFT
// Dispatch in a loop
GridSampleShaderConstants constants = {};
constants.AlignCorners = static_cast<uint32_t>(m_params.alignCorners);
constants.Mode = static_cast<uint32_t>(m_params.mode);
constants.PaddingMode = static_cast<uint32_t>(m_params.paddingMode);
std::copy(inputDims.begin(), inputDims.end(), constants.InputSizes);
std::copy(inputStrides.begin(), inputStrides.end(), constants.InputStrides);
std::copy(gridDims.begin(), gridDims.end(), constants.GridSizes);
std::copy(gridStrides.begin(), gridStrides.end(), constants.GridStrides);
std::copy(outputDims.begin(), outputDims.end(), constants.OutputSizes);
std::copy(outputStrides.begin(), outputStrides.end(), constants.OutputStrides);
constants.ElementCount = ComputeElementCountFromDimensions(constants.OutputSizes);
std::array<ID3D12Resource*, 3> uav_resources = { inputResource, gridResource, outputResource };
Dispatch(uav_resources, constants, commandList);
// Transition resources to common state
barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition(
inputResource,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_COMMON
);
barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition(
gridResource,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_COMMON
);
barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition(
outputResource,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_COMMON
);
commandList->ResourceBarrier(3, barriers);
}
std::vector<uint32_t> GetTensorDimensions(IMLOperatorTensor* tensor)
{
auto inputDimsSize = tensor->GetDimensionCount();
auto dims = std::vector<uint32_t>(inputDimsSize);
ORT_THROW_IF_FAILED(tensor->GetShape(static_cast<uint32_t>(dims.size()), dims.data()));
return dims;
}
template <typename TConstants, uint32_t TSize>
void Dispatch(
std::array<ID3D12Resource*, TSize>& resources,
TConstants& constants,
ID3D12GraphicsCommandList* commandList)
{
D3D12_RESOURCE_BARRIER uav_barriers[TSize];
std::transform(
resources.begin(), resources.end(),
uav_barriers,
[](auto& resource) { return CD3DX12_RESOURCE_BARRIER::UAV(resource); } );
commandList->ResourceBarrier(TSize, uav_barriers);
for (uint32_t i = 0; i < TSize; i++)
{
// Set resource views
if (resources[i]) {
commandList->SetComputeRootUnorderedAccessView(
i, // root parameter index
resources[i]->GetGPUVirtualAddress()
);
}
else
{
commandList->SetComputeRootUnorderedAccessView(
i, // root parameter index
{}
);
}
}
auto pendingElementCount = constants.ElementCount;
// Dispatch up to the maximum number of threads per iteration until
// all elements are completed
while (pendingElementCount > 0)
{
constants.StartIndex = constants.ElementCount - pendingElementCount;
uint32_t dispatchSizeX;
GridSampleHelpers::GetNextDispatchSize(
pendingElementCount,
1,
64,
dispatchSizeX,
pendingElementCount
);
// Set root constants
commandList->SetComputeRoot32BitConstants(
TSize, // root parameter index
29, // Constant count
&constants,
0 // offset
);
commandList->Dispatch(dispatchSizeX, 1, 1);
}
commandList->ResourceBarrier(2, uav_barriers);
}
};
struct GridSampleShapeInferrer : public WRL::Base<IMLOperatorShapeInferrer>
{
STDMETHOD(InferOutputShapes)(IMLOperatorShapeInferenceContext* context) noexcept
{
try
{
ComPtr<IMLOperatorShapeInferenceContextPrivate> contextPrivate;
ORT_THROW_IF_FAILED(context->QueryInterface(IID_PPV_ARGS(&contextPrivate)));
MLShapeInferenceContext inferenceContext(context);
OperatorHelper::KernelInformationAdapter kernelInfo{inferenceContext};
OperatorHelper::ShapeInformationAdapter shapeInfo{inferenceContext};
DmlGridSampleParameters params(kernelInfo, shapeInfo);
std::array<uint32_t, 4> outputDims = { params.batchSize, params.channelSize, params.height, params.width };
ORT_THROW_IF_FAILED(context->SetOutputTensorShape(0, onnxruntime::narrow<uint32_t>(outputDims.size()), outputDims.data()));
}
catch (...)
{
return E_FAIL;
}
return S_OK;
}
};
class DmlGridSampleOperatorFactory : public WRL::Base<IMLOperatorKernelFactory>
{
public:
STDMETHOD(CreateKernel)(
IMLOperatorKernelCreationContext* context,
IMLOperatorKernel** kernel)
{
try
{
auto dftOperator = wil::MakeOrThrow<DmlGridSampleOperator>(context);
dftOperator.CopyTo(kernel);
return S_OK;
}
catch (...)
{
return E_FAIL;
}
}
static void RegisterGridSampleKernel(IMLOperatorRegistry* registry)
{
MLOperatorKernelDescription kernelDescription = {};
kernelDescription.domain = "";
kernelDescription.name = "GridSample";
kernelDescription.minimumOperatorSetVersion = 16;
kernelDescription.executionType = MLOperatorExecutionType::D3D12;
// T1: tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
MLOperatorEdgeTypeConstrant t1Constraint;
t1Constraint.typeLabel = "T1";
std::vector<MLOperatorEdgeDescription> t1AllowedEdges
{
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float16 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int8 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int16 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int32 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int64 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt8 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt16 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt32 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt64 },
};
t1Constraint.allowedTypes = t1AllowedEdges.data();
t1Constraint.allowedTypeCount = static_cast<uint32_t>(t1AllowedEdges.size());
// T2 : tensor(int32), tensor(int64)
MLOperatorEdgeTypeConstrant t2Constraint;
t2Constraint.typeLabel = "T2";
std::vector<MLOperatorEdgeDescription> t2AllowedEdges
{
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float16 },
MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float },
};
t2Constraint.allowedTypes = t2AllowedEdges.data();
t2Constraint.allowedTypeCount = static_cast<uint32_t>(t2AllowedEdges.size());
std::vector<MLOperatorEdgeTypeConstrant> typeConstraints{ t1Constraint, t2Constraint };
kernelDescription.typeConstraints = typeConstraints.data();
kernelDescription.typeConstraintCount = static_cast<uint32_t>(typeConstraints.size());
MLOperatorAttributeNameValue alignedCornersAttributeValue;
alignedCornersAttributeValue.name = AttrName::AlignCorners;
alignedCornersAttributeValue.type = MLOperatorAttributeType::Int;
alignedCornersAttributeValue.valueCount = 1;
static const int64_t alignedCorners[] = { 0 };
alignedCornersAttributeValue.ints = alignedCorners;
MLOperatorAttributeNameValue modeAttributeValue;
modeAttributeValue.name = AttrName::Mode;
modeAttributeValue.type = MLOperatorAttributeType::String;
modeAttributeValue.valueCount = 1;
static const char* modes[] = { "bilinear" };
modeAttributeValue.strings = modes;
MLOperatorAttributeNameValue paddingModeAttributeValue;
paddingModeAttributeValue.name = AttrName::Mode;
paddingModeAttributeValue.type = MLOperatorAttributeType::String;
paddingModeAttributeValue.valueCount = 1;
static const char* paddingModes[] = { "zeros" };
paddingModeAttributeValue.strings = paddingModes;
std::vector<MLOperatorAttributeNameValue> attributeDefaultValues{
alignedCornersAttributeValue,
modeAttributeValue,
paddingModeAttributeValue
};
kernelDescription.defaultAttributes = attributeDefaultValues.data();
kernelDescription.defaultAttributeCount = static_cast<uint32_t>(attributeDefaultValues.size());
kernelDescription.options = MLOperatorKernelOptions::None;
kernelDescription.executionOptions = 0;
auto shareInferrer = wil::MakeOrThrow<GridSampleShapeInferrer>();
auto factory = wil::MakeOrThrow<DmlGridSampleOperatorFactory>();
ComPtr<IMLOperatorRegistryPrivate> registryPrivate;
ORT_THROW_IF_FAILED(registry->QueryInterface(IID_PPV_ARGS(&registryPrivate)));
ORT_THROW_IF_FAILED(registryPrivate->RegisterOperatorKernel(
&kernelDescription,
factory.Get(),
shareInferrer.Get(),
nullptr,
false, // isInternalOperator
false, // alias
false, // supportsGraph
nullptr,
nullptr,
0));
}
};

View file

@ -7,10 +7,78 @@ if "%1" == "DEBUG" (
fxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_5_0 /DTBUFFER=float /Zi /Od /Fh bluestein_chirp.h
dxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh bluestein_chirp_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=uint /DTBUFFER2=float /Zi /Od /Fh grid_sample_uint_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=int /DTBUFFER2=float /Zi /Od /Fh grid_sample_int_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /Zi /Od /Fh grid_sample_float_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /Zi /Od /Fh grid_sample_double_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /Zi /Od /Fh grid_sample_bool_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_double.h
) else (
fxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh stockham.h
dxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh stockham_fp16.h
fxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh bluestein_chirp.h
dxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh bluestein_chirp_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=uint /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_uint_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=int /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_int_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_float_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_double_float.h
fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_bool_float.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_fp16.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_double.h
dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_double.h
)

View file

@ -4,6 +4,7 @@
#include "precomp.h"
#include "DmlDFT.h"
#include "DmlSTFT.h"
#include "DmlGridSample.h"
#include "OperatorRegistration.h"
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
#include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h"
@ -1054,6 +1055,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
GpuDFTOperatorFactory::RegisterDFTKernel(registry);
DmlSTFTOperatorFactory::RegisterSTFTKernel(registry);
DmlGridSampleOperatorFactory::RegisterGridSampleKernel(registry);
}
} // namespace Dml

View file

@ -0,0 +1,289 @@
// TBUFFER is the data type to read from src and write to dst.
// Arithmetic is always done in FP32.
#if !defined(TBUFFER1)
#define TBUFFER1 float
#endif
#if !defined(TBUFFER2)
#define TBUFFER2 float
#endif
RWStructuredBuffer<TBUFFER1> input : register(u0);
RWStructuredBuffer<TBUFFER2> grid : register(u1);
RWStructuredBuffer<TBUFFER1> output : register(u2);
static const uint Zeros = 0;
static const uint Border = 1;
static const uint Reflection = 2;
static const uint Bilinear = 0;
static const uint Nearest = 1;
static const uint Bicubic = 2;
cbuffer Constants
{
uint StartIndex;
uint ElementCount;
uint Mode;
uint PaddingMode;
uint4 InputSizes;
uint4 InputStrides;
uint4 GridSizes;
uint4 GridStrides;
uint4 OutputSizes;
uint4 OutputStrides;
uint AlignCorners;
};
uint4 DecomposeIndex(uint index)
{
uint4 idx = uint4(0, 0, 0, 0);
uint n = OutputSizes.x;
uint c = OutputSizes.y;
uint h = OutputSizes.z;
uint w = OutputSizes.w;
uint width_denominator = 1;
uint height_denominator = w;
uint channel_denominator = w * h;
uint batch_denominator = w * h * c;
idx.x = (index) / batch_denominator; // batch
idx.y = (index - (idx.x * OutputStrides.x)) / channel_denominator; // channel
idx.z = (index - (idx.x * OutputStrides.x) - (idx.y * OutputStrides.y)) / height_denominator; // height
idx.w = (index - (idx.x * OutputStrides.x) - (idx.y * OutputStrides.y) - (idx.z * OutputStrides.z)) / width_denominator; // width
return idx;
}
// Returns the indices for the real and complex output uav
float2 FetchGridVector(uint4 index)
{
// The index is in (n,c,h,w)
// The shape of gridsizes (and gridstrides) is (n, h, w, 2)
uint n = index.x;
uint c = index.y;
uint h = index.z;
uint w = index.w;
float4 gridIdx = float4(n, h, w, 0);
float2 flattenedGridIndex = float2(0, 0);
flattenedGridIndex.x = dot(gridIdx, GridStrides);
flattenedGridIndex.y = flattenedGridIndex.x + GridStrides.w;
return float2((float)grid[flattenedGridIndex.x],
(float)grid[flattenedGridIndex.y]);
}
// {x_min, y_min, x_max, y_max};
float4 CalculateBorders()
{
// Force float here to avoid possible issue in integer T case
float2 mins = float2(-0.5f, -0.5f);
float2 maxes = float2(InputSizes.w - 0.5f, // W_in
InputSizes.z - 0.5f); // H_in
if (AlignCorners) {
mins = float2(0.f, 0.f);
maxes = float2(InputSizes.w - 1.f, // W_in
InputSizes.z - 1.f); // H_in
}
return float4(mins.xy, maxes.xy);
}
// Reflect by the near border till within the borders
// Use float for borders to avoid potential issues with integer T
float Reflect(float x, float x_min, float x_max) {
float range = x_max - x_min;
if (x < x_min) {
float dx = x_min - x;
uint n = dx / range;
float r = dx - n * range;
if (n % 2 == 0) {
x = x_min + r;
} else {
x = x_max - r;
}
} else if (x > x_max) {
float dx = x - x_max;
uint n = dx / range;
float r = dx - n * range;
if (n % 2 == 0) {
x = x_max - r;
} else {
x = x_min + r;
}
}
// else fallthrough
return x;
}
// Restore normalized location to actual image location
// When align_corners is true:
// Normalized location (-1, -1) points to the top-left pixel.
// Normalized location (1, 1) points to the bottom-right pixel.
// When align_corners is false [default]:
// Normalized location (-1, -1) points to the top-left pixel minus half
// pixel in both directions, i.e, (-0.5, -0.5) in actual image space.
// Normalized location (1, 1) points to the bottom-right pixel plus half
// pixel in both directions, i.e. (H - 0.5, W - 0.5) in actual image space.
float2 DenormalizeInput(float2 n, float4 border)
{
float2 dims = InputSizes.wz; // w-h
if (AlignCorners == 1)
{
// AlignCorners: true => [-1, 1] to [0, dims - 1]
n = (n + 1) / 2.f * (dims - 1);
}
else
{
// AlignCorners: false => [-1, 1] to [-0.5, dims - 0.5]
n = ((n + 1) * dims - 1) / 2.f;
}
if (Mode == Nearest)
{
n = round(n);
}
float x_min = border.x;
float y_min = border.y;
float x_max = border.z;
float y_max = border.w;
if (n.x < x_min || n.x > x_max || n.y < y_min || n.y > y_max) { // out of bound
if (PaddingMode == Border) {
// use original border in both align_corner cases
n.x = clamp(n.x, 0, InputSizes.w - 1);
n.y = clamp(n.y, 0, InputSizes.z - 1);
} else if (PaddingMode == Reflection) {
n.x = Reflect(n.x, x_min, x_max);
n.y = Reflect(n.y, y_min, y_max);
}
} // out of bound
return n;
}
float FetchInputPixel(uint4 index)
{
// index and InputStrides is in (n,c, h, w)
return (float)input[dot(index, InputStrides)];
}
float PixelAtGrid(float4 inputIdx, float4 border) {
float pixel = 0; // default 0
if (PaddingMode == Zeros)
{
if (inputIdx.w >= 0 && (uint)inputIdx.w < (uint)InputSizes.w &&
inputIdx.z >= 0 && (uint)inputIdx.z < (uint)InputSizes.z)
{
pixel = FetchInputPixel(inputIdx);
}
}
else if (PaddingMode == Border)
{
uint w = clamp(inputIdx.w, 0, InputSizes.w - 1);
uint z = clamp(inputIdx.z, 0, InputSizes.z - 1);
pixel = FetchInputPixel(float4(inputIdx.xy, z, w));
}
else if (PaddingMode == Reflection)
{
uint w = Reflect(inputIdx.w, border.x, border.z);
uint z = Reflect(inputIdx.z, border.y, border.w);
pixel = FetchInputPixel(float4(inputIdx.xy, z, w));
}
return pixel;
}
float BicubicConvolutionPFunction(float t, float fminus1, float f0, float f1, float f2)
{
static const float a = -.75;
static const float4x4 bicubicConvolutionMatrix =
{
0, 1, 0, 0,
a, 0, -a, 0,
-2*a, -3-a, 3+2*a, a,
a, 2+a, -2-a, -a
};
float4 t_vec = float4( 1, t, t * t, t * t * t);
float4 f_vec = float4(fminus1, f0, f1, f2);
return mul(t_vec, mul(bicubicConvolutionMatrix, f_vec));
}
[numthreads(64, 1, 1)]
void GridSample(uint3 dtid : SV_DispatchThreadId)
{
uint n = StartIndex + dtid.x;
if (n < ElementCount)
{
float4 border = CalculateBorders();
uint4 index = DecomposeIndex(n);
float2 flowVector = FetchGridVector(index);
float2 inputWidthAndHeightIdx = DenormalizeInput(flowVector, border);
float4 inputIdx = float4(index.x, // N
index.y, // C
inputWidthAndHeightIdx.y, // H
inputWidthAndHeightIdx.x); // W
if (Mode == Nearest)
{
output[n] = (TBUFFER1)(PixelAtGrid(inputIdx, border));
}
else if (Mode == Bilinear)
{
float x1 = floor(inputIdx.w);
float y1 = floor(inputIdx.z);
float x2 = x1 + 1;
float y2 = y1 + 1;
float p11 = PixelAtGrid(float4(index.x, index.y, y1, x1), border);
float p12 = PixelAtGrid(float4(index.x, index.y, y1, x2), border);
float p21 = PixelAtGrid(float4(index.x, index.y, y2, x1), border);
float p22 = PixelAtGrid(float4(index.x, index.y, y2, x2), border);
// p11--p12
// | |
// p21--p22
float p1 = lerp(p11, p12, frac(inputIdx.w));
float p2 = lerp(p21, p22, frac(inputIdx.w));
float p = lerp(p1, p2, frac(inputIdx.z));
output[n] = (TBUFFER1)(p);
}
else if (Mode == Bicubic)
{
float x0 = floor(inputIdx.w) - 1;
float y0 = floor(inputIdx.z) - 1;
float f00 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 0), border);
float f01 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 1), border);
float f02 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 2), border);
float f03 = PixelAtGrid(float4(index.x, index.y, y0 + 0, x0 + 3), border);
float f10 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 0), border);
float f11 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 1), border);
float f12 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 2), border);
float f13 = PixelAtGrid(float4(index.x, index.y, y0 + 1, x0 + 3), border);
float f20 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 0), border);
float f21 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 1), border);
float f22 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 2), border);
float f23 = PixelAtGrid(float4(index.x, index.y, y0 + 2, x0 + 3), border);
float f30 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 0), border);
float f31 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 1), border);
float f32 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 2), border);
float f33 = PixelAtGrid(float4(index.x, index.y, y0 + 3, x0 + 3), border);
float tx = frac(inputIdx.w);
float ty = frac(inputIdx.z);
float bminus1 = BicubicConvolutionPFunction(ty, f00, f10, f20, f30);
float b0 = BicubicConvolutionPFunction(ty, f01, f11, f21, f31);
float b1 = BicubicConvolutionPFunction(ty, f02, f12, f22, f32);
float b2 = BicubicConvolutionPFunction(ty, f03, f13, f23, f33);
float p = BicubicConvolutionPFunction(tx, bminus1, b0, b1, b2);
output[n] = (TBUFFER1)(p);
}
}
}

View file

@ -11,6 +11,7 @@ namespace AttrName
static constexpr const char* Activations = "activations";
static constexpr const char* AllowZero = "allowzero";
static constexpr const char* Alpha = "alpha";
static constexpr const char* AlignCorners = "align_corners";
static constexpr const char* AutoPad = "auto_pad";
static constexpr const char* Axes = "axes";
static constexpr const char* Axis = "axis";
@ -64,6 +65,7 @@ namespace AttrName
static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes";
static constexpr const char* NormalizeVariance = "normalize_variance";
static constexpr const char* P = "p";
static constexpr const char* PaddingMode = "padding_mode";
static constexpr const char* OutputHeight = "output_height";
static constexpr const char* OutputShape = "output_shape";
static constexpr const char* OutputPadding = "output_padding";

View file

@ -1124,16 +1124,174 @@ static void ModelBuilding_ConstantMatmul() {
#endif
}
#if !defined(BUILD_INBOX)
enum class Mode : uint32_t {
Bilinear,
Nearest,
Bicubic,
};
enum class PaddingMode : uint32_t {
Zeros,
Border,
Reflection,
};
template <typename T, typename U>
static void GridSample(
LearningModelDeviceKind kind,
const std::vector<T>& input,
const std::vector<int64_t>& input_dims,
const std::vector<U>& grid,
const std::vector<int64_t>& grid_dims,
bool align_corners,
Mode mode,
PaddingMode padding_mode
) {
const hstring modes[] = {
L"bilinear",
L"nearest",
L"bicubic"};
const hstring padding_modes[] = {
L"zeros",
L"border",
L"reflection"};
auto model =
LearningModelBuilder::Create(17)
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Input", TensorKind::Float, input_dims))
.Inputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Grid", TensorKind::Float, grid_dims))
.Outputs().Add(LearningModelBuilder::CreateTensorFeatureDescriptor(L"Output", TensorKind::Float, {-1, -1, -1, -1}))
.Operators().Add(Operator(L"GridSample")
.SetInput(L"X", L"Input")
.SetInput(L"grid", L"Grid")
.SetAttribute(L"align_corners", TensorInt64Bit::CreateFromArray({ }, {INT64(align_corners)}))
.SetAttribute(L"mode", TensorString::CreateFromArray({ }, { modes[static_cast<uint32_t>(mode)] }))
.SetAttribute(L"padding_mode", TensorString::CreateFromArray({ }, { padding_modes[static_cast<uint32_t>(padding_mode)] }))
.SetOutput(L"Y", L"Output"))
.CreateModel();
auto cpu_device = LearningModelDevice(LearningModelDeviceKind::Cpu);
auto device = LearningModelDevice(kind);
LearningModelSession device_session(model, device);
LearningModelBinding device_binding(device_session);
LearningModelSession cpu_session(model, cpu_device);
LearningModelBinding cpu_binding(cpu_session);
device_binding.Bind(L"Input", TensorFloat::CreateFromShapeArrayAndDataArray(input_dims, input));
device_binding.Bind(L"Grid", TensorFloat::CreateFromShapeArrayAndDataArray(grid_dims, grid));
cpu_binding.Bind(L"Input", TensorFloat::CreateFromShapeArrayAndDataArray(input_dims, input));
cpu_binding.Bind(L"Grid", TensorFloat::CreateFromShapeArrayAndDataArray(grid_dims, grid));
auto cpu_result = cpu_session.Evaluate(cpu_binding, L"");
// Evaluate
auto start = std::chrono::high_resolution_clock::now();
auto device_result = device_session.Evaluate(device_binding, L"");
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::micro> evaluate_duration_in_microseconds = end - start;
printf("GridSample[Mode=%ls, PaddingMode=%ls, AlignCorners=%s] took %fus.\n",
modes[static_cast<uint32_t>(mode)].c_str(),
padding_modes[static_cast<uint32_t>(padding_mode)].c_str(),
align_corners ? "True" : "False",
evaluate_duration_in_microseconds.count());
// Check results
constexpr float error_threshold = .001f;
auto device_y_tensor = device_result.Outputs().Lookup(L"Output").as<TensorFloat>();
auto device_y_ivv = device_y_tensor.GetAsVectorView();
auto cpu_y_tensor = cpu_result.Outputs().Lookup(L"Output").as<TensorFloat>();
auto cpu_y_ivv = cpu_y_tensor.GetAsVectorView();
WINML_EXPECT_EQUAL(device_y_ivv.Size(), cpu_y_ivv.Size());
for (uint32_t i = 0; i < device_y_ivv.Size(); i++) {
bool in_range = abs(device_y_ivv.GetAt(i) - cpu_y_ivv.GetAt(i)) < error_threshold;
if (!in_range) {
printf("[%d] ACTUAL(%f) EXPECTED(%f)\n", (int)i, device_y_ivv.GetAt(i), cpu_y_ivv.GetAt(i));
}
WINML_EXPECT_TRUE(in_range);
}
}
static void GridSampleRunner(LearningModelDeviceKind kind,
const std::vector<float>& input,
const std::vector<int64_t>& input_dims,
const std::vector<float>& grid,
const std::vector<int64_t>& grid_dims)
{
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bilinear, PaddingMode::Zeros);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bilinear, PaddingMode::Border);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bilinear, PaddingMode::Reflection);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Nearest, PaddingMode::Zeros);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Nearest, PaddingMode::Border);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Nearest, PaddingMode::Reflection);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bicubic, PaddingMode::Zeros);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bicubic, PaddingMode::Border);
GridSample(kind, input, input_dims, grid, grid_dims, false, Mode::Bicubic, PaddingMode::Reflection);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bilinear, PaddingMode::Zeros);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bilinear, PaddingMode::Border);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bilinear, PaddingMode::Reflection);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Nearest, PaddingMode::Zeros);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Nearest, PaddingMode::Border);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Nearest, PaddingMode::Reflection);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bicubic, PaddingMode::Zeros);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bicubic, PaddingMode::Border);
GridSample(kind, input, input_dims, grid, grid_dims, true, Mode::Bicubic, PaddingMode::Reflection);
}
static void ModelBuilding_GridSample_Internal(LearningModelDeviceKind kind) {
std::vector<float> input =
{
0.00f, 1.00f, 2.00f, 3.00f,
4.00f, 5.00f, 6.00f, 7.00f,
8.00f, 9.00f, 10.00f, 11.00f,
12.00f, 13.00f, 14.00f, 15.00f,
};
std::vector<float> grid =
{
0.00f, 1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f, 9.00f,
10.00f, 11.00f, 12.00f, 13.00f, 14.00f, 15.00f, 16.00f, 17.00f, 18.00f, 19.00f,
20.00f, 21.00f, 22.00f, 23.00f, 24.00f, 25.00f, 26.00f, 27.00f, 28.00f, 29.00f,
30.00f, 31.00f, 32.00f, 33.00f, 34.00f, 35.00f, 36.00f, 37.00f, 38.00f, 39.00f,
40.00f, 41.00f, 42.00f, 43.00f, 44.00f, 45.00f, 46.00f, 47.00f, 48.00f, 49.00f,
};
std::transform(grid.begin(), grid.end(), grid.begin(), [&](auto& in) { return in / grid.size(); });
std::vector<int64_t> input_dims = {1, 1, 4, 4};
std::vector<int64_t> grid_dims = {1, 5, 5, 2};
GridSampleRunner(kind, input, input_dims, grid, grid_dims);
input = { 0.0f, 1.0f, 2.0f, 3.0f, 4.0, 5.0f };
grid =
{
-10.0000f, -10.0000f,
-5.0000f, -5.0000f,
-0.2000f, -0.2000f,
10.0000f, 10.0000f,
10.0000f, 10.0000f,
-0.2000f, -0.2000f,
5.0000f, 5.0000f,
10.0000f, 10.0000f
};
input_dims = {1, 1, 3, 2};
grid_dims = {1, 2, 4, 2};
GridSampleRunner(kind, input, input_dims, grid, grid_dims);
}
static void ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceKind kind) {
std::vector<float> real_input =
{
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
};
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
1.00f, 2.00f, 3.00f, 4.00f, 5.00f, 6.00f, 7.00f, 8.00f,
};
std::vector<std::complex<float>> real_expected_axis_0_two_sided = {
{5.000f, 0.000f}, {10.000f, 0.000f}, {15.000f, 0.000f}, {20.000f, 0.000f}, {25.000f, 0.000f}, {30.000f, 0.000f}, {35.000f, 0.000f}, {40.000f, 0.000f},
@ -1155,17 +1313,17 @@ static void ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceK
std::vector<std::complex<float>> input =
{
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{1.00f, 0.00f}, {2.00f, 0.00f}, {3.00f, 0.00f}, {4.00f, 0.00f}, {5.00f, 0.00f}, {6.00f, 0.00f}, {7.00f, 0.00f}, {8.00f, 0.00f},
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
{2.00f, 1.00f}, {4.00f, 2.00f}, {6.00f, 3.00f}, {8.00f, 4.00f}, {10.00f, 5.00f}, {12.00f, 6.00f}, {14.00f, 7.00f}, {16.00f, 8.00f},
};
std::vector<std::complex<float>> expected_axis_0_two_sided = {
@ -1259,6 +1417,12 @@ static void ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceK
}
#endif
static void ModelBuilding_GridSampleDeviceDirectX() {
#if !defined(BUILD_INBOX)
ModelBuilding_GridSample_Internal(LearningModelDeviceKind::DirectX);
#endif
}
static void ModelBuilding_DiscreteFourierTransform() {
#if !defined(BUILD_INBOX)
ModelBuilding_DiscreteFourierTransform_Internal(LearningModelDeviceKind::Cpu);
@ -1563,6 +1727,7 @@ const LearningModelSessionAPITestsApi& getapi() {
ModelBuilding_DiscreteFourierTransformInverseIdentity,
ModelBuilding_DiscreteFourierTransformDeviceDirectX,
ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX,
ModelBuilding_GridSampleDeviceDirectX,
ModelBuilding_HannWindow,
ModelBuilding_HammingWindow,
ModelBuilding_BlackmanWindow,
@ -1581,6 +1746,7 @@ const LearningModelSessionAPITestsApi& getapi() {
api.AdapterIdAndDevice = SkipTest;
api.ModelBuilding_DiscreteFourierTransformDeviceDirectX = SkipTest;
api.ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX = SkipTest;
api.ModelBuilding_GridSampleDeviceDirectX = SkipTest;
}
if (RuntimeParameterExists(L"EdgeCore")) {
api.AdapterIdAndDevice = SkipTest;

View file

@ -30,6 +30,7 @@ struct LearningModelSessionAPITestsApi {
VoidTest ModelBuilding_DiscreteFourierTransformInverseIdentity;
VoidTest ModelBuilding_DiscreteFourierTransformDeviceDirectX;
VoidTest ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX;
VoidTest ModelBuilding_GridSampleDeviceDirectX;
VoidTest ModelBuilding_HannWindow;
VoidTest ModelBuilding_HammingWindow;
VoidTest ModelBuilding_BlackmanWindow;
@ -68,6 +69,7 @@ WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransform)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformInverseIdentity)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformDeviceDirectX)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_DiscreteFourierTransformInverseIdentityDeviceDirectX)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_GridSampleDeviceDirectX)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_HannWindow)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_HammingWindow)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_BlackmanWindow)