mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Add PSO caching for advanced indexing kernels (#99855)
Use bindless Argument Buffers (unbounded arrays) for advanced indexing kernels - this allows caching of the PSOs since we don't have to query anymore the main metal function for the AB size (this is filled directly now on the CPU). Pull Request resolved: https://github.com/pytorch/pytorch/pull/99855 Approved by: https://github.com/kulinseth
This commit is contained in:
parent
09b189edc3
commit
dcd686f478
6 changed files with 191 additions and 90 deletions
|
|
@ -9,27 +9,42 @@ static const char * indexing_metal_shaders = R"INDEX_METAL(
|
|||
|
||||
using namespace metal;
|
||||
|
||||
constant uint32_t num_indices [[function_constant(0)]];
|
||||
|
||||
#if __METAL_VERSION__ < 300
|
||||
struct IndexAB {
|
||||
// Allow up to 16 indices
|
||||
metal::array<constant void *, 16> indexArray [[ id(0) ]];
|
||||
};
|
||||
#else
|
||||
struct IndexAB {
|
||||
constant int64_t* indexArray;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
kernel void index_select(
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant IndexAB * indexAB [[buffer(0)]],
|
||||
#else
|
||||
constant IndexAB & indexAB [[buffer(0)]],
|
||||
#endif
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
constant uint32_t & num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
||||
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
||||
int64_t offset = 0;
|
||||
for (uint32_t i = 0; i < num_indices; i++) {
|
||||
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
||||
#else
|
||||
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
||||
#endif
|
||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
||||
if (index < 0) {
|
||||
index += index_sizes[i];
|
||||
}
|
||||
|
|
@ -42,19 +57,30 @@ kernel void index_select(
|
|||
|
||||
template<typename T>
|
||||
kernel void index_put(
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant IndexAB * indexAB [[buffer(0)]],
|
||||
#else
|
||||
constant IndexAB & indexAB [[buffer(0)]],
|
||||
#endif
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
constant uint32_t & num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
|
||||
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
||||
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
||||
int64_t offset = 0;
|
||||
for (uint32_t i = 0; i < num_indices; i++) {
|
||||
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
||||
#else
|
||||
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
||||
#endif
|
||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
||||
|
||||
if (index < 0) {
|
||||
index += index_sizes[i];
|
||||
}
|
||||
|
|
@ -65,6 +91,7 @@ kernel void index_put(
|
|||
*out = *in;
|
||||
}
|
||||
|
||||
#if __METAL_VERSION__ < 300
|
||||
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
|
||||
template \
|
||||
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
|
||||
|
|
@ -75,7 +102,22 @@ kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
|
|||
constant uint3 * offsets [[buffer(3)]], \
|
||||
constant void * inputData [[buffer(4)]], \
|
||||
device void * outputData [[buffer(5)]], \
|
||||
constant uint32_t & num_indices [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
#else
|
||||
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
|
||||
template \
|
||||
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
|
||||
kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
|
||||
constant IndexAB * indexAB [[buffer(0)]], \
|
||||
constant void * indexSizes [[buffer(1)]], \
|
||||
constant void * indexStrides [[buffer(2)]], \
|
||||
constant uint3 * offsets [[buffer(3)]], \
|
||||
constant void * inputData [[buffer(4)]], \
|
||||
device void * outputData [[buffer(5)]], \
|
||||
constant uint32_t & num_indices [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
#endif
|
||||
|
||||
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
||||
REGISTER_INDEX_OP(8bit, char, INDEX_OP_TYPE); \
|
||||
|
|
@ -92,29 +134,40 @@ kernel void kernel_index_offsets(constant packed_uint3 * strides [[buffe
|
|||
constant uint & num_dimensions [[buffer(3)]],
|
||||
constant uint & num_offsets [[buffer(4)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
data_offsets[thread_index] = 0;
|
||||
uint32_t idx = thread_index;
|
||||
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
|
||||
uint32_t remainder = idx % iter_shape[dim];
|
||||
idx /= iter_shape[dim];
|
||||
|
||||
for (uint32_t offset = 0; offset < num_offsets; offset++)
|
||||
data_offsets[thread_index][offset] += remainder * strides[dim][offset];
|
||||
data_offsets[thread_index] += remainder * strides[dim];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename E>
|
||||
kernel void index_put_accumulate_native_dtypes(constant IndexAB & indexAB [[buffer(0)]],
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
kernel void index_put_accumulate_native_dtypes(
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant IndexAB * indexAB [[buffer(0)]],
|
||||
#else
|
||||
constant IndexAB & indexAB [[buffer(0)]],
|
||||
#endif
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
||||
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
||||
int64_t offset = 0;
|
||||
for (uint32_t i = 0; i < num_indices; i++) {
|
||||
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
||||
#else
|
||||
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
||||
#endif
|
||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
||||
if (index < 0) {
|
||||
index += index_sizes[i];
|
||||
}
|
||||
|
|
@ -136,18 +189,29 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[buffer(0)]],
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
kernel void atomic_index_put_accumulate(
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant IndexAB * indexAB [[buffer(0)]],
|
||||
#else
|
||||
constant IndexAB & indexAB [[buffer(0)]],
|
||||
#endif
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
||||
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
||||
int64_t offset = 0;
|
||||
for (uint32_t i = 0; i < num_indices; i++) {
|
||||
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
||||
#else
|
||||
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
||||
#endif
|
||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
||||
if (index < 0) {
|
||||
index += index_sizes[i];
|
||||
}
|
||||
|
|
@ -160,22 +224,35 @@ kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[b
|
|||
|
||||
template
|
||||
[[host_name("index_put_accumulate_32bit_float")]]
|
||||
kernel void atomic_index_put_accumulate<float>(constant IndexAB & indexAB [[buffer(0)]],
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
kernel void atomic_index_put_accumulate<float>(
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant IndexAB * indexAB [[buffer(0)]],
|
||||
#else
|
||||
constant IndexAB & indexAB [[buffer(0)]],
|
||||
#endif
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
template
|
||||
[[host_name("index_put_accumulate_32bit_int")]]
|
||||
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(constant IndexAB & indexAB [[buffer(0)]],
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(
|
||||
#if __METAL_VERSION__ >= 300
|
||||
constant IndexAB * indexAB [[buffer(0)]],
|
||||
#else
|
||||
constant IndexAB & indexAB [[buffer(0)]],
|
||||
#endif
|
||||
constant void * indexSizes [[buffer(1)]],
|
||||
constant void * indexStrides [[buffer(2)]],
|
||||
constant uint3 * offsets [[buffer(3)]],
|
||||
constant void * inputData [[buffer(4)]],
|
||||
device void * outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
)INDEX_METAL";
|
||||
|
||||
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@
|
|||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
typedef id<MTLDevice> MTLDevice_t;
|
||||
typedef id<MTLLibrary> MTLLibrary_t;
|
||||
typedef id<MTLFunction> MTLFunction_t;
|
||||
typedef MTLFunctionConstantValues* MTLFunctionConstantValues_t;
|
||||
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
|
||||
typedef id<MTLLibrary> MTLLibrary_t;
|
||||
#else
|
||||
typedef void* MTLDevice;
|
||||
typedef void* MTLDevice_t;
|
||||
typedef void* MTLLibrary_t;
|
||||
typedef void* MTLFunction_t;
|
||||
typedef void* MTLFunctionConstantValues_t;
|
||||
typedef void* MTLComputePipelineState_t;
|
||||
typedef void* MTLLibrary_t;
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
|
|
@ -66,7 +66,8 @@ class TORCH_API MPSDevice {
|
|||
*/
|
||||
bool isMacOS13Plus(MacOSVersion version) const;
|
||||
|
||||
MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues);
|
||||
MTLComputePipelineState_t metalIndexingFunction(const std::string &kernel);
|
||||
MTLLibrary_t getMetalIndexingLibrary();
|
||||
|
||||
~MPSDevice();
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,15 @@ namespace mps {
|
|||
static std::unique_ptr<MPSDevice> mps_device;
|
||||
static c10::once_flag mpsdev_init;
|
||||
|
||||
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
|
||||
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
|
||||
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
|
||||
// host_name attribute needs at least Metal 2.2
|
||||
MTLLanguageVersion languageVersion = MTLLanguageVersion2_2;
|
||||
#if defined(__MAC_13_0)
|
||||
if (macOS13Plus) {
|
||||
languageVersion = MTLLanguageVersion3_0;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
|
||||
return languageVersion;
|
||||
|
|
@ -27,12 +32,12 @@ MPSDevice* MPSDevice::getInstance() {
|
|||
return mps_device.get();
|
||||
}
|
||||
|
||||
id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
|
||||
id<MTLLibrary> MPSDevice::getMetalIndexingLibrary() {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
|
||||
NSError* error = nil;
|
||||
if (!_mtl_indexing_library) {
|
||||
MTLCompileOptions* options = [MTLCompileOptions new];
|
||||
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
|
||||
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
|
||||
[options setFastMathEnabled:YES];
|
||||
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
|
||||
encoding:NSASCIIStringEncoding]
|
||||
|
|
@ -40,24 +45,31 @@ id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLF
|
|||
error:&error];
|
||||
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
return _mtl_indexing_library;
|
||||
}
|
||||
|
||||
id<MTLFunction> indexFunction = nil;
|
||||
if (constantValues) {
|
||||
indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]
|
||||
constantValues:constantValues
|
||||
error:&error] autorelease];
|
||||
} else {
|
||||
indexFunction =
|
||||
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
|
||||
id<MTLComputePipelineState> MPSDevice::metalIndexingFunction(const std::string& kernel) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
|
||||
NSError* error = nil;
|
||||
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
|
||||
id<MTLLibrary> indexing_lib = getMetalIndexingLibrary();
|
||||
id<MTLComputePipelineState> state = psoCache[kernel];
|
||||
if (state) {
|
||||
return state;
|
||||
}
|
||||
|
||||
id<MTLFunction> indexFunction =
|
||||
[[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
|
||||
TORCH_CHECK(indexFunction,
|
||||
"Failed to create specialized function state object: ",
|
||||
kernel,
|
||||
", error: ",
|
||||
[[error description] UTF8String]);
|
||||
|
||||
return indexFunction;
|
||||
state = [_mtl_device newComputePipelineStateWithFunction:indexFunction error:&error];
|
||||
TORCH_CHECK(state, error.localizedDescription.UTF8String);
|
||||
psoCache[kernel] = state;
|
||||
return state;
|
||||
}
|
||||
|
||||
MPSDevice::~MPSDevice() {
|
||||
|
|
|
|||
|
|
@ -197,10 +197,8 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
|
|||
}
|
||||
}
|
||||
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -155,10 +155,8 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
|
|||
}
|
||||
}
|
||||
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/native/IndexKernel.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
|
|
@ -77,12 +78,10 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
|
|||
|
||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
id<MTLFunction> kernelDataOffsetsFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
|
||||
id<MTLComputePipelineState> kernelDataOffsetsPSO =
|
||||
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
|
||||
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
|
||||
options:0] autorelease];
|
||||
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
|
||||
id<MTLBuffer> kernelDataOffsets =
|
||||
(id<MTLBuffer>)getIMPSAllocator()->allocate(numThreads * sizeof(simd_uint3)).get();
|
||||
TORCH_CHECK(
|
||||
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
|
||||
|
|
@ -94,39 +93,53 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
|
|||
[computeEncoder setBytes:&nOffsets length:sizeof(uint32_t) atIndex:4];
|
||||
|
||||
NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (kernelOffsetsTGSize > numThreads)
|
||||
if (kernelOffsetsTGSize > numThreads) {
|
||||
kernelOffsetsTGSize = numThreads;
|
||||
}
|
||||
|
||||
MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];
|
||||
|
||||
MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues new] autorelease];
|
||||
[constantValues setConstantValue:&num_indices type:MTLDataTypeUInt atIndex:0];
|
||||
|
||||
std::string indexFunction = getIndexFunctionName(inputTensor.scalar_type(), index_select, accumulate);
|
||||
id<MTLFunction> indexKernelFunction =
|
||||
MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues);
|
||||
id<MTLArgumentEncoder> argumentEncoder = [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease];
|
||||
NSUInteger argumentBufferLength = argumentEncoder.encodedLength;
|
||||
id<MTLBuffer> indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
|
||||
[argumentEncoder setArgumentBuffer:indexAB offset:0];
|
||||
id<MTLComputePipelineState> indexSelectPSO = nil;
|
||||
id<MTLBuffer> indexAB = nil;
|
||||
#if defined(__MAC_13_0)
|
||||
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS)) {
|
||||
indexSelectPSO = MPSDevice::getInstance()->metalIndexingFunction(indexFunction);
|
||||
size_t argumentBufferLength = sizeof(uint64_t) * num_indices;
|
||||
indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
|
||||
uint64_t* indexABContents = (uint64_t*)(indexAB.contents);
|
||||
for (uint32_t idx = 0; idx < num_indices; idx++) {
|
||||
const Tensor& indexTensor = iter.tensor(idx + 2);
|
||||
indexABContents[idx] =
|
||||
getMTLBufferStorage(indexTensor).gpuAddress + (indexTensor.storage_offset() * indexTensor.element_size());
|
||||
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
|
||||
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
id<MTLLibrary> lib = MPSDevice::getInstance()->getMetalIndexingLibrary();
|
||||
id<MTLFunction> indexKernelFunction =
|
||||
[[lib newFunctionWithName:[NSString stringWithUTF8String:indexFunction.c_str()]] autorelease];
|
||||
id<MTLArgumentEncoder> argumentEncoder =
|
||||
[[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease];
|
||||
NSUInteger argumentBufferLength = argumentEncoder.encodedLength;
|
||||
indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
|
||||
[argumentEncoder setArgumentBuffer:indexAB offset:0];
|
||||
|
||||
for (uint32_t idx = 0; idx < num_indices; idx++) {
|
||||
const Tensor& indexTensor = iter.tensor(idx + 2);
|
||||
[argumentEncoder setBuffer:getMTLBufferStorage(indexTensor)
|
||||
offset:indexTensor.storage_offset() * indexTensor.element_size()
|
||||
atIndex:idx];
|
||||
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
|
||||
}
|
||||
for (uint32_t idx = 0; idx < num_indices; idx++) {
|
||||
const Tensor& indexTensor = iter.tensor(idx + 2);
|
||||
[argumentEncoder setBuffer:getMTLBufferStorage(indexTensor)
|
||||
offset:indexTensor.storage_offset() * indexTensor.element_size()
|
||||
atIndex:idx];
|
||||
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
|
||||
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
|
||||
}
|
||||
|
||||
// FIXME: PSO needs to be cached
|
||||
id<MTLComputePipelineState> indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction
|
||||
error:&error] autorelease];
|
||||
TORCH_CHECK(indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
|
||||
for (uint32_t idx = 0; idx < num_indices; idx++) {
|
||||
const Tensor& indexTensor = iter.tensor(idx + 2);
|
||||
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
|
||||
indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction error:&error] autorelease];
|
||||
TORCH_CHECK(
|
||||
indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
|
||||
[computeEncoder setComputePipelineState:indexSelectPSO];
|
||||
|
|
@ -138,10 +151,12 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
|
|||
[computeEncoder setBuffer:outputBuffer
|
||||
offset:outputTensor.storage_offset() * outputTensor.element_size()
|
||||
atIndex:5];
|
||||
[computeEncoder setBytes:&num_indices length:sizeof(uint32_t) atIndex:6];
|
||||
|
||||
NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup;
|
||||
if (tgSize > numThreads)
|
||||
if (tgSize > numThreads) {
|
||||
tgSize = numThreads;
|
||||
}
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
|
|
|||
Loading…
Reference in a new issue