pytorch/torch/mps
Nikita Shulga 95b17f6346 [MPS] Add CompileShader method (#141478)
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
   kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
     x[idx] += idx;
   }
")
m.foo(x)
```

And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
2024-12-11 02:00:51 +00:00
..
__init__.py [MPS] Add CompileShader method (#141478) 2024-12-11 02:00:51 +00:00
event.py
profiler.py