Skip to content

Commit

Permalink
Do not use CPU-specific bounds check on CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Apr 16, 2024
1 parent 64b91ab commit 40cfcec
Show file tree
Hide file tree
Showing 16 changed files with 138 additions and 101 deletions.
29 changes: 17 additions & 12 deletions mops/src/internal/checks/opsa.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef MOPS_CHECKS_OPSA_HPP
#define MOPS_CHECKS_OPSA_HPP

#include <string>
#include "mops/tensor.hpp"
#include "utils.hpp"

Expand All @@ -9,13 +10,16 @@ void check_opsa(
mops::Tensor<scalar_t, 3> output,
mops::Tensor<scalar_t, 2> A,
mops::Tensor<scalar_t, 2> B,
mops::Tensor<int32_t, 1> indices_output
mops::Tensor<int32_t, 1> indices_output,
std::string operation_name
) {
check_sizes(A, "A", 0, B, "B", 0, "opsa");
check_sizes(A, "A", 1, output, "output", 1, "opsa");
check_sizes(B, "B", 1, output, "output", 2, "opsa");
check_sizes(A, "A", 0, indices_output, "indices_output", 0, "opsa");
check_index_tensor(indices_output, "indices_output", output.shape[0], "opsa");
check_sizes(A, "A", 0, B, "B", 0, operation_name);
check_sizes(A, "A", 1, output, "output", 1, operation_name);
check_sizes(B, "B", 1, output, "output", 2, operation_name);
check_sizes(A, "A", 0, indices_output, "indices_output", 0, operation_name);
if (operation_name.rfind("cuda_", 0) != 0) {
check_index_tensor(indices_output, "indices_output", output.shape[0], operation_name);
}
}

template <typename scalar_t>
Expand All @@ -25,17 +29,18 @@ void check_opsa_vjp(
mops::Tensor<scalar_t, 3> grad_output,
mops::Tensor<scalar_t, 2> A,
mops::Tensor<scalar_t, 2> B,
mops::Tensor<int32_t, 1> indices_output
mops::Tensor<int32_t, 1> indices_output,
std::string operation_name
) {
if (grad_A.data != nullptr) {
check_sizes(grad_A, "grad_A", 0, A, "A", 0, "opsa_vjp");
check_sizes(grad_A, "grad_A", 1, A, "A", 1, "opsa_vjp");
check_sizes(grad_A, "grad_A", 0, A, "A", 0, operation_name);
check_sizes(grad_A, "grad_A", 1, A, "A", 1, operation_name);
}
if (grad_B.data != nullptr) {
check_sizes(grad_B, "grad_B", 0, B, "B", 0, "opsa_vjp");
check_sizes(grad_B, "grad_B", 1, B, "B", 1, "opsa_vjp");
check_sizes(grad_B, "grad_B", 0, B, "B", 0, operation_name);
check_sizes(grad_B, "grad_B", 1, B, "B", 1, operation_name);
}
check_opsa(grad_output, A, B, indices_output);
check_opsa(grad_output, A, B, indices_output, operation_name);
}

#endif
41 changes: 23 additions & 18 deletions mops/src/internal/checks/opsaw.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef MOPS_CHECKS_OPSAW_HPP
#define MOPS_CHECKS_OPSAW_HPP

#include <string>
#include "mops/tensor.hpp"
#include "utils.hpp"

