mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
To workaround a bug where `abs` method call seems to be ignored before calling log, which could be reproduced by running the following code (submitted as FB16415011 )
```swift
import Metal
func run_shader<T: BinaryFloatingPoint> (library: MTLLibrary, kernel_name: String, type: T.Type, nelem: Int = 16) {
guard let mfunc = library.makeFunction(name: kernel_name) else { fatalError("Can't find function") }
let device = library.device
guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") }
guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") }
guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
let ibuf_data = ibuf.contents().assumingMemoryBound(to: T.self)
for i in 0..<nelem {
ibuf_data[i] = T(sin(Float(2 + i)))
}
guard let obuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
let obuf_data = obuf.contents().assumingMemoryBound(to: T.self)
computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc))
computeEncoder.setBuffer(obuf, offset:0, index: 0)
computeEncoder.setBuffer(ibuf, offset:0, index: 1)
computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1))
computeEncoder.endEncoding()
cmdBuffer.commit()
cmdBuffer.waitUntilCompleted()
print("Results for \(String(describing: T.self)):", terminator: " ")
for i in 0..<nelem {
print(obuf_data[i], terminator: " ")
}
print()
}
let shader_source = """
#include <metal_stdlib>
template<typename T>
float foo(T x) {
const auto abs_x = :🤘:abs(static_cast<float>(x));
auto rc = :🤘:log(abs_x);
return rc - :🤘:log(:🤘:abs(abs_x * :🤘:sinpi(abs_x)));
}
kernel void half_kernel(
device half* out_ptr0,
constant half* in_ptr0,
uint xindex [[thread_position_in_grid]]
) {
auto inp = in_ptr0[xindex];
auto out = foo(inp);
out_ptr0[xindex] = static_cast<half>(out);
}
kernel void float_kernel(
device float* out_ptr0,
constant float* in_ptr0,
uint xindex [[thread_position_in_grid]]
) {
auto inp = in_ptr0[xindex];
auto out = foo(inp);
out_ptr0[xindex] = static_cast<float>(out);
}
"""
let options = MTLCompileOptions()
options.mathMode = .safe
options.mathFloatingPointFunctions = .precise
guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
let library = try! device.makeLibrary(source:shader_source, options:options)
run_shader(library:library, kernel_name:"half_kernel", type: Float16.self)
run_shader(library:library, kernel_name:"float_kernel", type: Float.self)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145740
Approved by: https://github.com/dcci
|
||
|---|---|---|
| .. | ||
| benchmark | ||
| core | ||
| cuda | ||
| hip | ||
| macros | ||
| metal | ||
| mobile | ||
| test | ||
| util | ||
| xpu | ||
| BUCK.oss | ||
| BUILD.bazel | ||
| build.bzl | ||
| CMakeLists.txt | ||
| ovrsource_defs.bzl | ||