Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 26, 2024
1 parent 0451429 commit c6b74d3
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ KERNEL(lstm_seq)(
}
}
printf("initial hidden state is %f\n", initial_hidden_state[INPUT1_GET_INDEX(b, 0, hidden_idx, 0)]);
printf("initial cell state is %f\n", initial_hidden_state[INPUT2_GET_INDEX(b, 0, hidden_idx, 0)]);
printf("initial cell state is %f\n", initial_cell_state[INPUT2_GET_INDEX(b, 0, hidden_idx, 0)]);
printf("seq lentghts are %d \n", sequence_lengths[INPUT3_GET_INDEX(b, 0, 0, 0)]);
printf("INPUT SIZE is %d \n", INPUT_SIZE);
for(int i=0;i<INPUT_SIZE;i++){
for(int j=0;j<4;j++){
//printf("j is %d hidden_idx is %d HIDDEN_SIZE is %d idx is %d hidden_idx+j*HIDDEN_SIZE is %d\n", j, hidden_idx, HIDDEN_SIZE, INPUT4_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, i, 0), W[INPUT4_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, i, 0)], hidden_idx+j*HIDDEN_SIZE);
printf("W are %f", W[INPUT4_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, 0, i, 0)]);
//W[INPUT4_GET_INDEX(hidden_idx+j*HIDDEN_SIZE, 0, i, 0)], hidden_idx+j*HIDDEN_SIZE);
printf("W are %f for idx %d\n", W[INPUT4_GET_INDEX(0, 0, hidden_idx+j*HIDDEN_SIZE, 0)], hidden_idx+j*HIDDEN_SIZE);
}
}
printf("input0 pitches %d, %d, %d, %d \n", INPUT0_PITCHES[0], INPUT0_PITCHES[1], INPUT0_PITCHES[2], INPUT0_PITCHES[3]);
Expand All @@ -42,22 +43,24 @@ KERNEL(lstm_seq)(
printf("input4 pitches %d, %d, %d, %d \n", INPUT4_PITCHES[0], INPUT4_PITCHES[1], INPUT4_PITCHES[2], INPUT4_PITCHES[3]);
printf("input5 pitches %d, %d, %d, %d \n", INPUT5_PITCHES[0], INPUT5_PITCHES[1], INPUT5_PITCHES[2], INPUT5_PITCHES[3]);
printf("input6 pitches %d, %d, %d, %d \n", INPUT6_PITCHES[0], INPUT6_PITCHES[1], INPUT6_PITCHES[2], INPUT6_PITCHES[3]);
printf("output pitches %d, %d, %d, %d \n", OUTPUT_PITCHES[0], OUTPUT_PITCHES[1], OUTPUT_PITCHES[2], OUTPUT_PITCHES[3]);
printf("output1 pitches %d, %d, %d, %d \n", OUTPUT1_PITCHES[0], OUTPUT1_PITCHES[1], OUTPUT1_PITCHES[2], OUTPUT1_PITCHES[3]);
printf("output2 pitches %d, %d, %d, %d \n", OUTPUT2_PITCHES[0], OUTPUT2_PITCHES[1], OUTPUT2_PITCHES[2], OUTPUT2_PITCHES[3]);
for(int i=0;i<HIDDEN_SIZE;i++){
for(int j=0;j<4;j++){
printf("R are %f \n", R[INPUT5_GET_INDEX(0,hidden_idx+j*HIDDEN_SIZE, 0, i, 0)]);
printf("R are %f \n", R[INPUT5_GET_INDEX(0,hidden_idx+j*HIDDEN_SIZE, 0, i)]);
}
}
for(int j=0;j<4;j++){
printf("B are %f \n", B[INPUT6_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, 0, 0)]);
}
const int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};
const uint hidden_idxl = get_local_id(0);
const uint bl = get_local_id(1);
const int gate_num = 4;
__local float hidden_result[BATCH_SIZE][HIDDEN_SIZE][gate_num];
__local float input_result[BATCH_SIZE][HIDDEN_SIZE][gate_num];
__local float gate_output[BATCH_SIZE][HIDDEN_SIZE][gate_num];
int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};