Expand All @@ -11,17 +12,20 @@ void check_opsaw(
mops::Tensor<scalar_t, 2> B,
mops::Tensor<scalar_t, 2> W,
mops::Tensor<int32_t, 1> indices_W,
mops::Tensor<int32_t, 1> indices_output
mops::Tensor<int32_t, 1> indices_output,
std::string operation_name
) {
check_sizes(A, "A", 0, B, "B", 0, "opsaw");
check_sizes(A, "A", 1, output, "output", 1, "opsaw");
check_sizes(B, "B", 1, output, "output", 2, "opsaw");
check_sizes(A, "A", 0, indices_output, "indices_output", 0, "opsaw");
check_sizes(A, "A", 0, indices_W, "indices_W", 0, "opsaw");
check_sizes(W, "W", 0, output, "output", 0, "opsaw");
check_sizes(B, "B", 1, W, "W", 1, "opsaw");
check_index_tensor(indices_output, "indices_output", output.shape[0], "opsaw");
check_index_tensor(indices_W, "indices_W", output.shape[0], "opsaw");
check_sizes(A, "A", 0, B, "B", 0, operation_name);
check_sizes(A, "A", 1, output, "output", 1, operation_name);
check_sizes(B, "B", 1, output, "output", 2, operation_name);
check_sizes(A, "A", 0, indices_output, "indices_output", 0, operation_name);
check_sizes(A, "A", 0, indices_W, "indices_W", 0, operation_name);
check_sizes(W, "W", 0, output, "output", 0, operation_name);
check_sizes(B, "B", 1, W, "W", 1, operation_name);
if (operation_name.rfind("cuda_", 0) != 0) {
check_index_tensor(indices_output, "indices_output", output.shape[0], operation_name);
check_index_tensor(indices_W, "indices_W", output.shape[0], operation_name);
}
}

template <typename scalar_t>
Expand All @@ -34,21 +38,22 @@ void check_opsaw_vjp(
mops::Tensor<scalar_t, 2> B,
mops::Tensor<scalar_t, 2> W,
mops::Tensor<int32_t, 1> indices_W,
mops::Tensor<int32_t, 1> indices_output
mops::Tensor<int32_t, 1> indices_output,
std::string operation_name
) {
if (grad_A.data != nullptr) {
check_sizes(grad_A, "grad_A", 0, A, "A", 0, "opsaw_vjp");
check_sizes(grad_A, "grad_A", 1, A, "A", 1, "opsaw_vjp");
check_sizes(grad_A, "grad_A", 0, A, "A", 0, operation_name);
check_sizes(grad_A, "grad_A", 1, A, "A", 1, operation_name);
}
if (grad_B.data != nullptr) {
check_sizes(grad_B, "grad_B", 0, B, "B", 0, "opsaw_vjp");
check_sizes(grad_B, "grad_B", 1, B, "B", 1, "opsaw_vjp");
check_sizes(grad_B, "grad_B", 0, B, "B", 0, operation_name);
check_sizes(grad_B, "grad_B", 1, B, "B", 1, operation_name);
}
if (grad_W.data != nullptr) {
check_sizes(grad_W, "grad_W", 0, W, "W", 0, "opsaw_vjp");
check_sizes(grad_W, "grad_W", 1, W, "W", 1, "opsaw_vjp");
check_sizes(grad_W, "grad_W", 0, W, "W", 0, operation_name);
check_sizes(grad_W, "grad_W", 1, W, "W", 1, operation_name);
}
check_opsaw(grad_output, A, B, W, indices_W, indices_output);
check_opsaw(grad_output, A, B, W, indices_W, indices_output, operation_name);
}

#endif
61 changes: 38 additions & 23 deletions mops/src/internal/checks/sasaw.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef MOPS_CHECKS_SASAW_HPP
#define MOPS_CHECKS_SASAW_HPP

#include <string>
#include "mops/tensor.hpp"
#include "utils.hpp"

