Skip to content

Commit

Permalink
Merge branch 'gemv-support' of github.com:ucb-bar/gemmini into gemv-s…
Browse files Browse the repository at this point in the history
…upport
  • Loading branch information
lelzeiny committed May 7, 2024
2 parents 5f8d310 + 32009d2 commit a0c36f4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 33 deletions.
2 changes: 1 addition & 1 deletion software/libgemmini
Submodule libgemmini updated 1 files
+1 −1 gemmini.cc
37 changes: 25 additions & 12 deletions src/main/scala/gemmini/AccumulatorScale.scala
Original file line number Diff line number Diff line change
Expand Up @@ -386,22 +386,35 @@ object AccumulatorScale {
}

def iexp[T <: Data](q: T, qln2: T, qln2_inv: T, qb: T, qc: T)(implicit ev: Arithmetic[T]): T = {
// import ev._

// val zero = q.zero
// def neg(x: T) = zero-x

// // qln2_inv needs scale to be 1 / (2 ** 16) / S
// // qln2_inv / S / (2 ** 16) = 1 / ln2
// // q * qln2_inv = x / S / ln2 * S * (2 ** 16) = x / ln2 * (2 ** 16)
// val neg_q_iexp = neg(q)
// val z_iexp = (neg_q_iexp * qln2_inv).asUInt.do_>>(16).asTypeOf(q) // q is non-positive
// val z_iexp_saturated = Wire(z_iexp.cloneType)
// z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S, z_iexp)
// val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q)
// val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q)
// // we dont want a rounding shift
// // TODO: z overflow
// (q_poly_iexp.asUInt.do_>>(z_iexp_saturated.asUInt)).asTypeOf(q)

import ev._

val zero = q.zero
val one = q.identity
def neg(x: T) = zero-x

// qln2_inv needs scale to be 1 / (2 ** 16) / S
// qln2_inv / S / (2 ** 16) = 1 / ln2
// q * qln2_inv = x / S / ln2 * S * (2 ** 16) = x / ln2 * (2 ** 16)
val neg_q_iexp = neg(q)
val z_iexp = (neg_q_iexp * qln2_inv).asUInt.do_>>(16).asTypeOf(q) // q is non-positive
val z_iexp_saturated = Wire(z_iexp.cloneType)
z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S, z_iexp)
val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q)
val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q)
// we dont want a rounding shift
// TODO: z overflow
(q_poly_iexp.asUInt.do_>>(z_iexp_saturated.asUInt)).asTypeOf(q)
val q_sign = Mux(q.zero > q, neg(one), one)
val q_abs = Mux(q.zero > q, neg(q), q)
val q_clipped = Mux(q_abs > neg(qb), neg(qb), q_abs)
val q_poly = qc.mac(q_clipped + qb, q_clipped + qb).withWidthOf(q)
val q_erf = (q_sign * q_poly).withWidthOf(q)
(q * (q_erf + qc)).withWidthOf(q)
}}

38 changes: 24 additions & 14 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val bd_transpose = Reg(Bool())
val config_initialized = RegInit(false.B)

val is_gemv = WireInit(true.B)

val a_should_be_fed_into_transposer = Mux(current_dataflow === Dataflow.OS.id.U, !a_transpose, a_transpose)
val a_address_place = Mux(preload_cmd_place === 0.U, 1.U, Mux(a_should_be_fed_into_transposer, 2.U, 0.U))

Expand Down Expand Up @@ -252,10 +254,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val c_addr_stride = Reg(UInt(16.W)) // TODO magic numbers

val a_address = (0 until tileColumns).map(i => a_address_rs1(i) + a_addr_offset(i))
val b_address = b_address_rs2 + b_fire_counter
dontTouch(b_address)
dontTouch(b_address_rs2)
val d_address = d_address_rs1 + (block_size.U - 1.U - d_fire_counter)
val b_address = Mux(is_gemv, b_address_rs2, b_address_rs2 + b_fire_counter)
val d_address = Mux(is_gemv, d_address_rs1, d_address_rs1 + (block_size.U - 1.U - d_fire_counter))
dontTouch(d_address)

