Skip to content

Commit

Permalink
wip, errors on half
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Aug 4, 2024
1 parent 125884d commit 08fb207
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ KERNEL (concatenation_gpu_ref)(
uint output_offset = FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR out_b, out_f, out_w, out_z, out_y, out_x);

INPUT0_TYPE result = input[input_offset];
printf("result is %f for input_offset %d from %d %d %d %d %d %d\n", result, input_offset, b, f, w, z, y, x);
#if HAS_FUSED_OPS
FUSED_OPS;
output[output_offset] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ KERNEL(lstm_seq)(
const uint b = get_global_id(1);
const int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};
const int gate_num = 4;
//printf("b %d MAX_SEQ_LENGTH %d sequence_lengths %d\n", b, MAX_SEQ_LENGTH, sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)]);
ACCUMULATOR_TYPE hidden_result[gate_num];
ACCUMULATOR_TYPE input_result[gate_num];
ACCUMULATOR_TYPE gate_output[gate_num];

ACCUMULATOR_TYPE temp_cell_state = 0;
for(int k=0;k<gate_num;k++){
gate_output[k] = 0;
}
//printf("DIRECTION %d \n", DIRECTION);

const int real_seq_length = sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)];
for(int i=0;i<real_seq_length;i++){
for(int k=0;k<gate_num;k++){
Expand All @@ -56,7 +55,7 @@ KERNEL(lstm_seq)(
input_result[k] += x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)]*W[INPUT4_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], j, 0)];
}
}
gate_output[k] = hidden_result[k] + input_result[k] + B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)];
gate_output[k] = hidden_result[k] + input_result[k] + TO_ACCUMULATOR_TYPE(B[INPUT6_GET_INDEX_SAFE(0, hidden_idx+weight_offsets[k], 0, 0)]);

switch(k){
case 0:
Expand All @@ -73,26 +72,21 @@ KERNEL(lstm_seq)(
}

if (i==0){
//cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = ;
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = TO_OUTPUT_TYPE(gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]) + TO_OUTPUT_TYPE(gate_output[1]*gate_output[2]);
temp_cell_state = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] + gate_output[1]*gate_output[2];
}else{
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] *= TO_OUTPUT_TYPE(gate_output[0]);
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += TO_OUTPUT_TYPE(gate_output[1]*gate_output[2]);
temp_cell_state *= gate_output[0];
temp_cell_state += gate_output[1]*gate_output[2];
}
int cur_history_idx = i;
if(DIRECTION == 1){ //reverse
cur_history_idx = real_seq_length - 1 - i ;
}
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = TO_OUTPUT_TYPE(gate_output[3]*ACTIVATION_H(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H));
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[3]*ACTIVATION_H(temp_cell_state, ACTIVATION_PARAMS_H);
barrier(CLK_LOCAL_MEM_FENCE);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, cur_history_idx, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
barrier(CLK_LOCAL_MEM_FENCE);
if(i==real_seq_length-1){
cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = temp_cell_state;
}
}

//printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
for(int i=0;i<real_seq_length;i++){
//hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = i;
//printf("DIR %d result is %f for hididx %d b %d\n", DIRECTION, hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)], hidden_idx, b);
//printf("DIR %d hidden state is %f for hid idx %d b %d \n", DIRECTION, hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], hidden_idx, b);
}
}
6 changes: 3 additions & 3 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,21 +250,21 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
tensor_from_dims(op->get_output_shape(1)));

std::vector<cldnn::memory::ptr> shared_memories;
shared_memories.push_back(p.get_engine().allocate_memory(out12Layout, false));
shared_memories.push_back(p.get_engine().allocate_memory(out12Layout));
const cldnn::primitive_id mutable_id_1 = layerName + "_md_write1";
const cldnn::mutable_data mutable_prim_1{mutable_id_1, shared_memories.front()};
p.add_primitive(*op, mutable_prim_1);

std::cout << "layout is " << out12Layout << std::endl;
shared_memories.push_back(p.get_engine().allocate_memory(out12Layout, false));
shared_memories.push_back(p.get_engine().allocate_memory(out12Layout));
const cldnn::primitive_id mutable_id_2 = layerName + "_md_write2";
const cldnn::mutable_data mutable_prim_2{mutable_id_2, shared_memories.back()};
p.add_primitive(*op, mutable_prim_2);
int direction = op->get_direction() == ov::op::RecurrentSequenceDirection::REVERSE ? 1 : 0;
cldnn::lstm_seq prim(lstm_seq_id + ".out0", inputs[0], inputs[1], \
inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(bias), mutable_id_1, mutable_id_2, \
"", clip, 0, activations, activation_params, cldnn::lstm_weights_order::fizo, direction);
prim.output_data_types = get_output_data_types(op, {{ov::element::f32, ov::element::f16}});
//prim.output_data_types = get_output_data_types(op, {{ov::element::f32, ov::element::f16}});
//prim.out1_prim_id = f_id;
p.add_primitive(*op, prim);
p.add_primitive(*op, cldnn::mutable_data(lstm_seq_id + ".out1", {cldnn::input_info(lstm_seq_id + ".out0")}, shared_memories.front()));
Expand Down
1 change: 0 additions & 1 deletion src/plugins/intel_gpu/src/plugin/program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ std::vector<cldnn::input_info> ProgramBuilder::GetInputInfo(const std::shared_pt
bool is_legacy_multiple_outputs = !allow_new_shape_infer
|| ov::is_type<ov::op::v1::Split>(prevOp)
|| ov::is_type<ov::op::v1::VariadicSplit>(prevOp)
|| ov::is_type<ov::op::v5::LSTMSequence>(prevOp)
|| ov::is_type<ov::op::v4::LSTMCell>(prevOp);
if (prevOp->get_output_size() > 1 && is_legacy_multiple_outputs) {
prevName += ".out" + std::to_string(op->get_input_source_output(i).get_index());
Expand Down

0 comments on commit 08fb207

Please sign in to comment.