Expand All @@ -15,21 +16,24 @@ void check_sasaw(
mops::Tensor<int32_t, 1> indices_W_1,
mops::Tensor<int32_t, 1> indices_W_2,
mops::Tensor<int32_t, 1> indices_output_1,
mops::Tensor<int32_t, 1> indices_output_2
mops::Tensor<int32_t, 1> indices_output_2,
std::string operation_name
) {
check_sizes(A, "A", 0, B, "B", 0, "sasaw");
check_sizes(W, "W", 0, output, "output", 0, "sasaw");
check_sizes(B, "B", 1, W, "W", 2, "sasaw");
check_sizes(C, "C", 0, indices_A, "indices_A", 0, "sasaw");
check_sizes(C, "C", 0, indices_W_2, "indices_W_2", 0, "sasaw");
check_sizes(C, "C", 0, indices_output_2, "indices_output_2", 0, "sasaw");
check_sizes(A, "A", 0, indices_output_1, "indices_output_1", 0, "sasaw");
check_sizes(A, "A", 0, indices_W_1, "indices_W_1", 0, "sasaw");
check_index_tensor(indices_output_1, "indices_output_1", output.shape[0], "sasaw");
check_index_tensor(indices_W_1, "indices_W_1", output.shape[0], "sasaw");
check_index_tensor(indices_A, "indices_A", A.shape[1], "sasaw");
check_index_tensor(indices_W_2, "indices_W_2", B.shape[1], "sasaw");
check_index_tensor(indices_output_2, "indices_output_2", output.shape[1], "sasaw");
check_sizes(A, "A", 0, B, "B", 0, operation_name);
check_sizes(W, "W", 0, output, "output", 0, operation_name);
check_sizes(B, "B", 1, W, "W", 2, operation_name);
check_sizes(C, "C", 0, indices_A, "indices_A", 0, operation_name);
check_sizes(C, "C", 0, indices_W_2, "indices_W_2", 0, operation_name);
check_sizes(C, "C", 0, indices_output_2, "indices_output_2", 0, operation_name);
check_sizes(A, "A", 0, indices_output_1, "indices_output_1", 0, operation_name);
check_sizes(A, "A", 0, indices_W_1, "indices_W_1", 0, operation_name);
if (operation_name.rfind("cuda_", 0) != 0) {
check_index_tensor(indices_A, "indices_A", A.shape[1], operation_name);
check_index_tensor(indices_W_1, "indices_W_1", output.shape[0], operation_name);
check_index_tensor(indices_W_2, "indices_W_2", B.shape[1], operation_name);
check_index_tensor(indices_output_1, "indices_output_1", output.shape[0], operation_name);
check_index_tensor(indices_output_2, "indices_output_2", output.shape[1], operation_name);
}
}

template <typename scalar_t>
Expand All @@ -46,23 +50,34 @@ void check_sasaw_vjp(
mops::Tensor<int32_t, 1> indices_W_1,
mops::Tensor<int32_t, 1> indices_W_2,
mops::Tensor<int32_t, 1> indices_output_1,
mops::Tensor<int32_t, 1> indices_output_2
mops::Tensor<int32_t, 1> indices_output_2,
std::string operation_name
) {
if (grad_A.data != nullptr) {
check_sizes(grad_A, "grad_A", 0, A, "A", 0, "sasaw_vjp");
check_sizes(grad_A, "grad_A", 1, A, "A", 1, "sasaw_vjp");
check_sizes(grad_A, "grad_A", 0, A, "A", 0, operation_name);
check_sizes(grad_A, "grad_A", 1, A, "A", 1, operation_name);
}
if (grad_B.data != nullptr) {
check_sizes(grad_B, "grad_B", 0, B, "B", 0, "sasaw_vjp");
check_sizes(grad_B, "grad_B", 1, B, "B", 1, "sasaw_vjp");
check_sizes(grad_B, "grad_B", 0, B, "B", 0, operation_name);
check_sizes(grad_B, "grad_B", 1, B, "B", 1, operation_name);
}
if (grad_W.data != nullptr) {
check_sizes(grad_W, "grad_W", 0, W, "W", 0, "sasaw_vjp");
check_sizes(grad_W, "grad_W", 1, W, "W", 1, "sasaw_vjp");
check_sizes(grad_W, "grad_W", 2, W, "W", 2, "sasaw_vjp");
check_sizes(grad_W, "grad_W", 0, W, "W", 0, operation_name);
check_sizes(grad_W, "grad_W", 1, W, "W", 1, operation_name);
check_sizes(grad_W, "grad_W", 2, W, "W", 2, operation_name);
}
check_sasaw(
grad_output, A, B, C, W, indices_A, indices_W_1, indices_W_2, indices_output_1, indices_output_2
grad_output,
A,
B,
C,
W,
indices_A,
indices_W_1,
indices_W_2,
indices_output_1,
indices_output_2,
operation_name
);
}

