pytorch/c10
Nikita Shulga 3a23d75b37 [MPS] Fix c0:🤘:log_gamma correctness on M4 (#145740)
To workaround a bug where `abs` method call seems to be ignored before calling log, which could be reproduced by running the following code (submitted as FB16415011 )
```swift
import Metal

func run_shader<T: BinaryFloatingPoint> (library: MTLLibrary, kernel_name: String, type: T.Type, nelem: Int = 16) {
  guard let mfunc = library.makeFunction(name: kernel_name) else { fatalError("Can't find function") }
  let device = library.device
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") }
  guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") }
  guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
  let ibuf_data = ibuf.contents().assumingMemoryBound(to: T.self)
  for i in 0..<nelem {
    ibuf_data[i] = T(sin(Float(2 + i)))
  }
  guard let obuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
  let obuf_data = obuf.contents().assumingMemoryBound(to: T.self)

  computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc))
  computeEncoder.setBuffer(obuf, offset:0, index: 0)
  computeEncoder.setBuffer(ibuf, offset:0, index: 1)
  computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1))
  computeEncoder.endEncoding()
  cmdBuffer.commit()
  cmdBuffer.waitUntilCompleted()

  print("Results for \(String(describing: T.self)):", terminator: " ")
  for i in 0..<nelem {
    print(obuf_data[i], terminator: " ")
  }
  print()
}

let shader_source = """
#include <metal_stdlib>

template<typename T>
float foo(T x) {
  const auto abs_x = :🤘:abs(static_cast<float>(x));
  auto rc = :🤘:log(abs_x);

  return rc - :🤘:log(:🤘:abs(abs_x * :🤘:sinpi(abs_x)));
}

kernel void half_kernel(
    device half* out_ptr0,
    constant half* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
  auto inp = in_ptr0[xindex];
  auto out = foo(inp);
  out_ptr0[xindex] = static_cast<half>(out);
}

kernel void float_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
  auto inp = in_ptr0[xindex];
  auto out = foo(inp);
  out_ptr0[xindex] = static_cast<float>(out);
}
"""
let options = MTLCompileOptions()
options.mathMode = .safe
options.mathFloatingPointFunctions = .precise

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
let library = try! device.makeLibrary(source:shader_source, options:options)
run_shader(library:library, kernel_name:"half_kernel", type: Float16.self)
run_shader(library:library, kernel_name:"float_kernel", type: Float.self)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145740
Approved by: https://github.com/dcci
2025-01-27 21:24:22 +00:00
..
benchmark Set RUNPATH so installed tests can find the required shared libraries (#136627) 2024-10-25 09:38:08 +00:00
core [autocast][pytorch] Support autocast for MTIA (#145627) 2025-01-25 03:24:59 +00:00
cuda Set RUNPATH on CUDA and XPU tests (#144305) 2025-01-26 08:40:22 +00:00
hip Fix hardcoded ROCm paths in Caffe2Targets.cmake (#136283) 2024-09-26 00:34:43 +00:00
macros [ROCm][Windows] Fix export macros (#144098) 2025-01-04 17:12:46 +00:00
metal [MPS] Fix c0:🤘:log_gamma correctness on M4 (#145740) 2025-01-27 21:24:22 +00:00
mobile [2/N] Fix extra warnings brought by clang-tidy-17 (#137459) 2024-10-08 19:05:02 +00:00
test c10::string_view -> std::string_view in pytorch (#143591) 2025-01-13 21:44:05 +00:00
util Fix PythonMod printing for C++ (#143385) 2025-01-22 14:58:35 +00:00
xpu Set RUNPATH on CUDA and XPU tests (#144305) 2025-01-26 08:40:22 +00:00
BUCK.oss
BUILD.bazel
build.bzl
CMakeLists.txt [pytorch][monitoring] Dynamic backend for WaitCounter (#135967) 2024-09-15 18:07:49 +00:00
ovrsource_defs.bzl