Skip to content

Commit

Permalink
wip3
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 25, 2024
1 parent c00ff8a commit 0451429
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ ov::pass::ConvertLSTMSequenceToTensorIterator::ConvertLSTMSequenceToTensorIterat
}

ov::pass::ConvertSequenceToTensorIterator::ConvertSequenceToTensorIterator() {
//d_matcher<ConvertLSTMSequenceToTensorIterator>();
//add_matcher<ConvertLSTMSequenceToTensorIterator>();
add_matcher<ConvertRNNSequenceToTensorIterator>();
add_matcher<ConvertGRUSequenceToTensorIterator>();
}
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct lstm_seq_impl : typed_primitive_impl_ocl<lstm_seq> {
for (size_t i = 0; i < instance.outputs_memory_count(); i++) {
args.outputs.push_back(instance.output_memory_ptr(i));
}
args.outputs.push_back(instance.second_output_mem());
args.outputs.push_back(instance.third_output_mem());
//args.outputs.push_back(instance.second_output_mem());
//args.outputs.push_back(instance.third_output_mem());
return args;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
// Copyright (C) 2018-2024 Intel Corporation
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/fetch_data.cl"

// initial_hidden_state
// initial_cell_state
// sequence_lengths
// WR
// B
// output0
//output1
//output2
KERNEL(lstm_seq)(
const __global INPUT0_TYPE* x,
const __global INPUT1_TYPE* initial_hidden_state,
Expand All @@ -25,8 +17,39 @@ KERNEL(lstm_seq)(
__global OUTPUT2_TYPE* cell_state
)
{

const uint hidden_idx = get_global_id(0);
const uint b = get_global_id(1);
for(int i=0;i<MAX_SEQ_LENGTH;i++){
for(int j=0;j<INPUT_SIZE;j++){
printf("x is %f\n", x[INPUT0_GET_INDEX(b, i, j, 0)]);
}
}
printf("initial hidden state is %f\n", initial_hidden_state[INPUT1_GET_INDEX(b, 0, hidden_idx, 0)]);
printf("initial cell state is %f\n", initial_hidden_state[INPUT2_GET_INDEX(b, 0, hidden_idx, 0)]);
printf("seq lentghts are %d \n", sequence_lengths[INPUT3_GET_INDEX(b, 0, 0, 0)]);
printf("INPUT SIZE is %d \n", INPUT_SIZE);
for(int i=0;i<INPUT_SIZE;i++){
for(int j=0;j<4;j++){
//printf("j is %d hidden_idx is %d HIDDEN_SIZE is %d idx is %d hidden_idx+j*HIDDEN_SIZE is %d\n", j, hidden_idx, HIDDEN_SIZE, INPUT4_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, i, 0), W[INPUT4_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, i, 0)], hidden_idx+j*HIDDEN_SIZE);
printf("W are %f", W[INPUT4_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, 0, i, 0)]);
}
}
printf("input0 pitches %d, %d, %d, %d \n", INPUT0_PITCHES[0], INPUT0_PITCHES[1], INPUT0_PITCHES[2], INPUT0_PITCHES[3]);
printf("input1 pitches %d, %d, %d, %d \n", INPUT1_PITCHES[0], INPUT1_PITCHES[1], INPUT1_PITCHES[2], INPUT1_PITCHES[3]);
printf("input2 pitches %d, %d, %d, %d \n", INPUT2_PITCHES[0], INPUT2_PITCHES[1], INPUT2_PITCHES[2], INPUT2_PITCHES[3]);
printf("input3 pitches %d, %d, %d, %d \n", INPUT3_PITCHES[0], INPUT3_PITCHES[1], INPUT3_PITCHES[2], INPUT3_PITCHES[3]);
printf("input4 pitches %d, %d, %d, %d \n", INPUT4_PITCHES[0], INPUT4_PITCHES[1], INPUT4_PITCHES[2], INPUT4_PITCHES[3]);
printf("input5 pitches %d, %d, %d, %d \n", INPUT5_PITCHES[0], INPUT5_PITCHES[1], INPUT5_PITCHES[2], INPUT5_PITCHES[3]);
printf("input6 pitches %d, %d, %d, %d \n", INPUT6_PITCHES[0], INPUT6_PITCHES[1], INPUT6_PITCHES[2], INPUT6_PITCHES[3]);
for(int i=0;i<HIDDEN_SIZE;i++){
for(int j=0;j<4;j++){
printf("R are %f \n", R[INPUT5_GET_INDEX(0,hidden_idx+j*HIDDEN_SIZE, 0, i, 0)]);
}
}
for(int j=0;j<4;j++){
printf("B are %f \n", B[INPUT6_GET_INDEX(0, hidden_idx+j*HIDDEN_SIZE, 0, 0)]);
}
const uint hidden_idxl = get_local_id(0);
const uint bl = get_local_id(1);
const int gate_num = 4;
Expand All @@ -44,13 +67,13 @@ KERNEL(lstm_seq)(
}
}
}
printf("kernel usage %d %d %d %d\n", hidden_idx, b, hidden_idxl, bl);
printf("kernel usage %d seq len\n", MAX_SEQ_LENGTH);
for(int i=0;i<MAX_SEQ_LENGTH;i++){
for(int k=0;k<gate_num;k++){
for(int j=0;j<HIDDEN_SIZE;j++) {
if(i==0){
hidden_result[b][hidden_idx][k] += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
printf("mult %d %f %f\n", __LINE__, initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
printf("mult %f %f\n", initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
}else{
hidden_result[b][hidden_idx][k] += hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
printf("hidden_result[b][hidden_idx][k] %f %f\n",hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
Expand Down Expand Up @@ -96,7 +119,7 @@ KERNEL(lstm_seq)(
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[b][hidden_idx][1]*gate_output[b][hidden_idx][2];
printf("cell_stateppliu is %f\n", cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] );
}
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[b][hidden_idx][3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 1)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[b][hidden_idx][3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
printf("hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] is %f\n", hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
printf("hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] is %f\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,8 @@ void inline fill_data_random(T* pointer,
double_t start_from = 0,
const int32_t k = 1,
const int seed = 1) {
if (range == 0) {
for (std::size_t i = 0; i < size; i++) {
pointer[i] = static_cast<T>(start_from);
}
return;
}

testing::internal::Random random(seed);
const uint32_t k_range = k * range; // range with respect to k
random.Generate(k_range);

if (start_from < 0 && !std::numeric_limits<T>::is_signed) {
start_from = 0;
}
for (std::size_t i = 0; i < size; i++) {
pointer[i] = static_cast<T>(start_from + static_cast<double>(random.Generate(k_range)) / k);
pointer[i] = static_cast<T>(i);
}
}

Expand Down

0 comments on commit 0451429

Please sign in to comment.