Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 30, 2024
1 parent ab1307c commit 1ccdacc
Showing 1 changed file with 11 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,16 @@ KERNEL(lstm_seq)(
const uint hidden_idx = get_global_id(0);
float local_hidden_state = 0;
const uint b = get_global_id(1);
//
const int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};
const int gate_num = 4;
float hidden_result[gate_num];
float input_result[gate_num];
float gate_output[gate_num];
for(int i=0;i<BATCH_SIZE;i++){
for(int j=0;j<HIDDEN_SIZE;j++){
for(int k=0;k<gate_num;k++){
hidden_result[k] = 0;
input_result[k] = 0;
gate_output[k] = 0;
}
}
for(int k=0;k<gate_num;k++){
gate_output[k] = 0;
}
pinrtf("W is %f R is %f B is %f\n, W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[0], 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[0], 0, 0)], B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[0], 0, 0)]);
for(int i=0;i<MAX_SEQ_LENGTH;i++){
for(int k=0;k<gate_num;k++){
hidden_result[k] = 0;
Expand All @@ -50,9 +46,8 @@ KERNEL(lstm_seq)(
for(int j=0;j<INPUT_SIZE;j++) {
input_result[k] += x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
}
for(int j=0;j<HIDDEN_SIZE;j++){
gate_output[k] = hidden_result[k] + input_result[k] + B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)];
}
gate_output[k] = hidden_result[k] + input_result[k] + B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)];

switch(k){
case 0:
case 1:
Expand All @@ -68,14 +63,14 @@ KERNEL(lstm_seq)(
}

if (i==0){
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, 0, 0)];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += gate_output[1]*gate_output[2];
}else{
cell_state[OUTPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] *= gate_output[0];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[1]*gate_output[2];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] *= gate_output[0];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += gate_output[1]*gate_output[2];
}
local_hidden_state = gate_output[3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, 0)] = local_hidden_state;
local_hidden_state = gate_output[3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = local_hidden_state;
}
printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
printf("result is %f %f \n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 0, 0)], hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 1, 0)]);
Expand Down

0 comments on commit 1ccdacc

Please sign in to comment.