Skip to content

Commit

Permalink
[GPU opt guide][SYCL][joint matrix] update the test to match the guide (
Browse files Browse the repository at this point in the history
  • Loading branch information
dkhaldi authored Jan 24, 2024
1 parent 4bb588d commit d217052
Showing 1 changed file with 34 additions and 42 deletions.
76 changes: 34 additions & 42 deletions Publications/GPU-Opt-Guide/joint-matrix/joint-matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
#include <iostream>
#include <sycl/sycl.hpp>

// using joint_matrix = sycl::ext::oneapi::experimental::matrix;
using use = sycl::ext::oneapi::experimental::matrix::use;
using layout = sycl::ext::oneapi::experimental::matrix::layout;
using bfloat16 = sycl::ext::oneapi::bfloat16;

#define SG_SZ 16
constexpr size_t SG_SZ = 16;

#define TM 8
#define TN SG_SZ
#define TK 16
constexpr size_t TM = 8;
constexpr size_t TN = SG_SZ;
constexpr size_t TK = 16;

#define BF16_EPSILON 0.00781250
constexpr float ALPHA = 2.0;

constexpr float BF16_EPSILON = 0.00781250;

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
private:
Expand All @@ -42,10 +43,9 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,

sycl::queue q;
q.submit([&](sycl::handler &cgh) {
sycl::accessor accC(bufC, cgh, sycl::read_write, sycl::no_init);
sycl::accessor accC(bufC, cgh, sycl::read_write);
sycl::accessor accA(bufA, cgh, sycl::read_only);
sycl::accessor accB(bufB, cgh, sycl::read_only);

cgh.parallel_for(
sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
[=](sycl::nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
Expand All @@ -66,30 +66,32 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
// For B, we assume B has been already VNNIed.
sycl::ext::oneapi::experimental::matrix::joint_matrix<
sycl::sub_group, bfloat16, use::b, TK, TN,
sycl::ext::intel::experimental::matrix::layout::packed>
layout::ext_intel_packed>
sub_b;
sycl::ext::oneapi::experimental::matrix::joint_matrix<
sycl::sub_group, float, use::accumulator, TM, TN>
sub_c;

joint_matrix_load(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, layout::row_major);
for (int k = 0; k < K / TK; k += 1) { //
joint_matrix_fill(sg, sub_c, 1.0);
for (int k = 0; k < K / TK; k += 1) {
joint_matrix_load(
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
sg, sub_a,
accA.template get_multi_ptr<sycl::access::decorated::no>() +
(sg_startx * TM) * K + k * TK,
K);
joint_matrix_load(sg, sub_b,
accB.get_pointer() + (k * TK / 2) * (N * 2) +
sg_starty / SG_SZ * TN * 2,
N * 2);
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
joint_matrix_load(
sg, sub_b,
accB.template get_multi_ptr<sycl::access::decorated::no>() +
(k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2,
N * 2);
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
}
joint_matrix_store(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, layout::row_major);
joint_matrix_apply(sg, sub_c, [=](float &x) { x *= ALPHA; });
joint_matrix_store(
sg, sub_c,
accC.template get_multi_ptr<sycl::access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
N, layout::row_major);
}); // parallel for
}).wait();
// kernel end
Expand All @@ -100,53 +102,43 @@ static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
bfloat16 A[MATRIX_M][MATRIX_K];
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
unsigned short Aref[MATRIX_M][MATRIX_K];
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
float C[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];

float make_fp32(short x) {
unsigned int y = x;
float make_fp32(bfloat16 x) {
unsigned int y = *((int *)&x);
y = y << 16;
float *res = reinterpret_cast<float *>(&y);
return *res;
}

unsigned short make_bf16(float x) {
int *res = reinterpret_cast<int *>(&x);
*res = *res >> 16;
return (unsigned short)*res;
}

void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
int K) {
for (int m = 0; m < M; m++)
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
short *va = (short *)(A_mem + m * K + k);
short *vb = (short *)(B_mem + k * N + n);
// Because B was assumed VNNIed
bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
float acc = *((float *)(C_mem + m * N + n));
for (int i = 0; i < 2; i++) {
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
}
*((float *)(C_mem + m * N + n)) = acc;
}
*((float *)(C_mem + m * N + n)) *= ALPHA;
}
}

int main() {
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_K; j++) {
// bfloat16 is created using unsigned short since conversion from float to
// bfloat16 is not supported on the host side yet
A[i][j] = bfloat16(1.0f * (i + j));
Aref[i][j] = make_bf16(1.0f * (i + j));
}
}
for (int i = 0; i < MATRIX_K / 2; i++) {
for (int j = 0; j < MATRIX_N * 2; j++) {
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
}
}
for (int i = 0; i < MATRIX_M; i++) {
Expand All @@ -161,13 +153,13 @@ int main() {
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
matrix_multiply(MC, MA, MB);
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
MATRIX_N, MATRIX_K / 2);

bool res = true;
for (int i = 0; i < MATRIX_M; i++) {
for (int j = 0; j < MATRIX_N; j++) {
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
if ((fabs(C[i][j] - D[i][j])) > BF16_EPSILON)
res = false;
}
}
Expand Down

0 comments on commit d217052

Please sign in to comment.