for(int i=0;i<BATCH_SIZE;i++){
for(int j=0;j<HIDDEN_SIZE;j++){
for(int k=0;k<gate_num;k++){
Expand All @@ -67,26 +70,32 @@ KERNEL(lstm_seq)(
}
}
}
for(int i=0;i<4;i++) {
printf("HIDDEN_SIZE is %d weight is %d \n", HIDDEN_SIZE, weight_offsets[i]);
}
printf("offsets %d %d %d %d \n", GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O);

printf("kernel usage %d seq len\n", MAX_SEQ_LENGTH);
for(int i=0;i<MAX_SEQ_LENGTH;i++){
for(int k=0;k<gate_num;k++){
for(int j=0;j<HIDDEN_SIZE;j++) {
if(i==0){
hidden_result[b][hidden_idx][k] += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
printf("mult %f %f\n", initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
hidden_result[b][hidden_idx][k] += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];

printf("mult %f %f \n", initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+GEMM_OFFSET_F, 0, 0)]);
}else{
hidden_result[b][hidden_idx][k] += hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
printf("hidden_result[b][hidden_idx][k] %f %f\n",hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
}
}

for(int j=0;j<INPUT_SIZE;j++) {
input_result[b][hidden_idx][k] += x[INPUT0_GET_INDEX_SAFE(b, hidden_idx, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
printf("input_result[b][hidden_idx][k] %f %f\n", x[INPUT0_GET_INDEX_SAFE(b, hidden_idx, j, 0)], W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
input_result[b][hidden_idx][k] += x[INPUT0_GET_INDEX_SAFE(b, hidden_idx, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, 0, hidden_idx+weight_offsets[k], 0)];
printf("input_result[b][hidden_idx][k] %f %f\n", x[INPUT0_GET_INDEX_SAFE(b, hidden_idx, j, 0)], W[INPUT4_GET_INDEX_SAFE(0, 0, hidden_idx+weight_offsets[k], 0)]);
}
for(int j=0;j<HIDDEN_SIZE;j++){
gate_output[b][hidden_idx][k] = hidden_result[b][j][k] + input_result[b][j][k] + B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)];
printf("gate_output[b][hidden_idx][k] %f %f %f\n", hidden_result[b][j][k], input_result[b][j][k], B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)]);
printf("gate_output[b][hidden_idx][k] %f %f %f for b %d and k %d\n", hidden_result[b][j][k], input_result[b][j][k], B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)], b, k);
}
switch(k){
case 0:
Expand All @@ -109,19 +118,22 @@ KERNEL(lstm_seq)(
}

if (i==0){
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] = gate_output[b][hidden_idx][0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
printf("cell_stateeq %f %f\n" , gate_output[b][hidden_idx][0], initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]);
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[b][hidden_idx][1]*gate_output[b][hidden_idx][2];
printf("cell_stateplus %f %f OUTPUT2_GET_INDEX(b, hidden_idx, 0, 0) %d\n" , gate_output[b][hidden_idx][1], gate_output[b][hidden_idx][2], OUTPUT2_GET_INDEX(b, hidden_idx, 0, 0));
cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, 0, 0)] = gate_output[b][hidden_idx][0]*initial_cell_state[INPUT2_GET_INDEX(b, 0, 0, 0)];
printf("cell_stateeq %f %f for b %d %d %d %d %d\n" , gate_output[b][hidden_idx][0], initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, 0, 0)], b, OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), OUTPUT2_GET_INDEX_SAFE(0, b, 0, 0), OUTPUT2_GET_INDEX_SAFE(0, 0, b, 0), OUTPUT2_GET_INDEX_SAFE(0, 0, 0, b));
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] = gate_output[b][hidden_idx][1]*gate_output[b][hidden_idx][2];
printf("cell_stateplus %f %f OUTPUT2_GET_INDEX(b, hidden_idx, 0, 0) %d for b %d \n" , gate_output[b][hidden_idx][1], gate_output[b][hidden_idx][2], OUTPUT2_GET_INDEX(b, hidden_idx, 0, 0), b);
}else{
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] = gate_output[b][hidden_idx][0]*cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
printf("cell_stateeqq is %f\n", cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]);
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[b][hidden_idx][1]*gate_output[b][hidden_idx][2];
printf("cell_stateppliu is %f\n", cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] );
if(CELL_TERM) {
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[b][hidden_idx][1]*gate_output[b][hidden_idx][2];
printf("cell_stateppliu is %f\n", cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] );
}
}
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[b][hidden_idx][3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
printf("hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] is %f\n", hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]);
printf("hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] is %f on b %d\n", hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], b);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
printf("hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] is %f\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)]);
}
printf("cell state for 3 is %f", cell_state[OUTPUT2_GET_INDEX_SAFE(3, 0, 0, 0)]);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ JitConstants LSTMSeqKernelBase::GetJitConstants(const lstm_seq_params& params) c
if (params.has_cell) {
jit.AddConstants({MakeJitConstant("CELL_TERM", true),
MakeJitConstant("CELL_DIRECTION", params.cell_direction)});
} else {
jit.AddConstants({MakeJitConstant("CELL_TERM", false)});
}
if (params.input_forget) {
jit.AddConstants({MakeJitConstant("INPUT_FORGET", true)});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ ParamsKey LSTMSeqKernelRef::GetSupportedKey() const {
k.EnableTensorPitches();
k.EnableBatching();
k.EnableLSTMSeqCell();
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
return k;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3449,10 +3449,7 @@ std::vector<LSTMSequenceV1Params> generateV1ParamsBF16() {

std::vector<LSTMSequenceParams> generateCombinedParams() {
const std::vector<std::vector<LSTMSequenceParams>> generatedParams{
generateParams<element::Type_t::f64>(),
generateParams<element::Type_t::f32>(),
generateParamsBF16<element::Type_t::f16>(),
generateParamsBF16<element::Type_t::bf16>(),
};
std::vector<LSTMSequenceParams> combinedParams;

Expand All @@ -3477,7 +3474,7 @@ std::vector<LSTMSequenceV1Params> generateV1CombinedParams() {
return combinedParams;
}

INSTANTIATE_TEST_SUITE_P(smoke_LSTMSequence_With_Hardcoded_Refs,
INSTANTIATE_TEST_SUITE_P(*smoke_LSTMSequence_With_Hardcoded_Refs*,
ReferenceLSTMSequenceTest,
testing::ValuesIn(generateCombinedParams()),
ReferenceLSTMSequenceTest::getTestCaseName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void LSTMSequenceTest::SetUp() {
ov::pass::Manager manager;
if (direction == ov::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ov::pass::BidirectionalLSTMSequenceDecomposition>();
//manager.register_pass<ov::pass::ConvertLSTMSequenceToTensorIterator>();
manager.register_pass<ov::pass::ConvertLSTMSequenceToTensorIterator>();
manager.run_passes(function);
bool ti_found = ov::test::utils::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, true);
Expand Down

0 comments on commit c6b74d3

Please sign in to comment.