Skip to content

Commit

Permalink
test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 30, 2024
1 parent bc65969 commit ab1307c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 83 deletions.
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
const lstm_weights_order offset_order = lstm_weights_order::iofz,
const uint32_t direction = 0,
const padding& output_padding = padding())
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B}, {output_padding}, {}, 2),
: primitive_base(id, {x, initial_hidden_state, initial_cell_state, seq_lenghts, W, R, B, out2_prim_id}, {output_padding}, {}, 1),
out2_prim_id(out2_prim_id),
cell(cell),
clip(clip),
Expand Down Expand Up @@ -282,7 +282,7 @@ struct lstm_seq : public primitive_base<lstm_seq> {
ret.push_back(cell);
*/
//ret.push_back(out1_prim_id);
ret.push_back(out2_prim_id);
//ret.push_back(out2_prim_id);
return ret;
}
};
Expand Down
10 changes: 1 addition & 9 deletions src/plugins/intel_gpu/src/graph/impls/ocl/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,7 @@ struct lstm_seq_impl : typed_primitive_impl_ocl<lstm_seq> {
protected:
kernel_arguments_data get_arguments(const typed_primitive_inst<lstm_seq>& instance) const override {
kernel_arguments_data args = parent::get_arguments(instance);
args.inputs = { instance.dep_memory_ptr(0), instance.dep_memory_ptr(1), instance.dep_memory_ptr(2), instance.dep_memory_ptr(3),
instance.dep_memory_ptr(4), instance.dep_memory_ptr(5), instance.dep_memory_ptr(6)};
args.cell = instance.cell_term() ? instance.cell_memory() : nullptr;
// New API for mutiple outputs support
for (size_t i = 0; i < 1; i++) {
args.outputs.push_back(instance.output_memory_ptr(i));
}
args.outputs.push_back(instance.second_output_mem());
args.outputs.push_back(instance.third_output_mem());
args.outputs.push_back(instance.dep_memory_ptr(instance.desc()->input_size() - 1));
return args;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,13 @@ KERNEL(lstm_seq)(
const __global INPUT5_TYPE* R,
const __global INPUT6_TYPE* B,
__global OUTPUT_TYPE* hidden_history,
__global OUTPUT2_TYPE* cell_state
__global OUTPUT1_TYPE* cell_state
)
{
const uint hidden_idx = get_global_id(0);
float local_hidden_state = 0;
const uint b = get_global_id(1);
for(int i=0;i<MAX_SEQ_LENGTH;i++){
for(int j=0;j<INPUT_SIZE;j++){
//printf("x is %f\n", x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)]);
}
}
//printf("initial hidden state is %f\n", initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]);
//printf("initial cell state is %f\n", initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]);
//printf("seq lentghts are %d \n", sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)]);
//printf("INPUT SIZE is %d \n", INPUT_SIZE);
for(int i=0;i<INPUT_SIZE;i++){
for(int j=0;j<4;j++){
//printf("j is %d hidden_idx is %d HIDDEN_SIZE is %d idx is %d hidden_idx+j*HIDDEN_SIZE is %d\n", j, hidden_idx, HIDDEN_SIZE, INPUT4_GET_INDEX_SAFE(0, hidden_idx+j*HIDDEN_SIZE, i, 0), W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+j*HIDDEN_SIZE, i, 0)], hidden_idx+j*HIDDEN_SIZE);
//W[INPUT4_GET_INDEX_SAFE(hidden_idx+j*HIDDEN_SIZE, 0, i, 0)], hidden_idx+j*HIDDEN_SIZE);
//printf("W are %f for idx %d %d\n", W[INPUT4_GET_INDEX_SAFE(0, i, hidden_idx+j*HIDDEN_SIZE, 0)], hidden_idx+j*HIDDEN_SIZE, i);
//printf("oj W are %f for idx %d %d\n", W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+j*HIDDEN_SIZE, i, 0)], hidden_idx+j*HIDDEN_SIZE, i);
}
}
//printf("input0 pitches %d, %d, %d, %d \n", INPUT0_PITCHES[0], INPUT0_PITCHES[1], INPUT0_PITCHES[2], INPUT0_PITCHES[3]);
//printf("input1 pitches %d, %d, %d, %d \n", INPUT1_PITCHES[0], INPUT1_PITCHES[1], INPUT1_PITCHES[2], INPUT1_PITCHES[3]);
//printf("input2 pitches %d, %d, %d, %d \n", INPUT2_PITCHES[0], INPUT2_PITCHES[1], INPUT2_PITCHES[2], INPUT2_PITCHES[3]);
//printf("input3 pitches %d, %d, %d, %d \n", INPUT3_PITCHES[0], INPUT3_PITCHES[1], INPUT3_PITCHES[2], INPUT3_PITCHES[3]);
//printf("input4 pitches %d, %d, %d, %d \n", INPUT4_PITCHES[0], INPUT4_PITCHES[1], INPUT4_PITCHES[2], INPUT4_PITCHES[3]);
//printf("input5 pitches %d, %d, %d, %d \n", INPUT5_PITCHES[0], INPUT5_PITCHES[1], INPUT5_PITCHES[2], INPUT5_PITCHES[3]);
//printf("input6 pitches %d, %d, %d, %d \n", INPUT6_PITCHES[0], INPUT6_PITCHES[1], INPUT6_PITCHES[2], INPUT6_PITCHES[3]);
//printf("output pitches %d, %d, %d, %d \n", OUTPUT_PITCHES[0], OUTPUT_PITCHES[1], OUTPUT_PITCHES[2], OUTPUT_PITCHES[3]);
//printf("output1 pitches %d, %d, %d, %d \n", OUTPUT1_PITCHES[0], OUTPUT1_PITCHES[1], OUTPUT1_PITCHES[2], OUTPUT1_PITCHES[3]);
//printf("output2 pitches %d, %d, %d, %d \n", OUTPUT2_PITCHES[0], OUTPUT2_PITCHES[1], OUTPUT2_PITCHES[2], OUTPUT2_PITCHES[3]);
for(int i=0;i<HIDDEN_SIZE;i++){
for(int j=0;j<4;j++){
//printf("R are %f \n", R[INPUT5_GET_INDEX_SAFE(0,hidden_idx+j*HIDDEN_SIZE, 0, i)]);
}
}
for(int j=0;j<4;j++){
//printf("B are %f \n", B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+j*HIDDEN_SIZE, 0, 0)]);
}
const int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};
const uint hidden_idxl = get_local_id(0);
const uint bl = get_local_id(1);
const int gate_num = 4;
float hidden_result[gate_num];
float input_result[gate_num];
Expand All @@ -70,75 +33,50 @@ KERNEL(lstm_seq)(
}
}
}
/*
for(int i=0;i<4;i++) {
printf("HIDDEN_SIZE is %d weight is %d MAX_SEQ_LENGTH %d INPUT_SIZE %d\n", HIDDEN_SIZE, weight_offsets[i], MAX_SEQ_LENGTH, INPUT_SIZE);
}
printf("offsets %d %d %d %d \n", GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O);
printf("kernel usage %d seq len\n", MAX_SEQ_LENGTH);
*/
for(int i=0;i<MAX_SEQ_LENGTH;i++){
for(int k=0;k<gate_num;k++){
hidden_result[k] = 0;
input_result[k] = 0;
printf("I set to zero %p %p\n", &hidden_result[k], &input_result[k]);
}
for(int k=0;k<gate_num;k++){
for(int j=0;j<HIDDEN_SIZE;j++) {
if(i==0){
hidden_result[k] += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];

//printf("mult %f %f \n", initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+GEMM_OFFSET_F, 0, 0)]);
}else{
hidden_result[k] += local_hidden_state*R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
//printf("hidden_result[k] %f %f\n",hidden_state[INPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)], R[INPUT5_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
}
}

for(int j=0;j<INPUT_SIZE;j++) {
input_result[k] += x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
//printf("input_result[b][hidden_idx][k] %f %f\n", x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)], W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)]);
}
for(int j=0;j<HIDDEN_SIZE;j++){
gate_output[k] = hidden_result[k] + input_result[k] + B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)];
printf("gate_output[b][hidden_idx][k] %f %f %f for b %d and k %d\n", hidden_result[k], input_result[k], B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)], b, k);
}
switch(k){
case 0:
case 1:
case 3:
gate_output[k] = ACTIVATION_F(ACTIVATION_CLIP(gate_output[k], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_F);
//printf("03 gate output is %f\n", gate_output[b][hidden_idx][k]);
break;
case 2:
gate_output[k] = ACTIVATION_G(ACTIVATION_CLIP(gate_output[k], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_G);
//printf("2 gate output is %f\n", gate_output[b][hidden_idx][k]);
break;
default:
break;
}
}

if (i==0){
cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, 0, 0)];
//printf("cell_stateeq %f %f for b %d %d %d %d %d\n" , gate_output[b][hidden_idx][0], initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, 0, 0)], b, OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), OUTPUT2_GET_INDEX_SAFE(0, b, 0, 0), OUTPUT2_GET_INDEX_SAFE(0, 0, b, 0), OUTPUT2_GET_INDEX_SAFE(0, 0, 0, b));
cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += gate_output[1]*gate_output[2];
//printf("cell_stateplus %f %f OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0) %d for b %d \n" , gate_output[b][hidden_idx][1], gate_output[b][hidden_idx][2], OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), b);
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, 0, 0)];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += gate_output[1]*gate_output[2];
}else{
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] *= gate_output[0];
//printf("cell_stateeqq is %f\n", cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]);
cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[1]*gate_output[2];
//printf("cell_stateppliu is %f\n", cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] );
cell_state[OUTPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] *= gate_output[0];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] += gate_output[1]*gate_output[2];
}
//hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]
local_hidden_state = gate_output[3]*ACTIVATION_H(ACTIVATION_CLIP(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_CLIP), ACTIVATION_PARAMS_H);
//printf("hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] is %f on b %d\n", hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], b);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, 0)] = local_hidden_state;
printf("hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] is %f\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)]);
}
//printf("cell state for %d is %f \n", OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0), cell_state[OUTPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)]);
printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
//hidden_history[OUTPUT_GET_INDEX_SAFE(b, 1, 0, 0)] = 69;
printf("result is %f %f \n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 0, 0)], hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 1, 0)]);
}
7 changes: 2 additions & 5 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,17 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memory1};
p.add_primitive(*op, mutable_prim_1);

inputs.push_back(cldnn::input_info(mutable_id_1));
auto s_id = inputs.back().pid;
cldnn::lstm_seq prim(lstm_seq_id + ".out0", inputs[0], inputs[1], \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), "", "", \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), "", mutable_id_1, \
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, 0);
//prim.out1_prim_id = f_id;
prim.out2_prim_id = s_id;
p.add_primitive(*op, prim);
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out2", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memory1));
int b = op->get_input_shape(0)[0];
int seqlen = op->get_input_shape(0)[1];
int hidden_size = op->get_input_shape(1)[2];
p.add_primitive(*op, cldnn::crop(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, \
cldnn::tensor{ b, 1, seqlen-1, hidden_size}, cldnn::tensor{ 0, 0, 0, 0 }));
cldnn::tensor{ b, 1, hidden_size, 1}, cldnn::tensor{ 0, 0, 0, seqlen-1}));
}

REGISTER_FACTORY_IMPL(v4, LSTMCell);
Expand Down

0 comments on commit ab1307c

Please sign in to comment.