From 84bddf59772d64e24342a68a1564fdc65931a658 Mon Sep 17 00:00:00 2001 From: SeahK Date: Mon, 23 Sep 2024 16:18:52 -0700 Subject: [PATCH] output fence --- src/main/scala/gemmini/Configs.scala | 4 +- src/main/scala/gemmini/Controller.scala | 2 + src/main/scala/gemmini/GemminiISA.scala | 5 ++- src/main/scala/gemmini/Scratchpad.scala | 2 + src/main/scala/gemmini/StoreController.scala | 42 +++++++++++++++++--- 5 files changed, 46 insertions(+), 9 deletions(-) diff --git a/src/main/scala/gemmini/Configs.scala b/src/main/scala/gemmini/Configs.scala index 4c83ff8e..e4b1370c 100644 --- a/src/main/scala/gemmini/Configs.scala +++ b/src/main/scala/gemmini/Configs.scala @@ -32,10 +32,10 @@ object GemminiConfigs { meshColumns = 16, // Spatial array PE options - dataflow = Dataflow.BOTH, + dataflow = Dataflow.WS, // Scratchpad and accumulator - sp_capacity = CapacityInKilobytes(256), + sp_capacity = CapacityInKilobytes(128), acc_capacity = CapacityInKilobytes(64), sp_banks = 4, diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index c67f92df..e0998f19 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -193,6 +193,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] counters.io.event_io.collect(store_controller.io.counter) counters.io.event_io.collect(ex_controller.io.counter) + store_controller.io.dma_writer_busy := spad.module.io.writer_busy + /* tiler.io.issue.load.ready := false.B tiler.io.issue.store.ready := false.B diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index 7bca089b..26bef26a 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -113,7 +113,8 @@ object GemminiISA { val CONFIG_MVOUT_RS1_MAX_POOLING_WINDOW_SIZE_WIDTH = 2 val CONFIG_MVOUT_RS1_UPPER_ZERO_PADDING_WIDTH = 2 val CONFIG_MVOUT_RS1_LEFT_ZERO_PADDING_WIDTH = 2 - val CONFIG_MVOUT_RS1_SPACER_WIDTH = (24 - 2 * 6) + //val CONFIG_MVOUT_RS1_SPACER_WIDTH = (24 - 2 * 6) + val CONFIG_MVOUT_RS1_PIPE_BLK = (24 - 2 * 6) val CONFIG_MVOUT_RS1_POOL_OUT_DIM_WIDTH = 8 val CONFIG_MVOUT_RS1_POOL_OUT_ROWS_WIDTH = 8 val CONFIG_MVOUT_RS1_POOL_OUT_COLS_WIDTH = 8 @@ -126,7 +127,7 @@ object GemminiISA { val pocols = UInt(CONFIG_MVOUT_RS1_POOL_OUT_COLS_WIDTH.W) val porows = UInt(CONFIG_MVOUT_RS1_POOL_OUT_ROWS_WIDTH.W) val pool_out_dim = UInt(CONFIG_MVOUT_RS1_POOL_OUT_DIM_WIDTH.W) - val _spacer = UInt(CONFIG_MVOUT_RS1_SPACER_WIDTH.W) + val pipeline_block = UInt(CONFIG_MVOUT_RS1_PIPE_BLK.W) val lpad = UInt(CONFIG_MVOUT_RS1_LEFT_ZERO_PADDING_WIDTH.W) val upad = UInt(CONFIG_MVOUT_RS1_UPPER_ZERO_PADDING_WIDTH.W) val pool_size = UInt(CONFIG_MVOUT_RS1_MAX_POOLING_WINDOW_SIZE_WIDTH.W) diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index d07614b3..444ddb7e 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -241,6 +241,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Misc. ports val busy = Output(Bool()) + val writer_busy = Output(Bool()) // whether writer is empty val flush = Input(Bool()) val counter = new CounterEventIO() }) @@ -442,6 +443,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, reader.module.io.flush := io.flush io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_norm_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid + io.writer_busy := writer.module.io.busy val spad_mems = { val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank( diff --git a/src/main/scala/gemmini/StoreController.scala b/src/main/scala/gemmini/StoreController.scala index 72cd761b..5da80a22 100644 --- a/src/main/scala/gemmini/StoreController.scala +++ b/src/main/scala/gemmini/StoreController.scala @@ -21,22 +21,25 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val dma = new ScratchpadWriteMemIO(local_addr_t, accType.getWidth, acc_scale_t_bits) val completed = Decoupled(UInt(log2Up(reservation_station_entries).W)) - + val busy = Output(Bool()) + val dma_writer_busy = Input(Bool()) // whether dma is busy + val counter = new CounterEventIO() }) // val waiting_for_command :: waiting_for_dma_req_ready :: sending_rows :: Nil = Enum(3) object State extends ChiselEnum { - val waiting_for_command, waiting_for_dma_req_ready, sending_rows, pooling = Value + val waiting_for_command, waiting_for_dma_req_ready, sending_rows, pooling, synchronizing = Value } import State._ val control_state = RegInit(waiting_for_command) val stride = Reg(UInt(coreMaxAddrBits.W)) + val pip_block = Reg(UInt(12.W)) val block_rows = meshRows * tileRows val block_stride = block_rows.U val block_cols = meshColumns * tileColumns @@ -53,6 +56,8 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm //val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) val row_counter = RegInit(0.U(12.W)) // TODO magic number val block_counter = RegInit(0.U(8.W)) // TODO magic number + // to track how many blocks are done for HW level sync + val pip_block_counter = RegInit(UInt(CONFIG_MVOUT_RS1_PIPE_BLK.W), 0.U) // Pooling variables val pool_stride = Reg(UInt(CONFIG_MVOUT_RS1_MAX_POOLING_STRIDE_WIDTH.W)) // When this is 0, pooling is disabled @@ -106,6 +111,9 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val config_upad = config_mvout_rs1.upad val config_lpad = config_mvout_rs1.lpad + val config_pip_block = config_mvout_rs1.pipeline_block + dontTouch(config_pip_block) + val config_norm_rs1 = cmd.bits.cmd.rs1.asTypeOf(new ConfigNormRs1(accType.getWidth)) val config_norm_rs2 = cmd.bits.cmd.rs2.asTypeOf(new ConfigNormRs2(accType.getWidth)) val config_stats_id = config_norm_rs1.norm_stats_id @@ -224,6 +232,7 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm when (cmd.valid) { when(DoConfig) { stride := config_stride + pip_block := config_pip_block activation := config_activation when (!config_acc_scale.asUInt.andR) { @@ -282,8 +291,15 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val only_one_dma_req = block_counter === 0.U && row_counter === 0.U // This is a special case when only one DMA request is made when ((last_block && last_row) || only_one_dma_req) { - control_state := waiting_for_command - cmd.ready := true.B + when(pip_block_counter === pip_block - 1.U && pip_block > 0.U){ + control_state := synchronizing + cmd.ready := false.B + pip_block_counter := 0.U + }.otherwise{ + control_state := waiting_for_command + cmd.ready := true.B + pip_block_counter := pip_block_counter + 1.U + } } } @@ -294,8 +310,24 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U && io.dma.req.fire) when (last_row) { + when(pip_block_counter === pip_block - 1.U && pip_block > 0.U){ + control_state := synchronizing + cmd.ready := false.B + pip_block_counter := 0.U + }.otherwise{ + control_state := waiting_for_command + cmd.ready := true.B + pip_block_counter := pip_block_counter + 1.U + } + } + } + + is (synchronizing){ + // TODO: need more for synchronization? + // resolve when DMA is done with current pipeline block + when(io.dma_writer_busy === false.B){ control_state := waiting_for_command - cmd.ready := true.B + cmd.ready := true.B // else, not take any new command } } }