Expand Down
4 changes: 2 additions & 2 deletions mops/src/opsa/cpu.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void mops::outer_product_scatter_add(
Tensor<scalar_t, 2> B,
Tensor<int32_t, 1> indices_output
) {
check_opsa(output, A, B, indices_output);
check_opsa(output, A, B, indices_output, "cpu_outer_product_scatter_add");

size_t size_output = output.shape[0];
size_t size_output_inner = output.shape[1] * output.shape[2];
Expand Down Expand Up @@ -61,7 +61,7 @@ void mops::outer_product_scatter_add_vjp(
Tensor<scalar_t, 2> B,
Tensor<int32_t, 1> indices_output
) {
check_opsa_vjp(grad_A, grad_B, grad_output, A, B, indices_output);
check_opsa_vjp(grad_A, grad_B, grad_output, A, B, indices_output, "cpu_outer_product_scatter_add_vjp");

bool calculate_grad_A = grad_A.data != nullptr;
bool calculate_grad_B = grad_B.data != nullptr;
Expand Down
6 changes: 4 additions & 2 deletions mops/src/opsa/opsa.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void mops::cuda::outer_product_scatter_add(
Tensor<scalar_t, 2> B,
Tensor<int32_t, 1> indices_output
) {
check_opsa(output, A, B, indices_output);
check_opsa(output, A, B, indices_output, "cuda_outer_product_scatter_add");

int32_t* first_occurences = calculate_first_occurences_cuda(
indices_output.data, indices_output.shape[0], output.shape[0]
Expand Down Expand Up @@ -250,7 +250,9 @@ void mops::cuda::outer_product_scatter_add_vjp(
Tensor<scalar_t, 2> B,
Tensor<int32_t, 1> indices_output
) {
check_opsa_vjp(grad_A, grad_B, grad_output, A, B, indices_output);
check_opsa_vjp(
grad_A, grad_B, grad_output, A, B, indices_output, "cuda_outer_product_scatter_add_vjp"
);

int32_t* first_occurences = calculate_first_occurences_cuda(
indices_output.data, indices_output.shape[0], grad_output.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions mops/src/opsaw/cpu.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void mops::outer_product_scatter_add_with_weights(
Tensor<int32_t, 1> indices_W,
Tensor<int32_t, 1> indices_output
) {
check_opsaw(output, A, B, W, indices_W, indices_output);
check_opsaw(output, A, B, W, indices_W, indices_output, "cpu_outer_product_scatter_add_with_weights");

scalar_t* o_ptr = output.data;
scalar_t* a_ptr = A.data;
Expand Down Expand Up @@ -66,7 +66,7 @@ void mops::outer_product_scatter_add_with_weights_vjp(
Tensor<int32_t, 1> indices_W,
Tensor<int32_t, 1> indices_output
) {
check_opsaw_vjp(grad_A, grad_B, grad_W, grad_output, A, B, W, indices_W, indices_output);
check_opsaw_vjp(grad_A, grad_B, grad_W, grad_output, A, B, W, indices_W, indices_output, "cpu_outer_product_scatter_add_with_weights_vjp");

bool calculate_grad_A = grad_A.data != nullptr;
bool calculate_grad_B = grad_B.data != nullptr;
Expand Down
4 changes: 2 additions & 2 deletions mops/src/sasaw/cpu.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void mops::sparse_accumulation_scatter_add_with_weights(
Tensor<int32_t, 1> indices_output_1,
Tensor<int32_t, 1> indices_output_2
) {
check_sasaw(output, A, B, C, W, indices_A, indices_W_1, indices_W_2, indices_output_1, indices_output_2);
check_sasaw(output, A, B, C, W, indices_A, indices_W_1, indices_W_2, indices_output_1, indices_output_2, "cpu_sparse_accumulation_scatter_add_with_weights");

scalar_t* o_ptr = output.data;
scalar_t* a_ptr = A.data;
Expand Down Expand Up @@ -88,7 +88,7 @@ void mops::sparse_accumulation_scatter_add_with_weights_vjp(
Tensor<int32_t, 1> indices_output_1,
Tensor<int32_t, 1> indices_output_2
) {
check_sasaw_vjp(grad_A, grad_B, grad_W, grad_output, A, B, C, W, indices_A, indices_W_1, indices_W_2, indices_output_1, indices_output_2);
check_sasaw_vjp(grad_A, grad_B, grad_W, grad_output, A, B, C, W, indices_A, indices_W_1, indices_W_2, indices_output_1, indices_output_2, "cpu_sparse_accumulation_scatter_add_with_weights_vjp");

bool calculate_grad_A = grad_A.data != nullptr;
bool calculate_grad_B = grad_B.data != nullptr;
Expand Down
10 changes: 5 additions & 5 deletions python/mops-torch/tests/opsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_opsa_grads(dtype, device):
(A, B, indices, output_size),
)

if device != "cuda": # not yet implemented
assert torch.autograd.gradgradcheck(
mops.torch.outer_product_scatter_add,
(A, B, indices, output_size),
)
# not yet implemented
# assert torch.autograd.gradgradcheck(
# mops.torch.outer_product_scatter_add,
# (A, B, indices, output_size),
# )


def test_opsa_ref():
Expand Down
10 changes: 5 additions & 5 deletions python/mops-torch/tests/opsaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def test_opsaw_grads(dtype, device):
(A, B, W, indices_W, indices_output),
)

if device != "cuda": # not yet implemented
assert torch.autograd.gradgradcheck(
mops.torch.outer_product_scatter_add_with_weights,
(A, B, W, indices_W, indices_output),
)
# not yet implemented
# assert torch.autograd.gradgradcheck(
# mops.torch.outer_product_scatter_add_with_weights,
# (A, B, W, indices_W, indices_output),
# )


def test_opsaw_ref():
Expand Down
32 changes: 16 additions & 16 deletions python/mops-torch/tests/sasaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,22 @@ def test_sasaw_grads(dtype, device):
),
)

if device != "cuda": # not yet implemented
assert torch.autograd.gradgradcheck(
mops.torch.sparse_accumulation_scatter_add_with_weights,
(
A,
B,
C,
W,
indices_A,
indices_W_1,
indices_W_2,
indices_output_1,
indices_output_2,
output_size_2,
),
)
# not yet implemented
# assert torch.autograd.gradgradcheck(
# mops.torch.sparse_accumulation_scatter_add_with_weights,
# (
# A,
# B,
# C,
# W,
# indices_A,
# indices_W_1,
# indices_W_2,
# indices_output_1,
# indices_output_2,
# output_size_2,
# ),
# )


def test_sasaw_ref():
Expand Down
5 changes: 3 additions & 2 deletions python/mops/tests/opsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def test_opsa_size_mismatch(valid_arguments):
with pytest.raises(
mops.status.MopsError,
match="Dimension mismatch: the sizes of A along "
"dimension 0 and indices_output along dimension 0 must match in opsa",
"dimension 0 and indices_output along dimension 0 must match in "
"cpu_outer_product_scatter_add",
):
opsa(A, B, indices_output, output_size)

Expand All @@ -81,7 +82,7 @@ def test_opsa_out_of_bounds(valid_arguments):

with pytest.raises(
mops.status.MopsError,
match="Index array indices_output in operation opsa "
match="Index array indices_output in operation cpu_outer_product_scatter_add "
"contains elements up to 10; "
"this would cause out-of-bounds accesses. With the provided "
"parameters, it can only contain elements up to 9",
Expand Down
Loading

0 comments on commit 40cfcec

Please sign in to comment.