val dataAbank = a_address.map(address => address.sp_bank())
Expand Down Expand Up @@ -458,8 +458,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
for (i <- 0 until sp_banks) {
// val matching_a = dataAbank.indexOf(i.U)
val matching_a = if (i < 4) i else -1; // TODO temp fix bc indexOf() doesn't work for some reason
val matching_a_wire = WireInit(matching_a.S(4.W));
dontTouch(matching_a_wire)
val read_a = if (matching_a == -1) false.B else a_valid(matching_a) && !a_read_from_acc && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros(matching_a) && !(im2col_wire&&im2col_en)
val read_b = b_valid && !b_read_from_acc && dataBbank === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire
dontTouch(b_valid)
Expand Down Expand Up @@ -607,6 +605,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}
}

// is_gemv := config_ex_rs1.is_gemv.asBool

a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
c_addr_stride := config_ex_rs2.c_stride // TODO this needs to be kept in sync with ROB.scala
config_initialized := true.B
Expand Down Expand Up @@ -634,7 +634,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In

//start_inputting_a := current_dataflow === Dataflow.OS.id.U
//start_inputting_d := true.B

start_inputting_a := a_should_be_fed_into_transposer
start_inputting_b := b_should_be_fed_into_transposer
start_inputting_d := true.B
Expand Down Expand Up @@ -930,37 +930,47 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}
}

// TODO integrate this fully
val gemv_mode = RegInit(true.B)
dontTouch(dataB)
dontTouch(cntl_valid)
dontTouch(mesh.io.a.valid)
dontTouch(dataD)
dontTouch(is_gemv)

when (gemv_mode) {
when (is_gemv) {
when ((current_dataflow === Dataflow.WS.id.U).asBool) {
// transpose A
for (tc <- 0 until tileColumns) {
for (mr <- 0 until meshRows) {
for (tr <- 0 until tileRows) {
mesh.io.a.bits(mr)(tc)(tr) := dataA.asTypeOf(Vec(tileColumns, Vec(meshRows, Vec(tileRows, inputType))))(tc)(mr)(tr)
}
}
}
// pass in duplicated elements of weights vector in reverse order
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(0)
mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(meshRows.U - d_fire_counter)
}
}
// duplicate one element of the bias vector to the mesh
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(b_fire_counter-1.U)
}
}
mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))
}.otherwise {
// TODO this only works when casted this way
mesh.io.a.bits := dataA.asTypeOf(Vec(meshRows, Vec(tileColumns, Vec(tileRows, inputType))))
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(0)
mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(b_fire_counter-1.U)
}
}
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(d_fire_counter-1.U)
}
}
mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))
}
}.otherwise {
for (tc <- 0 until tileColumns) {
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
// From acc are ordered
val write_norm_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+2))
val write_scale_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+2))
val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+1, pipe=true))
val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), spad_read_delay+1, pipe=true)) // TODO can't this just be a normal queue?

val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t, accType.getWidth, acc_scale_t_bits), spad_read_delay+1, pipe=true, flow=true))
val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), spad_read_delay+1, pipe=true))
write_dispatch_q.ready := false.B

write_norm_q.io.enq.valid := false.B
Expand Down Expand Up @@ -444,10 +444,10 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_norm_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid

val spad_mems = {
val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(
val banks = Seq.tabulate(sp_banks) { bankId => Module(new ScratchpadBank(
sp_bank_entries, spad_w,
aligned_to, config.sp_singleported,
use_shared_ext_mem, is_dummy
use_shared_ext_mem, is_dummy=bankId > 5
)) }
val bank_ios = VecInit(banks.map(_.io))
// Reading from the SRAM banks
Expand Down

0 comments on commit a0c36f4

Please sign in to comment.