mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
make ATen/native/cuda/ConvolutionMM2d.cu data_ptr-correct (#99323)
make ATen/native/cuda/ConvolutionMM2d.cu data_ptr-correct Test Plan: Rely on CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99323 Approved by: https://github.com/ezyang
This commit is contained in:
parent
7880f9e7e3
commit
24d20ea194
1 changed files with 6 additions and 6 deletions
|
|
@ -198,8 +198,8 @@ void slow_conv2d_forward(
|
|||
|
||||
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
|
||||
auto gemm_in_ptr = requires_columns ?
|
||||
columns.data_ptr<scalar_t>() :
|
||||
input_n.data_ptr<scalar_t>();
|
||||
columns.const_data_ptr<scalar_t>() :
|
||||
input_n.const_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm(
|
||||
'n', 'n',
|
||||
n, m, k,
|
||||
|
|
@ -337,12 +337,12 @@ void slow_conv2d_grad_weight(
|
|||
// Extract columns:
|
||||
at::native::im2col<scalar_t>(
|
||||
c10::cuda::getCurrentCUDAStream(),
|
||||
input_n.data_ptr<scalar_t>(),
|
||||
input_n.const_data_ptr<scalar_t>(),
|
||||
nInputPlane, inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW, padH, padW, dH, dW,
|
||||
1, 1,
|
||||
columns.data_ptr<scalar_t>()
|
||||
columns.mutable_data_ptr<scalar_t>()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -354,8 +354,8 @@ void slow_conv2d_grad_weight(
|
|||
|
||||
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
|
||||
auto gemm_in_ptr = requires_columns ?
|
||||
columns.data_ptr<scalar_t>() :
|
||||
input_n.data_ptr<scalar_t>();
|
||||
columns.const_data_ptr<scalar_t>() :
|
||||
input_n.const_data_ptr<scalar_t>();
|
||||
at::cuda::blas::gemm(
|
||||
't', 'n',
|
||||
n, m, k,
|
||||
|
|
|
|||
Loading…
Reference in a new issue