From dcd686f4780b43c7cbe69d576fbd48bebd9eb19d Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 24 Apr 2023 15:41:47 +0000 Subject: [PATCH] [MPS] Add PSO caching for advanced indexing kernels (#99855) Use bindless Argument Buffers (unbounded arrays) for advanced indexing kernels - this allows caching of the PSOs since we don't have to query anymore the main metal function for the AB size (this is filled directly now on the CPU). Pull Request resolved: https://github.com/pytorch/pytorch/pull/99855 Approved by: https://github.com/kulinseth --- aten/src/ATen/mps/IndexKernels.h | 149 +++++++++++++----- aten/src/ATen/mps/MPSDevice.h | 11 +- aten/src/ATen/mps/MPSDevice.mm | 36 +++-- .../native/mps/operations/BinaryKernel.mm | 4 +- .../ATen/native/mps/operations/CrossKernel.mm | 4 +- .../ATen/native/mps/operations/Indexing.mm | 77 +++++---- 6 files changed, 191 insertions(+), 90 deletions(-) diff --git a/aten/src/ATen/mps/IndexKernels.h b/aten/src/ATen/mps/IndexKernels.h index 650da6ae951..1635fd21d22 100644 --- a/aten/src/ATen/mps/IndexKernels.h +++ b/aten/src/ATen/mps/IndexKernels.h @@ -9,27 +9,42 @@ static const char * indexing_metal_shaders = R"INDEX_METAL( using namespace metal; -constant uint32_t num_indices [[function_constant(0)]]; - +#if __METAL_VERSION__ < 300 struct IndexAB { // Allow up to 16 indices metal::array indexArray [[ id(0) ]]; }; +#else +struct IndexAB { + constant int64_t* indexArray; +}; + +#endif template kernel void index_select( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else constant IndexAB & indexAB [[buffer(0)]], +#endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant uint3 * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { - int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)]; +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } @@ -42,19 +57,30 @@ kernel void index_select( template kernel void index_put( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else constant IndexAB & indexAB [[buffer(0)]], +#endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant uint3 * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { - int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)]; +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; + if (index < 0) { index += index_sizes[i]; } @@ -65,6 +91,7 @@ kernel void index_put( *out = *in; } +#if __METAL_VERSION__ < 300 #define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \ template \ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \ @@ -75,7 +102,22 @@ kernel void index_ ## INDEX_OP_TYPE( \ constant uint3 * offsets [[buffer(3)]], \ constant void * inputData [[buffer(4)]], \ device void * outputData [[buffer(5)]], \ + constant uint32_t & num_indices [[buffer(6)]], \ uint thread_index [[thread_position_in_grid]]); +#else +#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \ +template \ +[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \ +kernel void index_ ## INDEX_OP_TYPE( \ + constant IndexAB * indexAB [[buffer(0)]], \ + constant void * indexSizes [[buffer(1)]], \ + constant void * indexStrides [[buffer(2)]], \ + constant uint3 * offsets [[buffer(3)]], \ + constant void * inputData [[buffer(4)]], \ + device void * outputData [[buffer(5)]], \ + constant uint32_t & num_indices [[buffer(6)]], \ + uint thread_index [[thread_position_in_grid]]); +#endif #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ REGISTER_INDEX_OP(8bit, char, INDEX_OP_TYPE); \ @@ -92,29 +134,40 @@ kernel void kernel_index_offsets(constant packed_uint3 * strides [[buffe constant uint & num_dimensions [[buffer(3)]], constant uint & num_offsets [[buffer(4)]], uint thread_index [[thread_position_in_grid]]) { + data_offsets[thread_index] = 0; uint32_t idx = thread_index; for (uint32_t dim = 0; dim < num_dimensions; dim++) { uint32_t remainder = idx % iter_shape[dim]; idx /= iter_shape[dim]; - for (uint32_t offset = 0; offset < num_offsets; offset++) - data_offsets[thread_index][offset] += remainder * strides[dim][offset]; + data_offsets[thread_index] += remainder * strides[dim]; } } template -kernel void index_put_accumulate_native_dtypes(constant IndexAB & indexAB [[buffer(0)]], - constant void * indexSizes [[buffer(1)]], - constant void * indexStrides [[buffer(2)]], - constant uint3 * offsets [[buffer(3)]], - constant void * inputData [[buffer(4)]], - device void * outputData [[buffer(5)]], - uint thread_index [[thread_position_in_grid]]) { +kernel void index_put_accumulate_native_dtypes( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t& num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { - int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)]; +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } @@ -136,18 +189,29 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a } template -kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[buffer(0)]], - constant void * indexSizes [[buffer(1)]], - constant void * indexStrides [[buffer(2)]], - constant uint3 * offsets [[buffer(3)]], - constant void * inputData [[buffer(4)]], - device void * outputData [[buffer(5)]], - uint thread_index [[thread_position_in_grid]]) { +kernel void atomic_index_put_accumulate( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t& num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { - int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)]; +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } @@ -160,22 +224,35 @@ kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[b template [[host_name("index_put_accumulate_32bit_float")]] -kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[buffer(0)]], - constant void * indexSizes [[buffer(1)]], - constant void * indexStrides [[buffer(2)]], - constant uint3 * offsets [[buffer(3)]], - constant void * inputData [[buffer(4)]], - device void * outputData [[buffer(5)]], - uint thread_index [[thread_position_in_grid]]); +kernel void atomic_index_put_accumulate( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t& num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]); + template [[host_name("index_put_accumulate_32bit_int")]] -kernel void index_put_accumulate_native_dtypes(constant IndexAB & indexAB [[buffer(0)]], - constant void * indexSizes [[buffer(1)]], - constant void * indexStrides [[buffer(2)]], - constant uint3 * offsets [[buffer(3)]], - constant void * inputData [[buffer(4)]], - device void * outputData [[buffer(5)]], - uint thread_index [[thread_position_in_grid]]); +kernel void index_put_accumulate_native_dtypes( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t& num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]); )INDEX_METAL"; static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER( diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 5e9c948bdd6..70da9e5a038 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -12,14 +12,14 @@ #include typedef id MTLDevice_t; typedef id MTLLibrary_t; -typedef id MTLFunction_t; -typedef MTLFunctionConstantValues* MTLFunctionConstantValues_t; +typedef id MTLComputePipelineState_t; +typedef id MTLLibrary_t; #else typedef void* MTLDevice; typedef void* MTLDevice_t; typedef void* MTLLibrary_t; -typedef void* MTLFunction_t; -typedef void* MTLFunctionConstantValues_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLLibrary_t; #endif using namespace std; @@ -66,7 +66,8 @@ class TORCH_API MPSDevice { */ bool isMacOS13Plus(MacOSVersion version) const; - MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues); + MTLComputePipelineState_t metalIndexingFunction(const std::string &kernel); + MTLLibrary_t getMetalIndexingLibrary(); ~MPSDevice(); diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index f7a37371bb7..7015bee6a43 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -13,10 +13,15 @@ namespace mps { static std::unique_ptr mps_device; static c10::once_flag mpsdev_init; -static inline MTLLanguageVersion getMetalLanguageVersion(const id& device) { +static inline MTLLanguageVersion getMetalLanguageVersion(const id& device, bool macOS13Plus) { // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) // host_name attribute needs at least Metal 2.2 MTLLanguageVersion languageVersion = MTLLanguageVersion2_2; +#if defined(__MAC_13_0) + if (macOS13Plus) { + languageVersion = MTLLanguageVersion3_0; + } +#endif TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); return languageVersion; @@ -27,12 +32,12 @@ MPSDevice* MPSDevice::getInstance() { return mps_device.get(); } -id MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) { +id MPSDevice::getMetalIndexingLibrary() { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); NSError* error = nil; if (!_mtl_indexing_library) { MTLCompileOptions* options = [MTLCompileOptions new]; - [options setLanguageVersion:getMetalLanguageVersion(_mtl_device)]; + [options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))]; [options setFastMathEnabled:YES]; _mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders encoding:NSASCIIStringEncoding] @@ -40,24 +45,31 @@ id MPSDevice::metalIndexingFunction(const std::string& kernel, MTLF error:&error]; TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]); } + return _mtl_indexing_library; +} - id indexFunction = nil; - if (constantValues) { - indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()] - constantValues:constantValues - error:&error] autorelease]; - } else { - indexFunction = - [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease]; +id MPSDevice::metalIndexingFunction(const std::string& kernel) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); + NSError* error = nil; + static std::unordered_map> psoCache; + id indexing_lib = getMetalIndexingLibrary(); + id state = psoCache[kernel]; + if (state) { + return state; } + id indexFunction = + [[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease]; TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]); - return indexFunction; + state = [_mtl_device newComputePipelineStateWithFunction:indexFunction error:&error]; + TORCH_CHECK(state, error.localizedDescription.UTF8String); + psoCache[kernel] = state; + return state; } MPSDevice::~MPSDevice() { diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 40f8a1ab947..c86793988c9 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -197,10 +197,8 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) { } } - id kernelDataOffsetsFunction = - MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); id kernelDataOffsetsPSO = - [[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease]; + MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets"); id kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3) options:0] autorelease]; TORCH_CHECK( diff --git a/aten/src/ATen/native/mps/operations/CrossKernel.mm b/aten/src/ATen/native/mps/operations/CrossKernel.mm index a9dea1fa69f..32274164613 100644 --- a/aten/src/ATen/native/mps/operations/CrossKernel.mm +++ b/aten/src/ATen/native/mps/operations/CrossKernel.mm @@ -155,10 +155,8 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other, } } - id kernelDataOffsetsFunction = - MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); id kernelDataOffsetsPSO = - [[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease]; + MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets"); id kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3) options:0] autorelease]; TORCH_CHECK( diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index 7e47583e1f2..f97de92cded 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -77,12 +78,10 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter, MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); id computeEncoder = mpsStream->commandEncoder(); - id kernelDataOffsetsFunction = - MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); id kernelDataOffsetsPSO = - [[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease]; - id kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3) - options:0] autorelease]; + MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets"); + id kernelDataOffsets = + (id)getIMPSAllocator()->allocate(numThreads * sizeof(simd_uint3)).get(); TORCH_CHECK( kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); @@ -94,39 +93,53 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter, [computeEncoder setBytes:&nOffsets length:sizeof(uint32_t) atIndex:4]; NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup; - if (kernelOffsetsTGSize > numThreads) + if (kernelOffsetsTGSize > numThreads) { kernelOffsetsTGSize = numThreads; + } MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1); [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize]; - MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues new] autorelease]; - [constantValues setConstantValue:&num_indices type:MTLDataTypeUInt atIndex:0]; - std::string indexFunction = getIndexFunctionName(inputTensor.scalar_type(), index_select, accumulate); - id indexKernelFunction = - MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues); - id argumentEncoder = [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease]; - NSUInteger argumentBufferLength = argumentEncoder.encodedLength; - id indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease]; - [argumentEncoder setArgumentBuffer:indexAB offset:0]; + id indexSelectPSO = nil; + id indexAB = nil; +#if defined(__MAC_13_0) + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS)) { + indexSelectPSO = MPSDevice::getInstance()->metalIndexingFunction(indexFunction); + size_t argumentBufferLength = sizeof(uint64_t) * num_indices; + indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease]; + uint64_t* indexABContents = (uint64_t*)(indexAB.contents); + for (uint32_t idx = 0; idx < num_indices; idx++) { + const Tensor& indexTensor = iter.tensor(idx + 2); + indexABContents[idx] = + getMTLBufferStorage(indexTensor).gpuAddress + (indexTensor.storage_offset() * indexTensor.element_size()); + TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index"); + [computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead]; + } + } else +#endif + { + id lib = MPSDevice::getInstance()->getMetalIndexingLibrary(); + id indexKernelFunction = + [[lib newFunctionWithName:[NSString stringWithUTF8String:indexFunction.c_str()]] autorelease]; + id argumentEncoder = + [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease]; + NSUInteger argumentBufferLength = argumentEncoder.encodedLength; + indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:indexAB offset:0]; - for (uint32_t idx = 0; idx < num_indices; idx++) { - const Tensor& indexTensor = iter.tensor(idx + 2); - [argumentEncoder setBuffer:getMTLBufferStorage(indexTensor) - offset:indexTensor.storage_offset() * indexTensor.element_size() - atIndex:idx]; - TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index"); - } + for (uint32_t idx = 0; idx < num_indices; idx++) { + const Tensor& indexTensor = iter.tensor(idx + 2); + [argumentEncoder setBuffer:getMTLBufferStorage(indexTensor) + offset:indexTensor.storage_offset() * indexTensor.element_size() + atIndex:idx]; + TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index"); + [computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead]; + } - // FIXME: PSO needs to be cached - id indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction - error:&error] autorelease]; - TORCH_CHECK(indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); - - for (uint32_t idx = 0; idx < num_indices; idx++) { - const Tensor& indexTensor = iter.tensor(idx + 2); - [computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead]; + indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction error:&error] autorelease]; + TORCH_CHECK( + indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); } [computeEncoder setComputePipelineState:indexSelectPSO]; @@ -138,10 +151,12 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter, [computeEncoder setBuffer:outputBuffer offset:outputTensor.storage_offset() * outputTensor.element_size() atIndex:5]; + [computeEncoder setBytes:&num_indices length:sizeof(uint32_t) atIndex:6]; NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > numThreads) + if (tgSize > numThreads) { tgSize = numThreads; + } MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];