Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Aug 2, 2024
1 parent dfdd052 commit 12626fc
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ ov::pass::ConvertLSTMSequenceToTensorIterator::ConvertLSTMSequenceToTensorIterat
}

ov::pass::ConvertSequenceToTensorIterator::ConvertSequenceToTensorIterator() {
// add_matcher<ConvertLSTMSequenceToTensorIterator>();
add_matcher<ConvertLSTMSequenceToTensorIterator>();
add_matcher<ConvertRNNSequenceToTensorIterator>();
add_matcher<ConvertGRUSequenceToTensorIterator>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ KERNEL(lstm_seq)(
hidden_result[k] += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, 0, j, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
}else{
int prev_idx = i-1;
if(DIRECTION){ //reverse
if(DIRECTION == 1){ //reverse
prev_idx = real_seq_length - i ;
}
hidden_result[k] += hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, prev_idx, j)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
}
}

for(int j=0;j<INPUT_SIZE;j++) {
if(DIRECTION){ //reverse
if(DIRECTION == 1){ //reverse
input_result[k] += x[INPUT0_GET_INDEX_SAFE(b, real_seq_length-1-i, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
} else {
input_result[k] += x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
Expand Down Expand Up @@ -80,7 +80,7 @@ KERNEL(lstm_seq)(
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += (OUTPUT_TYPE)(gate_output[1]*gate_output[2]);
}
int cur_history_idx = i;
if(DIRECTION){ //reverse
if(DIRECTION == 1){ //reverse
cur_history_idx = real_seq_length - 1 - i ;
}
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = (OUTPUT_TYPE)(gate_output[3]*ACTIVATION_H(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H));
Expand All @@ -92,6 +92,7 @@ KERNEL(lstm_seq)(
//printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
for(int i=0;i<real_seq_length;i++){
//hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = i;
//printf("result is %f for hididx %d b %d\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)], hidden_idx, b);
printf("DIR %d result is %f for hididx %d b %d\n", DIRECTION, hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)], hidden_idx, b);
printf("DIR %d hidden state is %f for hid idx %d b %d \n", DIRECTION, hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], hidden_idx, b);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ JitConstants LSTMSeqKernelBase::GetJitConstants(const lstm_seq_params& params) c
if (params.input_forget) {
jit.AddConstants({MakeJitConstant("INPUT_FORGET", true)});
}
jit.AddConstants({MakeJitConstant("DIRECTION", params.direction)});
jit.AddConstants({MakeJitConstant("DIRECTION", static_cast<int>(params.direction))});

size_t size = params.inputs[1].Y().v;
jit.AddConstants({
Expand Down

0 comments on commit 12626fc

Please sign in to comment.