From 185cac830746df9a22465e3267e6d47e01567bc3 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Mon, 27 Dec 2021 12:07:34 -0600 Subject: [PATCH] Add hypervisor extension (#2841) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements v0.6.2 of the RISC-V Hypervisor Extension. The implementation was inspired by José Martins' and colleagues' work described in [1]. Much of the microarchitecture and essentially all of the code is new, but their implementation served as our baseline. We thank them for trailblazing hypervisor support in rocket-chip. Note that this PR only includes the mechanisms to virtualize the hart itself. Virtualized interrupt controllers, IOMMUs, etc. are future work. Lots of future work. Note also that some features are (legally) not implemented. Currently, misa.H is not writable, something we may or may not choose to fix. The mtinst and htinst CSRs are hardwired to 0, placing additional onus on hypervisor software to use the HLVX instruction. [1] "A First Look at RISC-V Virtualization from an Embedded Systems Perspective", https://arxiv.org/abs/2103.14951 Co-authored-by: John Ingalls --- src/main/resources/vsrc/RoccBlackBox.v | 17 +- .../logicaltree/RocketLogicalTreeNode.scala | 2 +- .../model/ISASpecifications.scala | 5 +- .../diplomaticobjectmodel/model/OMISA.scala | 3 + .../diplomaticobjectmodel/model/OMPLIC.scala | 10 +- src/main/scala/rocket/CSR.scala | 532 +++++++++++++++--- src/main/scala/rocket/Consts.scala | 7 +- src/main/scala/rocket/DCache.scala | 8 +- src/main/scala/rocket/Frontend.scala | 23 +- src/main/scala/rocket/HellaCache.scala | 4 + src/main/scala/rocket/HellaCacheArbiter.scala | 2 + src/main/scala/rocket/IDecode.scala | 30 +- src/main/scala/rocket/Instructions.scala | 51 +- src/main/scala/rocket/NBDcache.scala | 2 + src/main/scala/rocket/PTW.scala | 277 +++++++-- src/main/scala/rocket/RocketCore.scala | 77 ++- src/main/scala/rocket/TLB.scala | 204 +++++-- src/main/scala/tile/BaseTile.scala | 23 +- src/main/scala/tile/Core.scala | 2 + src/main/scala/tile/LazyRoCC.scala | 2 + 20 files changed, 1045 insertions(+), 236 deletions(-) diff --git a/src/main/resources/vsrc/RoccBlackBox.v b/src/main/resources/vsrc/RoccBlackBox.v index b34bed0f91..fa54e243f1 100644 --- a/src/main/resources/vsrc/RoccBlackBox.v +++ b/src/main/resources/vsrc/RoccBlackBox.v @@ -10,6 +10,7 @@ module RoccBlackBox coreDataBits, coreDataBytes, paddrBits, + vaddrBitsExtended, FPConstants_RM_SZ, fLen, FPConstants_FLAGS_SZ ) @@ -32,9 +33,15 @@ module RoccBlackBox input rocc_cmd_bits_status_wfi, input [31:0] rocc_cmd_bits_status_isa, input [PRV_SZ-1:0] rocc_cmd_bits_status_dprv, + input rocc_cmd_bits_status_dv, input [PRV_SZ-1:0] rocc_cmd_bits_status_prv, + input rocc_cmd_bits_status_v, input rocc_cmd_bits_status_sd, - input [26:0] rocc_cmd_bits_status_zero2, + input [22:0] rocc_cmd_bits_status_zero2, + input rocc_cmd_bits_status_mpv, + input rocc_cmd_bits_status_gva, + input rocc_cmd_bits_status_mbe, + input rocc_cmd_bits_status_sbe, input [1:0] rocc_cmd_bits_status_sxl, input [1:0] rocc_cmd_bits_status_uxl, input rocc_cmd_bits_status_sd_rv32, @@ -51,7 +58,7 @@ module RoccBlackBox input [1:0] rocc_cmd_bits_status_mpp, input [0:0] rocc_cmd_bits_status_spp, input rocc_cmd_bits_status_mpie, - input rocc_cmd_bits_status_hpie, + input rocc_cmd_bits_status_ube, input rocc_cmd_bits_status_spie, input rocc_cmd_bits_status_upie, input rocc_cmd_bits_status_mie, @@ -73,6 +80,7 @@ module RoccBlackBox output rocc_mem_req_bits_no_alloc, output rocc_mem_req_bits_no_xcpt, output [1:0] rocc_mem_req_bits_dprv, + output rocc_mem_req_bits_dv, output [coreDataBits-1:0] rocc_mem_req_bits_data, output [coreDataBytes-1:0] rocc_mem_req_bits_mask, output rocc_mem_s1_kill, @@ -83,6 +91,8 @@ module RoccBlackBox output rocc_mem_s2_kill, input rocc_mem_s2_uncached, input [paddrBits-1:0] rocc_mem_s2_paddr, + input [vaddrBitsExtended-1:0] rocc_mem_s2_gpa, + input rocc_mem_s2_gpa_is_pte, input rocc_mem_resp_valid, input [coreMaxAddrBits-1:0] rocc_mem_resp_bits_addr, input [dcacheReqTagBits-1:0] rocc_mem_resp_bits_tag, @@ -97,11 +107,14 @@ module RoccBlackBox input [coreDataBits-1:0] rocc_mem_resp_bits_data_raw, input [coreDataBits-1:0] rocc_mem_resp_bits_store_data, input [1:0] rocc_mem_resp_bits_dprv, + input rocc_mem_resp_bits_dv, input rocc_mem_replay_next, input rocc_mem_s2_xcpt_ma_ld, input rocc_mem_s2_xcpt_ma_st, input rocc_mem_s2_xcpt_pf_ld, input rocc_mem_s2_xcpt_pf_st, + input rocc_mem_s2_xcpt_gf_ld, + input rocc_mem_s2_xcpt_gf_st, input rocc_mem_s2_xcpt_ae_ld, input rocc_mem_s2_xcpt_ae_st, input rocc_mem_ordered, diff --git a/src/main/scala/diplomaticobjectmodel/logicaltree/RocketLogicalTreeNode.scala b/src/main/scala/diplomaticobjectmodel/logicaltree/RocketLogicalTreeNode.scala index 8be9273c3d..92935251b8 100644 --- a/src/main/scala/diplomaticobjectmodel/logicaltree/RocketLogicalTreeNode.scala +++ b/src/main/scala/diplomaticobjectmodel/logicaltree/RocketLogicalTreeNode.scala @@ -81,7 +81,7 @@ class RocketLogicalTreeNode( def getOMInterruptTargets(): Seq[OMInterruptTarget] = { Seq(OMInterruptTarget( hartId = tile.rocketParams.hartId, - modes = OMModes.getModes(tile.rocketParams.core.hasSupervisorMode) + modes = OMModes.getModes(tile.rocketParams.core.hasSupervisorMode, tile.rocketParams.core.useHypervisor) )) } diff --git a/src/main/scala/diplomaticobjectmodel/model/ISASpecifications.scala b/src/main/scala/diplomaticobjectmodel/model/ISASpecifications.scala index c793c2cdc8..99855feb7a 100644 --- a/src/main/scala/diplomaticobjectmodel/model/ISASpecifications.scala +++ b/src/main/scala/diplomaticobjectmodel/model/ISASpecifications.scala @@ -5,11 +5,13 @@ package freechips.rocketchip.diplomaticobjectmodel.model sealed trait PrivilegedArchitectureExtension extends OMEnum case object MachineLevelISA extends PrivilegedArchitectureExtension +case object HypervisorLevelISA extends PrivilegedArchitectureExtension case object SupervisorLevelISA extends PrivilegedArchitectureExtension object PrivilegedArchitectureExtensions { val specifications = Map[PrivilegedArchitectureExtension, String]( MachineLevelISA -> "Machine-Level ISA", + HypervisorLevelISA -> "Hypervisor-Level ISA", SupervisorLevelISA -> "Supervisor-Level ISA" ) @@ -36,7 +38,8 @@ object ISAExtensions { C -> "C Standard Extension for Compressed Instruction", B -> "B Standard Extension for Bit Manipulation", U -> "The RISC‑V Instruction Set Manual, Volume II: Privileged Architecture", - S -> "Supervisor-Level ISA" + S -> "Supervisor-Level ISA", + H -> "H Standard Extension for Hypervisor", ) def specVersion(extension: OMExtensionType, version: String): OMSpecification = OMSpecification(specifications(extension), version) diff --git a/src/main/scala/diplomaticobjectmodel/model/OMISA.scala b/src/main/scala/diplomaticobjectmodel/model/OMISA.scala index ad8f6c4620..ea8ccab50c 100644 --- a/src/main/scala/diplomaticobjectmodel/model/OMISA.scala +++ b/src/main/scala/diplomaticobjectmodel/model/OMISA.scala @@ -15,6 +15,7 @@ case object C extends OMExtensionType case object B extends OMExtensionType case object U extends OMExtensionType case object S extends OMExtensionType +case object H extends OMExtensionType trait OMAddressTranslationMode extends OMEnum case object Bare extends OMAddressTranslationMode @@ -44,6 +45,7 @@ case class OMISA( v: Option[OMVectorExtension] = None, u: Option[OMSpecification], s: Option[OMSpecification], + h: Option[OMSpecification], addressTranslationModes: Seq[OMAddressTranslationMode], customExtensions: Seq[OMCustomExtensionSpecification], _types: Seq[String] = Seq("OMISA", "OMCompoundType") @@ -105,6 +107,7 @@ object OMISA { c = coreParams.useCompressed.option(isaExtSpec(C, "2.0")), u = (coreParams.hasSupervisorMode || coreParams.useUser).option(isaExtSpec(U, "1.10")), s = coreParams.hasSupervisorMode.option(isaExtSpec(S, "1.10")), + h = coreParams.useHypervisor.option(isaExtSpec(H, "0.6")), addressTranslationModes = Seq(addressTranslationModes), customExtensions = customExtensions ) diff --git a/src/main/scala/diplomaticobjectmodel/model/OMPLIC.scala b/src/main/scala/diplomaticobjectmodel/model/OMPLIC.scala index bdcfd09293..1f4d8df74f 100644 --- a/src/main/scala/diplomaticobjectmodel/model/OMPLIC.scala +++ b/src/main/scala/diplomaticobjectmodel/model/OMPLIC.scala @@ -4,15 +4,15 @@ package freechips.rocketchip.diplomaticobjectmodel.model sealed trait OMPrivilegeMode extends OMEnum case object OMMachineMode extends OMPrivilegeMode +case object OMHypervisorMode extends OMPrivilegeMode case object OMSupervisorMode extends OMPrivilegeMode case object OMUserMode extends OMPrivilegeMode object OMModes { - def getModes(hasSupervisorMode: Boolean): Seq[OMPrivilegeMode] = { - hasSupervisorMode match { - case false => Seq(OMMachineMode) - case true => Seq(OMMachineMode, OMSupervisorMode) - } + def getModes(hasSupervisorMode: Boolean, hasHypervisorMode: Boolean): Seq[OMPrivilegeMode] = { + Seq(OMMachineMode) ++ + (if (hasHypervisorMode) Seq(OMHypervisorMode) else Seq()) ++ + (if (hasSupervisorMode) Seq(OMSupervisorMode) else Seq()) } } diff --git a/src/main/scala/rocket/CSR.scala b/src/main/scala/rocket/CSR.scala index 99bdab7833..4e8dad3143 100644 --- a/src/main/scala/rocket/CSR.scala +++ b/src/main/scala/rocket/CSR.scala @@ -21,10 +21,17 @@ class MStatus extends Bundle { val wfi = Bool() val isa = UInt(width = 32) - val dprv = UInt(width = PRV.SZ) // effective privilege for data accesses - val prv = UInt(width = PRV.SZ) // not truly part of mstatus, but convenient + val dprv = UInt(width = PRV.SZ) // effective prv for data accesses + val dv = Bool() // effective v for data accesses + val prv = UInt(width = PRV.SZ) + val v = Bool() + val sd = Bool() - val zero2 = UInt(width = 27) + val zero2 = UInt(width = 23) + val mpv = Bool() + val gva = Bool() + val mbe = Bool() + val sbe = Bool() val sxl = UInt(width = 2) val uxl = UInt(width = 2) val sd_rv32 = Bool() @@ -41,7 +48,7 @@ class MStatus extends Bundle { val vs = UInt(width = 2) val spp = UInt(width = 1) val mpie = Bool() - val hpie = Bool() + val ube = Bool() val spie = Bool() val upie = Bool() val mie = Bool() @@ -50,6 +57,33 @@ class MStatus extends Bundle { val uie = Bool() } +class MNStatus extends Bundle { + val mpp = UInt(2.W) + val zero3 = UInt(3.W) + val mpv = Bool() + val zero2 = UInt(3.W) + val mie = Bool() + val zero1 = UInt(3.W) +} + +class HStatus extends Bundle { + val zero6 = UInt(width = 30) + val vsxl = UInt(width = 2) + val zero5 = UInt(width = 9) + val vtsr = Bool() + val vtw = Bool() + val vtvm = Bool() + val zero3 = UInt(width = 2) + val vgein = UInt(width = 6) + val zero2 = UInt(width = 2) + val hu = Bool() + val spvp = Bool() + val spv = Bool() + val gva = Bool() + val vsbe = Bool() + val zero1 = UInt(width = 5) +} + class DCSR extends Bundle { val xdebugver = UInt(width = 2) val zero4 = UInt(width=2) @@ -62,7 +96,8 @@ class DCSR extends Bundle { val stopcycle = Bool() val stoptime = Bool() val cause = UInt(width = 3) - val zero1 = UInt(width=3) + val v = Bool() + val zero1 = UInt(width = 2) val step = Bool() val prv = UInt(width = PRV.SZ) } @@ -70,20 +105,20 @@ class DCSR extends Bundle { class MIP(implicit p: Parameters) extends CoreBundle()(p) with HasCoreParameters { val lip = Vec(coreParams.nLocalInterrupts, Bool()) - val zero2 = Bool() - val debug = Bool() // keep in sync with CSR.debugIntCause val zero1 = Bool() + val debug = Bool() // keep in sync with CSR.debugIntCause val rocc = Bool() + val sgeip = Bool() val meip = Bool() - val heip = Bool() + val vseip = Bool() val seip = Bool() val ueip = Bool() val mtip = Bool() - val htip = Bool() + val vstip = Bool() val stip = Bool() val utip = Bool() val msip = Bool() - val hsip = Bool() + val vssip = Bool() val ssip = Bool() val usip = Bool() } @@ -133,6 +168,11 @@ object CSR } val ADDRSZ = 12 + + def modeLSB: Int = 8 + def mode(addr: Int): Int = (addr >> modeLSB) % (1 << PRV.SZ) + def mode(addr: UInt): UInt = addr(modeLSB + PRV.SZ - 1, modeLSB) + def busErrorIntCause = 128 def debugIntCause = 14 // keep in sync with MIP.debug def debugTriggerCause = { @@ -180,8 +220,11 @@ class TraceAux extends Bundle { val stall = Bool() } -class CSRDecodeIO extends Bundle { - val csr = UInt(INPUT, CSR.ADDRSZ) +class CSRDecodeIO(implicit p: Parameters) extends CoreBundle { + val inst = Input(UInt(iLen.W)) + + def csr_addr = (inst >> 20)(CSR.ADDRSZ-1, 0) + val fp_illegal = Bool(OUTPUT) val vector_illegal = Bool(OUTPUT) val fp_csr = Bool(OUTPUT) @@ -190,6 +233,8 @@ class CSRDecodeIO extends Bundle { val write_illegal = Bool(OUTPUT) val write_flush = Bool(OUTPUT) val system_illegal = Bool(OUTPUT) + val virtual_access_illegal = Bool(OUTPUT) + val virtual_system_illegal = Bool(OUTPUT) } class CSRFileIO(implicit p: Parameters) extends CoreBundle @@ -211,13 +256,19 @@ class CSRFileIO(implicit p: Parameters) extends CoreBundle val singleStep = Bool(OUTPUT) val status = new MStatus().asOutput + val hstatus = new HStatus().asOutput + val gstatus = new MStatus().asOutput val ptbr = new PTBR().asOutput + val hgatp = new PTBR().asOutput + val vsatp = new PTBR().asOutput val evec = UInt(OUTPUT, vaddrBitsExtended) val exception = Bool(INPUT) val retire = UInt(INPUT, log2Up(1+retireWidth)) val cause = UInt(INPUT, xLen) val pc = UInt(INPUT, vaddrBitsExtended) val tval = UInt(INPUT, vaddrBitsExtended) + val htval = UInt(INPUT, (maxSVAddrBits + 1) min xLen) + val gva = Bool(INPUT) val time = UInt(OUTPUT, xLen) val fcsr_rm = Bits(OUTPUT, FPConstants.RM_SZ) val fcsr_flags = Valid(Bits(width = FPConstants.FLAGS_SZ)).flip @@ -329,20 +380,20 @@ class CSRFile( val sup = Wire(new MIP) sup.usip := false sup.ssip := Bool(usingSupervisor) - sup.hsip := false + sup.vssip := Bool(usingHypervisor) sup.msip := true sup.utip := false sup.stip := Bool(usingSupervisor) - sup.htip := false + sup.vstip := Bool(usingHypervisor) sup.mtip := true sup.ueip := false sup.seip := Bool(usingSupervisor) - sup.heip := false + sup.vseip := Bool(usingHypervisor) sup.meip := true + sup.sgeip := false sup.rocc := usingRoCC - sup.zero1 := false sup.debug := false - sup.zero2 := false + sup.zero1 := false sup.lip foreach { _ := true } val supported_high_interrupts = if (io.interrupts.buserror.nonEmpty && !usingNMI) UInt(BigInt(1) << CSR.busErrorIntCause) else 0.U @@ -362,7 +413,38 @@ class CSRFile( Causes.misaligned_load, Causes.misaligned_store, Causes.illegal_instruction, - Causes.user_ecall).map(1 << _).sum) + Causes.user_ecall, + Causes.virtual_supervisor_ecall, + Causes.fetch_guest_page_fault, + Causes.load_guest_page_fault, + Causes.virtual_instruction, + Causes.store_guest_page_fault).map(1 << _).sum) + + val hs_delegable_exceptions = UInt(Seq( + Causes.misaligned_fetch, + Causes.fetch_access, + Causes.illegal_instruction, + Causes.breakpoint, + Causes.misaligned_load, + Causes.load_access, + Causes.misaligned_store, + Causes.store_access, + Causes.user_ecall, + Causes.fetch_page_fault, + Causes.load_page_fault, + Causes.store_page_fault).map(1 << _).sum) + + val (hs_delegable_interrupts, mideleg_always_hs) = { + val always = Wire(new MIP().fromBits(0.U)) + always.vssip := Bool(usingHypervisor) + always.vstip := Bool(usingHypervisor) + always.vseip := Bool(usingHypervisor) + + val deleg = Wire(init = always) + deleg.lip.foreach { _ := Bool(usingHypervisor) } + + (deleg.asUInt, always.asUInt) + } val reg_debug = Reg(init=Bool(false)) val reg_dpc = Reg(UInt(width = vaddrBitsExtended)) @@ -380,7 +462,7 @@ class CSRFile( val reg_mie = Reg(UInt(width = xLen)) val (reg_mideleg, read_mideleg) = { val reg = Reg(UInt(xLen.W)) - (reg, Mux(usingSupervisor, reg & delegable_interrupts, 0.U)) + (reg, Mux(usingSupervisor, reg & delegable_interrupts | mideleg_always_hs, 0.U)) } val (reg_medeleg, read_medeleg) = { val reg = Reg(UInt(xLen.W)) @@ -390,6 +472,7 @@ class CSRFile( val reg_mepc = Reg(UInt(width = vaddrBitsExtended)) val reg_mcause = RegInit(0.U(xLen.W)) val reg_mtval = Reg(UInt(width = vaddrBitsExtended)) + val reg_mtval2 = Reg(UInt(((maxSVAddrBits + 1) min xLen).W)) val reg_mscratch = Reg(Bits(width = xLen)) val mtvecWidth = paddrBits min xLen val reg_mtvec = mtvecInit match { @@ -397,7 +480,7 @@ class CSRFile( case None => Reg(UInt(width = mtvecWidth)) } - val reset_mnstatus = Wire(init=new MStatus().fromBits(0)) + val reset_mnstatus = Wire(init=new MNStatus().fromBits(0)) reset_mnstatus.mpp := PRV.M val reg_mnscratch = Reg(Bits(width = xLen)) val reg_mnepc = Reg(UInt(width = vaddrBitsExtended)) @@ -416,11 +499,41 @@ class CSRFile( (reg, Mux(usingSupervisor, reg & delegable_counters, 0.U)) } + val (reg_hideleg, read_hideleg) = { + val reg = Reg(UInt(xLen.W)) + (reg, Mux(usingHypervisor, reg & hs_delegable_interrupts, 0.U)) + } + val (reg_hedeleg, read_hedeleg) = { + val reg = Reg(UInt(xLen.W)) + (reg, Mux(usingHypervisor, reg & hs_delegable_exceptions, 0.U)) + } + val hs_delegable_counters = delegable_counters + val (reg_hcounteren, read_hcounteren) = { + val reg = Reg(UInt(32.W)) + (reg, Mux(usingHypervisor, reg & hs_delegable_counters, 0.U)) + } + val reg_hstatus = RegInit(0.U.asTypeOf(new HStatus)) + val reg_hgatp = Reg(new PTBR) + val reg_htval = Reg(reg_mtval2.cloneType) + val read_hvip = reg_mip.asUInt & hs_delegable_interrupts + val read_hie = reg_mie & hs_delegable_interrupts + + val (reg_vstvec, read_vstvec) = { + val reg = Reg(UInt(width = vaddrBitsExtended)) + (reg, formTVec(reg).sextTo(xLen)) + } + val reg_vsstatus = Reg(new MStatus) + val reg_vsscratch = Reg(Bits(width = xLen)) + val reg_vsepc = Reg(UInt(width = vaddrBitsExtended)) + val reg_vscause = Reg(Bits(width = xLen)) + val reg_vstval = Reg(UInt(width = vaddrBitsExtended)) + val reg_vsatp = Reg(new PTBR) + val reg_sepc = Reg(UInt(width = vaddrBitsExtended)) val reg_scause = Reg(Bits(width = xLen)) val reg_stval = Reg(UInt(width = vaddrBitsExtended)) val reg_sscratch = Reg(Bits(width = xLen)) - val reg_stvec = Reg(UInt(width = vaddrBits)) + val reg_stvec = Reg(UInt(width = if (usingHypervisor) vaddrBitsExtended else vaddrBits)) val reg_satp = Reg(new PTBR) val reg_wfi = withClock(io.ungated_clock) { Reg(init=Bool(false)) } @@ -448,8 +561,11 @@ class CSRFile( mip.meip := io.interrupts.meip // seip is the OR of reg_mip.seip and the actual line from the PLIC io.interrupts.seip.foreach { mip.seip := reg_mip.seip || _ } + // Simimlar sort of thing would apply if the PLIC had a VSEIP line: + //io.interrupts.vseip.foreach { mip.vseip := reg_mip.vseip || _ } mip.rocc := io.rocc_interrupt val read_mip = mip.asUInt & supported_interrupts + val read_hip = read_mip & hs_delegable_interrupts val high_interrupts = (if (usingNMI) 0.U else io.interrupts.buserror.map(_ << CSR.busErrorIntCause).getOrElse(0.U)) val pending_interrupts = high_interrupts | (read_mip & reg_mie) @@ -459,8 +575,9 @@ class CSRFile( io.interrupts.buserror.map(_ << CSR.rnmiBEUCause).getOrElse(0.U), !io.interrupts.debug && nmi.rnmi && reg_rnmie)).getOrElse(0.U, false.B) val m_interrupts = Mux(nmie && (reg_mstatus.prv <= PRV.S || reg_mstatus.mie), ~(~pending_interrupts | read_mideleg), UInt(0)) - val s_interrupts = Mux(nmie && (reg_mstatus.prv < PRV.S || (reg_mstatus.prv === PRV.S && reg_mstatus.sie)), pending_interrupts & read_mideleg, UInt(0)) - val (anyInterrupt, whichInterrupt) = chooseInterrupt(Seq(s_interrupts, m_interrupts, nmi_interrupts, d_interrupts)) + val s_interrupts = Mux(nmie && (reg_mstatus.v || reg_mstatus.prv < PRV.S || (reg_mstatus.prv === PRV.S && reg_mstatus.sie)), pending_interrupts & read_mideleg & ~read_hideleg, UInt(0)) + val vs_interrupts = Mux(nmie && (reg_mstatus.v && (reg_mstatus.prv < PRV.S || reg_mstatus.prv === PRV.S && reg_vsstatus.sie)), pending_interrupts & read_hideleg, UInt(0)) + val (anyInterrupt, whichInterrupt) = chooseInterrupt(Seq(vs_interrupts, s_interrupts, m_interrupts, nmi_interrupts, d_interrupts)) val interruptMSB = BigInt(1) << (xLen-1) val interruptCause = UInt(interruptMSB) + (nmiFlag << (xLen-2)) + whichInterrupt io.interrupt := (anyInterrupt && !io.singleStep || reg_singleStepped) && !(reg_debug || io.status.cease) @@ -482,6 +599,7 @@ class CSRFile( isaMaskString + "X" + // Custom extensions always present (e.g. CEASE instruction) (if (usingSupervisor) "S" else "") + + (if (usingHypervisor) "H" else "") + (if (usingUser) "U" else "") val isaMax = (BigInt(log2Ceil(xLen) - 4) << (xLen-2)) | isaStringToMask(isaString) val reg_misa = Reg(init=UInt(isaMax)) @@ -511,8 +629,9 @@ class CSRFile( CSRs.dscratch -> reg_dscratch.asUInt) ++ reg_dscratch1.map(r => CSRs.dscratch1 -> r) - val read_mnstatus = WireInit(0.U.asTypeOf(new MStatus())) + val read_mnstatus = WireInit(0.U.asTypeOf(new MNStatus())) read_mnstatus.mpp := reg_mnstatus.mpp + read_mnstatus.mpv := reg_mnstatus.mpv read_mnstatus.mie := reg_rnmie val nmi_csrs = if (!usingNMI) LinkedHashMap() else LinkedHashMap[Int,Bits]( CSRs.mnscratch -> reg_mnscratch, @@ -578,9 +697,14 @@ class CSRFile( } } + val sie_mask = { + val sgeip_mask = WireInit(0.U.asTypeOf(new MIP)) + sgeip_mask.sgeip := true + read_mideleg & ~(hs_delegable_interrupts | sgeip_mask.asUInt) + } if (usingSupervisor) { - val read_sie = reg_mie & read_mideleg - val read_sip = read_mip & read_mideleg + val read_sie = reg_mie & sie_mask + val read_sip = read_mip & sie_mask val read_sstatus = Wire(init = 0.U.asTypeOf(new MStatus)) read_sstatus.sd := io.status.sd read_sstatus.uxl := io.status.uxl @@ -628,62 +752,145 @@ class CSRFile( reg } + if (usingHypervisor) { + read_mapping += CSRs.mtinst -> 0.U + read_mapping += CSRs.mtval2 -> reg_mtval2 + + val read_hstatus = io.hstatus.asUInt()(xLen-1,0) + + read_mapping += CSRs.hstatus -> read_hstatus + read_mapping += CSRs.hedeleg -> read_hedeleg + read_mapping += CSRs.hideleg -> read_hideleg + read_mapping += CSRs.hcounteren-> read_hcounteren + read_mapping += CSRs.hgatp -> reg_hgatp.asUInt + read_mapping += CSRs.hip -> read_hip + read_mapping += CSRs.hie -> read_hie + read_mapping += CSRs.hvip -> read_hvip + read_mapping += CSRs.hgeie -> 0.U + read_mapping += CSRs.hgeip -> 0.U + read_mapping += CSRs.htval -> reg_htval + read_mapping += CSRs.htinst -> 0.U + + val read_vsie = (read_hie & read_hideleg) >> 1 + val read_vsip = (read_hip & read_hideleg) >> 1 + val read_vsepc = readEPC(reg_vsepc).sextTo(xLen) + val read_vstval = reg_vstval.sextTo(xLen) + val read_vsstatus = io.gstatus.asUInt()(xLen-1,0) + + read_mapping += CSRs.vsstatus -> read_vsstatus + read_mapping += CSRs.vsip -> read_vsip + read_mapping += CSRs.vsie -> read_vsie + read_mapping += CSRs.vsscratch -> reg_vsscratch + read_mapping += CSRs.vscause -> reg_vscause + read_mapping += CSRs.vstval -> read_vstval + read_mapping += CSRs.vsatp -> reg_vsatp.asUInt + read_mapping += CSRs.vsepc -> read_vsepc + read_mapping += CSRs.vstvec -> read_vstvec + } + // mimpid, marchid, and mvendorid are 0 unless overridden by customCSRs Seq(CSRs.mimpid, CSRs.marchid, CSRs.mvendorid).foreach(id => read_mapping.getOrElseUpdate(id, 0.U)) - val decoded_addr = read_mapping map { case (k, v) => k -> (io.rw.addr === k) } + val decoded_addr = { + val addr = Cat(io.status.v, io.rw.addr) + val pats = for (((k, _), i) <- read_mapping.zipWithIndex) + yield (BitPat(k.U), (0 until read_mapping.size).map(j => BitPat((i == j).B))) + val decoded = DecodeLogic(addr, Seq.fill(read_mapping.size)(X), pats) + val unvirtualized_mapping = for (((k, _), v) <- read_mapping zip decoded) yield k -> v.asBool + + for ((k, v) <- unvirtualized_mapping) yield k -> { + val alt = CSR.mode(k) match { + case PRV.S => unvirtualized_mapping.lift(k + (1 << CSR.modeLSB)) + case PRV.H => unvirtualized_mapping.lift(k - (1 << CSR.modeLSB)) + case _ => None + } + alt.map(Mux(reg_mstatus.v, _, v)).getOrElse(v) + } + } + val wdata = readModifyWriteCSR(io.rw.cmd, io.rw.rdata, io.rw.wdata) val system_insn = io.rw.cmd === CSR.I - val decode_table = Seq( SCALL-> List(Y,N,N,N,N,N), - SBREAK-> List(N,Y,N,N,N,N), - MRET-> List(N,N,Y,N,N,N), - CEASE-> List(N,N,N,Y,N,N), - WFI-> List(N,N,N,N,Y,N)) ++ - usingDebug.option( DRET-> List(N,N,Y,N,N,N)) ++ - usingNMI.option( MNRET-> List(N,N,Y,N,N,N)) ++ - coreParams.haveCFlush.option(CFLUSH_D_L1-> List(N,N,N,N,N,N)) ++ - usingSupervisor.option( SRET-> List(N,N,Y,N,N,N)) ++ - usingVM.option( SFENCE_VMA-> List(N,N,N,N,N,Y)) - - val insn_call :: insn_break :: insn_ret :: insn_cease :: insn_wfi :: insn_sfence :: Nil = + val hlsv = Seq(HLV_B, HLV_BU, HLV_H, HLV_HU, HLV_W, HLV_WU, HLV_D, HSV_B, HSV_H, HSV_W, HSV_D, HLVX_HU, HLVX_WU) + val decode_table = Seq( SCALL-> List(Y,N,N,N,N,N,N,N,N), + SBREAK-> List(N,Y,N,N,N,N,N,N,N), + MRET-> List(N,N,Y,N,N,N,N,N,N), + CEASE-> List(N,N,N,Y,N,N,N,N,N), + WFI-> List(N,N,N,N,Y,N,N,N,N)) ++ + usingDebug.option( DRET-> List(N,N,Y,N,N,N,N,N,N)) ++ + usingNMI.option( MNRET-> List(N,N,Y,N,N,N,N,N,N)) ++ + coreParams.haveCFlush.option(CFLUSH_D_L1-> List(N,N,N,N,N,N,N,N,N)) ++ + usingSupervisor.option( SRET-> List(N,N,Y,N,N,N,N,N,N)) ++ + usingVM.option( SFENCE_VMA-> List(N,N,N,N,N,Y,N,N,N)) ++ + usingHypervisor.option( HFENCE_VVMA-> List(N,N,N,N,N,N,Y,N,N)) ++ + usingHypervisor.option( HFENCE_GVMA-> List(N,N,N,N,N,N,N,Y,N)) ++ + (if (usingHypervisor) hlsv.map(_-> List(N,N,N,N,N,N,N,N,Y)) else Seq()) + val insn_call :: insn_break :: insn_ret :: insn_cease :: insn_wfi :: _ :: _ :: _ :: _ :: Nil = DecodeLogic(io.rw.addr << 20, decode_table(0)._2.map(x=>X), decode_table).map(system_insn && _.asBool) for (io_dec <- io.decode) { - def decodeAny(m: LinkedHashMap[Int,Bits]): Bool = m.map { case(k: Int, _: Bits) => io_dec.csr === k }.reduce(_||_) - def decodeFast(s: Seq[Int]): Bool = DecodeLogic(io_dec.csr, s.map(_.U), (read_mapping -- s).keys.toList.map(_.U)) + val addr = io_dec.inst(31, 20) + + def decodeAny(m: LinkedHashMap[Int,Bits]): Bool = m.map { case(k: Int, _: Bits) => addr === k }.reduce(_||_) + def decodeFast(s: Seq[Int]): Bool = DecodeLogic(addr, s.map(_.U), (read_mapping -- s).keys.toList.map(_.U)) - val _ :: is_break :: is_ret :: _ :: is_wfi :: is_sfence :: Nil = - DecodeLogic(io_dec.csr << 20, decode_table(0)._2.map(x=>X), decode_table).map(_.asBool) + val _ :: is_break :: is_ret :: _ :: is_wfi :: is_sfence :: is_hfence_vvma :: is_hfence_gvma :: is_hlsv :: Nil = + DecodeLogic(io_dec.inst, decode_table(0)._2.map(x=>X), decode_table).map(_.asBool) + val is_counter = (addr.inRange(CSR.firstCtr, CSR.firstCtr + CSR.nCtr) || addr.inRange(CSR.firstCtrH, CSR.firstCtrH + CSR.nCtr)) - val allow_wfi = Bool(!usingSupervisor) || reg_mstatus.prv > PRV.S || !reg_mstatus.tw - val allow_sfence_vma = Bool(!usingVM) || reg_mstatus.prv > PRV.S || !reg_mstatus.tvm - val allow_sret = Bool(!usingSupervisor) || reg_mstatus.prv > PRV.S || !reg_mstatus.tsr - val counter_addr = io_dec.csr(log2Ceil(read_mcounteren.getWidth)-1, 0) + val allow_wfi = Bool(!usingSupervisor) || reg_mstatus.prv > PRV.S || !reg_mstatus.tw && (!reg_mstatus.v || !reg_hstatus.vtw) + val allow_sfence_vma = Bool(!usingVM) || reg_mstatus.prv > PRV.S || !Mux(reg_mstatus.v, reg_hstatus.vtvm, reg_mstatus.tvm) + val allow_hfence_vvma = Bool(!usingHypervisor) || !reg_mstatus.v && (reg_mstatus.prv >= PRV.S) + val allow_hlsv = Bool(!usingHypervisor) || !reg_mstatus.v && (reg_mstatus.prv >= PRV.S || reg_hstatus.hu) + val allow_sret = Bool(!usingSupervisor) || reg_mstatus.prv > PRV.S || !Mux(reg_mstatus.v, reg_hstatus.vtsr, reg_mstatus.tsr) + val counter_addr = addr(log2Ceil(read_mcounteren.getWidth)-1, 0) val allow_counter = (reg_mstatus.prv > PRV.S || read_mcounteren(counter_addr)) && - (!usingSupervisor || reg_mstatus.prv >= PRV.S || read_scounteren(counter_addr)) - io_dec.fp_illegal := io.status.fs === 0 || !reg_misa('f'-'a') - io_dec.vector_illegal := io.status.vs === 0 || !reg_misa('v'-'a') + (!usingSupervisor || reg_mstatus.prv >= PRV.S || read_scounteren(counter_addr)) && + (!usingHypervisor || !reg_mstatus.v || read_hcounteren(counter_addr)) + io_dec.fp_illegal := io.status.fs === 0 || reg_mstatus.v && reg_vsstatus.fs === 0 || !reg_misa('f'-'a') + io_dec.vector_illegal := io.status.vs === 0 || reg_mstatus.v && reg_vsstatus.vs === 0 || !reg_misa('v'-'a') io_dec.fp_csr := decodeFast(fp_csrs.keys.toList) - io_dec.rocc_illegal := io.status.xs === 0 || !reg_misa('x'-'a') - io_dec.read_illegal := reg_mstatus.prv < io_dec.csr(9,8) || - !decodeAny(read_mapping) || - io_dec.csr === CSRs.satp && !allow_sfence_vma || - (io_dec.csr.inRange(CSR.firstCtr, CSR.firstCtr + CSR.nCtr) || io_dec.csr.inRange(CSR.firstCtrH, CSR.firstCtrH + CSR.nCtr)) && !allow_counter || + io_dec.rocc_illegal := io.status.xs === 0 || reg_mstatus.v && reg_vsstatus.vs === 0 || !reg_misa('x'-'a') + val csr_addr_legal = reg_mstatus.prv >= CSR.mode(addr) || + usingHypervisor && !reg_mstatus.v && reg_mstatus.prv === PRV.S && CSR.mode(addr) === PRV.H + val csr_exists = decodeAny(read_mapping) + io_dec.read_illegal := !csr_addr_legal || + !csr_exists || + ((addr === CSRs.satp || addr === CSRs.hgatp) && !allow_sfence_vma) || + is_counter && !allow_counter || decodeFast(debug_csrs.keys.toList) && !reg_debug || decodeFast(vector_csrs.keys.toList) && io_dec.vector_illegal || io_dec.fp_csr && io_dec.fp_illegal - io_dec.write_illegal := io_dec.csr(11,10).andR - io_dec.write_flush := !(io_dec.csr >= CSRs.mscratch && io_dec.csr <= CSRs.mtval || io_dec.csr >= CSRs.sscratch && io_dec.csr <= CSRs.stval) - io_dec.system_illegal := reg_mstatus.prv < io_dec.csr(9,8) || + io_dec.write_illegal := addr(11,10).andR + io_dec.write_flush := { + val addr_m = addr | (PRV.M << CSR.modeLSB) + !(addr_m >= CSRs.mscratch && addr_m <= CSRs.mtval) + } + io_dec.system_illegal := !csr_addr_legal && !is_hlsv || is_wfi && !allow_wfi || is_ret && !allow_sret || - is_ret && io_dec.csr(10) && io_dec.csr(7) && !reg_debug || - is_sfence && !allow_sfence_vma + is_ret && addr(10) && addr(7) && !reg_debug || + (is_sfence || is_hfence_gvma) && !allow_sfence_vma || + is_hfence_vvma && !allow_hfence_vvma || + is_hlsv && !allow_hlsv + + io_dec.virtual_access_illegal := reg_mstatus.v && csr_exists && ( + CSR.mode(addr) === PRV.H || + is_counter && read_mcounteren(counter_addr) && (!read_hcounteren(counter_addr) || !reg_mstatus.prv(0) && !read_scounteren(counter_addr)) || + CSR.mode(addr) === PRV.S && !reg_mstatus.prv(0) || + addr === CSRs.satp && reg_mstatus.prv(0) && reg_hstatus.vtvm) + + io_dec.virtual_system_illegal := reg_mstatus.v && ( + is_hfence_vvma || + is_hfence_gvma || + is_hlsv || + is_wfi && (!reg_mstatus.prv(0) || !reg_mstatus.tw && reg_hstatus.vtw) || + is_ret && CSR.mode(addr) === PRV.S && (!reg_mstatus.prv(0) || reg_hstatus.vtsr) || + is_sfence && (!reg_mstatus.prv(0) || reg_hstatus.vtvm)) } val cause = - Mux(insn_call, reg_mstatus.prv + Causes.user_ecall, + Mux(insn_call, Causes.user_ecall + Mux(reg_mstatus.prv(0) && reg_mstatus.v, PRV.H: UInt, reg_mstatus.prv), Mux[UInt](insn_break, Causes.breakpoint, io.cause)) val cause_lsbs = cause(log2Ceil(1 + CSR.busErrorIntCause)-1, 0) val causeIsDebugInt = cause(xLen-1) && cause_lsbs === CSR.debugIntCause @@ -694,13 +901,14 @@ class CSRFile( val debugException = p(DebugModuleKey).map(_.debugException).getOrElse(BigInt(0x808)) val debugTVec = Mux(reg_debug, Mux(insn_break, debugEntry.U, debugException.U), debugEntry.U) val delegate = Bool(usingSupervisor) && reg_mstatus.prv <= PRV.S && Mux(cause(xLen-1), read_mideleg(cause_lsbs), read_medeleg(cause_lsbs)) + val delegateVS = reg_mstatus.v && delegate && Mux(cause(xLen-1), read_hideleg(cause_lsbs), read_hedeleg(cause_lsbs)) def mtvecBaseAlign = 2 def mtvecInterruptAlign = { require(reg_mip.getWidth <= xLen) log2Ceil(xLen) } val notDebugTVec = { - val base = Mux(delegate, read_stvec, read_mtvec) + val base = Mux(delegate, Mux(delegateVS, read_vstvec, read_stvec), read_mtvec) val interruptOffset = cause(mtvecInterruptAlign-1, 0) << mtvecBaseAlign val interruptVec = Cat(base >> (mtvecInterruptAlign + mtvecBaseAlign), interruptOffset) val doVector = base(0) && cause(cause.getWidth-1) && (cause_lsbs >> mtvecInterruptAlign) === 0 @@ -720,6 +928,8 @@ class CSRFile( val tvec = Mux(trapToDebug, debugTVec, Mux(trapToNmi, nmiTVec, notDebugTVec)) io.evec := tvec io.ptbr := reg_satp + io.hgatp := reg_hgatp + io.vsatp := reg_vsatp io.eret := insn_call || insn_break || insn_ret io.singleStep := reg_dcsr.step && !reg_debug io.status := reg_mstatus @@ -728,9 +938,17 @@ class CSRFile( io.status.isa := reg_misa io.status.uxl := (if (usingUser) log2Ceil(xLen) - 4 else 0) io.status.sxl := (if (usingSupervisor) log2Ceil(xLen) - 4 else 0) - io.status.dprv := Reg(next = Mux(reg_mstatus.mprv && !reg_debug, reg_mstatus.mpp, reg_mstatus.prv)) - if (xLen == 32) - io.status.sd_rv32 := io.status.sd + io.status.dprv := Mux(reg_mstatus.mprv && !reg_debug, reg_mstatus.mpp, reg_mstatus.prv) + io.status.dv := reg_mstatus.v || Mux(reg_mstatus.mprv && !reg_debug, reg_mstatus.mpv, false.B) + io.status.sd_rv32 := xLen == 32 && io.status.sd + io.status.mpv := reg_mstatus.mpv + io.status.gva := reg_mstatus.gva + io.hstatus := reg_hstatus + io.hstatus.vsxl := (if (usingSupervisor) log2Ceil(xLen) - 4 else 0) + io.gstatus := reg_vsstatus + io.gstatus.sd := io.gstatus.fs.andR || io.gstatus.xs.andR || io.gstatus.vs.andR + io.gstatus.uxl := (if (usingUser) log2Ceil(xLen) - 4 else 0) + io.gstatus.sd_rv32 := xLen == 32 && io.gstatus.sd val exception = insn_call || insn_break || io.exception assert(PopCount(insn_ret :: insn_call :: insn_break :: io.exception :: Nil) <= 1, "these conditions must be mutually exclusive") @@ -745,40 +963,58 @@ class CSRFile( assert(!reg_singleStepped || io.retire === UInt(0)) val epc = formEPC(io.pc) - val noCause :: mCause :: hCause :: sCause :: uCause :: Nil = Enum(5) - val xcause_dest = Wire(init = noCause) when (exception) { when (trapToDebug) { when (!reg_debug) { + reg_mstatus.v := false reg_debug := true reg_dpc := epc reg_dcsr.cause := Mux(reg_singleStepped, 4, Mux(causeIsDebugInt, 3, Mux[UInt](causeIsDebugTrigger, 2, 1))) reg_dcsr.prv := trimPrivilege(reg_mstatus.prv) + reg_dcsr.v := reg_mstatus.v new_prv := PRV.M } }.elsewhen (trapToNmiInt) { when (reg_rnmie) { + reg_mstatus.v := false + reg_mnstatus.mpv := reg_mstatus.v reg_rnmie := false.B reg_mnepc := epc reg_mncause := (BigInt(1) << (xLen-1)).U | Mux(causeIsRnmiBEU, 3.U, 2.U) reg_mnstatus.mpp := trimPrivilege(reg_mstatus.prv) new_prv := PRV.M } + }.elsewhen (delegateVS && nmie) { + reg_mstatus.v := true + reg_vsstatus.spp := reg_mstatus.prv + reg_vsepc := epc + reg_vscause := Mux(cause(xLen-1), Cat(cause(xLen-1, 2), 1.U(2.W)), cause) + reg_vstval := io.tval + reg_vsstatus.spie := reg_vsstatus.sie + reg_vsstatus.sie := false + new_prv := PRV.S }.elsewhen (delegate && nmie) { + reg_mstatus.v := false + reg_hstatus.spvp := Mux(reg_mstatus.v, reg_mstatus.prv(0),reg_hstatus.spvp) + reg_hstatus.gva := io.gva + reg_hstatus.spv := reg_mstatus.v reg_sepc := epc reg_scause := cause - xcause_dest := sCause reg_stval := io.tval + reg_htval := io.htval reg_mstatus.spie := reg_mstatus.sie reg_mstatus.spp := reg_mstatus.prv reg_mstatus.sie := false new_prv := PRV.S }.otherwise { + reg_mstatus.v := false + reg_mstatus.mpv := reg_mstatus.v + reg_mstatus.gva := io.gva reg_mepc := epc reg_mcause := cause - xcause_dest := mCause reg_mtval := io.tval + reg_mtval2 := io.htval reg_mstatus.mpie := reg_mstatus.mie reg_mstatus.mpp := trimPrivilege(reg_mstatus.prv) reg_mstatus.mie := false @@ -809,29 +1045,44 @@ class CSRFile( when (insn_ret) { val ret_prv = WireInit(UInt(), DontCare) when (Bool(usingSupervisor) && !io.rw.addr(9)) { - reg_mstatus.sie := reg_mstatus.spie - reg_mstatus.spie := true - reg_mstatus.spp := PRV.U - ret_prv := reg_mstatus.spp - io.evec := readEPC(reg_sepc) + when (!reg_mstatus.v) { + reg_mstatus.sie := reg_mstatus.spie + reg_mstatus.spie := true + reg_mstatus.spp := PRV.U + ret_prv := reg_mstatus.spp + reg_mstatus.v := usingHypervisor && reg_hstatus.spv + io.evec := readEPC(reg_sepc) + reg_hstatus.spv := false + }.otherwise { + reg_vsstatus.sie := reg_vsstatus.spie + reg_vsstatus.spie := true + reg_vsstatus.spp := PRV.U + ret_prv := reg_vsstatus.spp + reg_mstatus.v := usingHypervisor + io.evec := readEPC(reg_vsepc) + } }.elsewhen (Bool(usingDebug) && io.rw.addr(10) && io.rw.addr(7)) { ret_prv := reg_dcsr.prv + reg_mstatus.v := usingHypervisor && reg_dcsr.v && reg_dcsr.prv <= PRV.S reg_debug := false io.evec := readEPC(reg_dpc) }.elsewhen (Bool(usingNMI) && io.rw.addr(10) && !io.rw.addr(7)) { ret_prv := reg_mnstatus.mpp + reg_mstatus.v := usingHypervisor && reg_mnstatus.mpv && reg_mnstatus.mpp <= PRV.S reg_rnmie := true.B io.evec := readEPC(reg_mnepc) }.otherwise { reg_mstatus.mie := reg_mstatus.mpie reg_mstatus.mpie := true reg_mstatus.mpp := legalizePrivilege(PRV.U) + reg_mstatus.mpv := false ret_prv := reg_mstatus.mpp + reg_mstatus.v := usingHypervisor && reg_mstatus.mpv && reg_mstatus.mpp <= PRV.S io.evec := readEPC(reg_mepc) } new_prv := ret_prv - when (usingUser && ret_prv < PRV.M) { + when (usingUser && ret_prv <= PRV.S) { reg_mstatus.mprv := false } } @@ -865,6 +1116,7 @@ class CSRFile( io.vector.foreach { vio => when (set_vs_dirty) { assert(reg_mstatus.vs > 0) + when (reg_mstatus.v) { reg_vsstatus.vs := 3 } reg_mstatus.vs := 3 } } @@ -873,6 +1125,7 @@ class CSRFile( if (coreParams.haveFSDirty) { when (set_fs_dirty) { assert(reg_mstatus.fs > 0) + when (reg_mstatus.v) { reg_vsstatus.fs := 3 } reg_mstatus.fs := 3 } } @@ -893,6 +1146,9 @@ class CSRFile( val csr_wen = io.rw.cmd.isOneOf(CSR.S, CSR.C, CSR.W) io.csrw_counter := Mux(coreParams.haveBasicCounters && csr_wen && (io.rw.addr.inRange(CSRs.mcycle, CSRs.mcycle + CSR.nCtr) || io.rw.addr.inRange(CSRs.mcycleh, CSRs.mcycleh + CSR.nCtr)), UIntToOH(io.rw.addr(log2Ceil(CSR.nCtr+nPerfCounters)-1, 0)), 0.U) when (csr_wen) { + val scause_mask = ((BigInt(1) << (xLen-1)) + 31).U /* only implement 5 LSBs and MSB */ + val satp_valid_modes = 0 +: (minPgLevels to pgLevels).map(new PTBR().pgLevelsToMode(_)) + when (decoded_addr(CSRs.mstatus)) { val new_mstatus = new MStatus().fromBits(wdata) reg_mstatus.mie := new_mstatus.mie @@ -913,6 +1169,10 @@ class CSRFile( reg_mstatus.sum := new_mstatus.sum reg_mstatus.tvm := new_mstatus.tvm } + if (usingHypervisor) { + reg_mstatus.mpv := new_mstatus.mpv + reg_mstatus.gva := new_mstatus.gva + } } if (usingSupervisor || usingFPU) reg_mstatus.fs := formFS(new_mstatus.fs) @@ -938,6 +1198,9 @@ class CSRFile( reg_mip.stip := new_mip.stip reg_mip.seip := new_mip.seip } + if (usingHypervisor) { + reg_mip.vssip := new_mip.vssip + } } when (decoded_addr(CSRs.mie)) { reg_mie := wdata & supported_interrupts } when (decoded_addr(CSRs.mepc)) { reg_mepc := formEPC(wdata) } @@ -945,15 +1208,16 @@ class CSRFile( if (mtvecWritable) when (decoded_addr(CSRs.mtvec)) { reg_mtvec := wdata } when (decoded_addr(CSRs.mcause)) { reg_mcause := wdata & UInt((BigInt(1) << (xLen-1)) + (BigInt(1) << whichInterrupt.getWidth) - 1) } - when (decoded_addr(CSRs.mtval)) { reg_mtval := wdata(vaddrBitsExtended-1,0) } + when (decoded_addr(CSRs.mtval)) { reg_mtval := wdata } if (usingNMI) { - val new_mnstatus = new MStatus().fromBits(wdata) + val new_mnstatus = new MNStatus().fromBits(wdata) when (decoded_addr(CSRs.mnscratch)) { reg_mnscratch := wdata } when (decoded_addr(CSRs.mnepc)) { reg_mnepc := formEPC(wdata) } when (decoded_addr(CSRs.mncause)) { reg_mncause := wdata & UInt((BigInt(1) << (xLen-1)) + BigInt(3)) } when (decoded_addr(CSRs.mnstatus)) { reg_mnstatus.mpp := legalizePrivilege(new_mnstatus.mpp) + reg_mnstatus.mpv := usingHypervisor && new_mnstatus.mpv reg_rnmie := reg_rnmie | new_mnstatus.mie // mnie bit settable but not clearable from software } } @@ -985,6 +1249,7 @@ class CSRFile( if (usingSupervisor) reg_dcsr.ebreaks := new_dcsr.ebreaks if (usingUser) reg_dcsr.ebreaku := new_dcsr.ebreaku if (usingUser) reg_dcsr.prv := legalizePrivilege(new_dcsr.prv) + if (usingHypervisor) reg_dcsr.v := new_dcsr.v } when (decoded_addr(CSRs.dpc)) { reg_dpc := formEPC(wdata) } when (decoded_addr(CSRs.dscratch)) { reg_dscratch := wdata } @@ -1012,24 +1277,95 @@ class CSRFile( when (decoded_addr(CSRs.satp)) { if (usingVM) { val new_satp = new PTBR().fromBits(wdata) - val valid_modes = 0 +: (minPgLevels to pgLevels).map(new_satp.pgLevelsToMode(_)) - when (new_satp.mode.isOneOf(valid_modes.map(_.U))) { - reg_satp.mode := new_satp.mode & valid_modes.reduce(_|_) + when (new_satp.mode.isOneOf(satp_valid_modes.map(_.U))) { + reg_satp.mode := new_satp.mode & satp_valid_modes.reduce(_|_) reg_satp.ppn := new_satp.ppn(ppnBits-1,0) if (asIdBits > 0) reg_satp.asid := new_satp.asid(asIdBits-1,0) } } } - when (decoded_addr(CSRs.sie)) { reg_mie := (reg_mie & ~read_mideleg) | (wdata & read_mideleg) } + when (decoded_addr(CSRs.sie)) { reg_mie := (reg_mie & ~sie_mask) | (wdata & sie_mask) } when (decoded_addr(CSRs.sscratch)) { reg_sscratch := wdata } when (decoded_addr(CSRs.sepc)) { reg_sepc := formEPC(wdata) } when (decoded_addr(CSRs.stvec)) { reg_stvec := wdata } - when (decoded_addr(CSRs.scause)) { reg_scause := wdata & UInt((BigInt(1) << (xLen-1)) + 31) /* only implement 5 LSBs and MSB */ } - when (decoded_addr(CSRs.stval)) { reg_stval := wdata(vaddrBitsExtended-1,0) } + when (decoded_addr(CSRs.scause)) { reg_scause := wdata & scause_mask } + when (decoded_addr(CSRs.stval)) { reg_stval := wdata } when (decoded_addr(CSRs.mideleg)) { reg_mideleg := wdata } when (decoded_addr(CSRs.medeleg)) { reg_medeleg := wdata } when (decoded_addr(CSRs.scounteren)) { reg_scounteren := wdata } } + + if (usingHypervisor) { + when (decoded_addr(CSRs.hstatus)) { + val new_hstatus = new HStatus().fromBits(wdata) + reg_hstatus.gva := new_hstatus.gva + reg_hstatus.spv := new_hstatus.spv + reg_hstatus.spvp := new_hstatus.spvp + reg_hstatus.hu := new_hstatus.hu + reg_hstatus.vtvm := new_hstatus.vtvm + reg_hstatus.vtw := new_hstatus.vtw + reg_hstatus.vtsr := new_hstatus.vtsr + reg_hstatus.vsxl := new_hstatus.vsxl + } + when (decoded_addr(CSRs.hideleg)) { reg_hideleg := wdata } + when (decoded_addr(CSRs.hedeleg)) { reg_hedeleg := wdata } + when (decoded_addr(CSRs.hgatp)) { + val new_hgatp = new PTBR().fromBits(wdata) + val valid_modes = 0 +: (minPgLevels to pgLevels).map(new_hgatp.pgLevelsToMode(_)) + when (new_hgatp.mode.isOneOf(valid_modes.map(_.U))) { + reg_hgatp.mode := new_hgatp.mode & valid_modes.reduce(_|_) + } + reg_hgatp.ppn := Cat(new_hgatp.ppn(ppnBits-1,2), 0.U(2.W)) + if (vmIdBits > 0) reg_hgatp.asid := new_hgatp.asid(vmIdBits-1,0) + } + when (decoded_addr(CSRs.hip)) { + val new_hip = new MIP().fromBits((read_mip & ~hs_delegable_interrupts) | (wdata & hs_delegable_interrupts)) + reg_mip.vssip := new_hip.vssip + } + when (decoded_addr(CSRs.hie)) { reg_mie := (reg_mie & ~hs_delegable_interrupts) | (wdata & hs_delegable_interrupts) } + when (decoded_addr(CSRs.hvip)) { + val new_sip = new MIP().fromBits((read_mip & ~hs_delegable_interrupts) | (wdata & hs_delegable_interrupts)) + reg_mip.vssip := new_sip.vssip + reg_mip.vstip := new_sip.vstip + reg_mip.vseip := new_sip.vseip + } + when (decoded_addr(CSRs.hcounteren)) { reg_hcounteren := wdata } + when (decoded_addr(CSRs.htval)) { reg_htval := wdata } + when (decoded_addr(CSRs.mtval2)) { reg_mtval2 := wdata } + + when (decoded_addr(CSRs.vsstatus)) { + val new_vsstatus = new MStatus().fromBits(wdata) + reg_vsstatus.sie := new_vsstatus.sie + reg_vsstatus.spie := new_vsstatus.spie + reg_vsstatus.spp := new_vsstatus.spp + reg_vsstatus.mxr := new_vsstatus.mxr + reg_vsstatus.sum := new_vsstatus.sum + reg_vsstatus.fs := formFS(new_vsstatus.fs) + reg_vsstatus.vs := formVS(new_vsstatus.vs) + if (usingRoCC) reg_vsstatus.xs := Fill(2, new_vsstatus.xs.orR) + } + when (decoded_addr(CSRs.vsip)) { + val new_vsip = new MIP().fromBits((read_hip & ~read_hideleg) | ((wdata << 1) & read_hideleg)) + reg_mip.vssip := new_vsip.vssip + } + when (decoded_addr(CSRs.vsatp)) { + val new_vsatp = new PTBR().fromBits(wdata) + val mode_ok = new_vsatp.mode.isOneOf(satp_valid_modes.map(_.U)) + when (mode_ok) { + reg_vsatp.mode := new_vsatp.mode & satp_valid_modes.reduce(_|_) + } + when (mode_ok || !reg_mstatus.v) { + reg_vsatp.ppn := new_vsatp.ppn(vpnBits-1,0) + } + if (asIdBits > 0) reg_vsatp.asid := new_vsatp.asid(asIdBits-1,0) + } + when (decoded_addr(CSRs.vsie)) { reg_mie := (reg_mie & ~read_hideleg) | ((wdata << 1) & read_hideleg) } + when (decoded_addr(CSRs.vsscratch)) { reg_vsscratch := wdata } + when (decoded_addr(CSRs.vsepc)) { reg_vsepc := formEPC(wdata) } + when (decoded_addr(CSRs.vstvec)) { reg_vstvec := wdata } + when (decoded_addr(CSRs.vscause)) { reg_vscause := wdata & scause_mask } + when (decoded_addr(CSRs.vstval)) { reg_vstval := wdata } + } if (usingUser) { when (decoded_addr(CSRs.mcounteren)) { reg_mcounteren := wdata } } @@ -1122,10 +1458,30 @@ class CSRFile( } } - reg_satp.asid := 0 + when(reset.asBool) { + reg_satp.mode := 0.U + reg_vsatp.mode := 0.U + reg_hgatp.mode := 0.U + } if (!usingVM) { - reg_satp.mode := 0 - reg_satp.ppn := 0 + reg_satp.mode := 0.U + reg_satp.ppn := 0.U + reg_satp.asid := 0.U + } + if (!usingHypervisor) { + reg_vsatp.mode := 0.U + reg_vsatp.ppn := 0.U + reg_vsatp.asid := 0.U + reg_hgatp.mode := 0.U + reg_hgatp.ppn := 0.U + reg_hgatp.asid := 0.U + } + if (!(asIdBits > 0)) { + reg_satp.asid := 0.U + reg_vsatp.asid := 0.U + } + if (!(vmIdBits > 0)) { + reg_hgatp.asid := 0.U } if (nBreakpoints <= 1) reg_tselect := 0 @@ -1171,8 +1527,8 @@ class CSRFile( def chooseInterrupt(masksIn: Seq[UInt]): (Bool, UInt) = { val nonstandard = supported_interrupts.getWidth-1 to 12 by -1 - // MEI, MSI, MTI, SEI, SSI, STI, UEI, USI, UTI - val standard = Seq(11, 3, 7, 9, 1, 5, 8, 0, 4) + // MEI, MSI, MTI, SEI, SSI, STI, VSEI, VSSI, VSTI, UEI, USI, UTI + val standard = Seq(11, 3, 7, 9, 1, 5, 10, 2, 6, 8, 0, 4) val priority = nonstandard ++ standard val masks = masksIn.reverse val any = masks.flatMap(m => priority.filter(_ < m.getWidth).map(i => m(i))).reduce(_||_) diff --git a/src/main/scala/rocket/Consts.scala b/src/main/scala/rocket/Consts.scala index 4618731d8b..2a28a17f92 100644 --- a/src/main/scala/rocket/Consts.scala +++ b/src/main/scala/rocket/Consts.scala @@ -71,14 +71,17 @@ trait MemoryOpConstants { def M_PWR = UInt("b10001") // partial (masked) store def M_PRODUCE = UInt("b10010") // write back dirty data and cede W permissions def M_CLEAN = UInt("b10011") // write back dirty data and retain R/W permissions - def M_SFENCE = UInt("b10100") // flush TLB + def M_SFENCE = UInt("b10100") // SFENCE.VMA + def M_HFENCEV = UInt("b10101") // HFENCE.VVMA + def M_HFENCEG = UInt("b10110") // HFENCE.GVMA def M_WOK = UInt("b10111") // check write permissions but don't perform a write + def M_HLVX = UInt("b10000") // HLVX instruction def isAMOLogical(cmd: UInt) = cmd.isOneOf(M_XA_SWAP, M_XA_XOR, M_XA_OR, M_XA_AND) def isAMOArithmetic(cmd: UInt) = cmd.isOneOf(M_XA_ADD, M_XA_MIN, M_XA_MAX, M_XA_MINU, M_XA_MAXU) def isAMO(cmd: UInt) = isAMOLogical(cmd) || isAMOArithmetic(cmd) def isPrefetch(cmd: UInt) = cmd === M_PFR || cmd === M_PFW - def isRead(cmd: UInt) = cmd === M_XRD || cmd === M_XLR || cmd === M_XSC || isAMO(cmd) + def isRead(cmd: UInt) = cmd.isOneOf(M_XRD, M_HLVX, M_XLR, M_XSC) || isAMO(cmd) def isWrite(cmd: UInt) = cmd === M_XWR || cmd === M_PWR || cmd === M_XSC || isAMO(cmd) def isWriteIntent(cmd: UInt) = isWrite(cmd) || cmd === M_PFW || cmd === M_XLR } diff --git a/src/main/scala/rocket/DCache.scala b/src/main/scala/rocket/DCache.scala index a70b7aab5d..18d0ed5bce 100644 --- a/src/main/scala/rocket/DCache.scala +++ b/src/main/scala/rocket/DCache.scala @@ -182,13 +182,15 @@ class DCacheModule(outer: DCache) extends HellaCacheModule(outer) { s0_tlb_req.vaddr := s0_req.addr s0_tlb_req.size := s0_req.size s0_tlb_req.cmd := s0_req.cmd + s0_tlb_req.prv := s0_req.dprv + s0_tlb_req.v := s0_req.dv } val s1_tlb_req = RegEnable(s0_tlb_req, s0_clk_en || tlb_port.req.valid) val s1_read = isRead(s1_req.cmd) val s1_write = isWrite(s1_req.cmd) val s1_readwrite = s1_read || s1_write - val s1_sfence = s1_req.cmd === M_SFENCE + val s1_sfence = s1_req.cmd === M_SFENCE || s1_req.cmd === M_HFENCEV || s1_req.cmd === M_HFENCEG val s1_flush_line = s1_req.cmd === M_FLUSH_ALL && s1_req.size(0) val s1_flush_valid = Reg(Bool()) val s1_waw_hazard = Wire(Bool()) @@ -258,6 +260,8 @@ class DCacheModule(outer: DCache) extends HellaCacheModule(outer) { tlb.io.sfence.bits.rs2 := s1_req.size(1) tlb.io.sfence.bits.asid := io.cpu.s1_data.data tlb.io.sfence.bits.addr := s1_req.addr + tlb.io.sfence.bits.hv := s1_req.cmd === M_HFENCEV + tlb.io.sfence.bits.hg := s1_req.cmd === M_HFENCEG tlb_port.req.ready := clock_en_reg tlb_port.s1_resp := tlb.io.resp @@ -890,6 +894,8 @@ class DCacheModule(outer: DCache) extends HellaCacheModule(outer) { io.cpu.resp.bits.replay := false io.cpu.s2_uncached := s2_uncached && !s2_hit io.cpu.s2_paddr := s2_req.addr + io.cpu.s2_gpa := s2_tlb_xcpt.gpa + io.cpu.s2_gpa_is_pte := s2_tlb_xcpt.gpa_is_pte // report whether there are any outstanding accesses. disregard any // slave-port accesses, since they don't affect local memory ordering. diff --git a/src/main/scala/rocket/Frontend.scala b/src/main/scala/rocket/Frontend.scala index 26822ad6ba..c175955185 100644 --- a/src/main/scala/rocket/Frontend.scala +++ b/src/main/scala/rocket/Frontend.scala @@ -24,6 +24,9 @@ class FrontendExceptions extends Bundle { val pf = new Bundle { val inst = Bool() } + val gf = new Bundle { + val inst = Bool() + } val ae = new Bundle { val inst = Bool() } @@ -49,6 +52,7 @@ class FrontendIO(implicit p: Parameters) extends CoreBundle()(p) { val req = Valid(new FrontendReq) val sfence = Valid(new SFenceReq) val resp = Decoupled(new FrontendResp).flip + val gpa = Flipped(Valid(UInt(vaddrBitsExtended.W))) val btb_update = Valid(new BTBUpdate) val bht_update = Valid(new BHTUpdate) val ras_update = Valid(new RASUpdate) @@ -112,7 +116,7 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) val s2_btb_resp_bits = Reg(new BTBResp) val s2_btb_taken = s2_btb_resp_valid && s2_btb_resp_bits.taken val s2_tlb_resp = Reg(tlb.io.resp) - val s2_xcpt = s2_tlb_resp.ae.inst || s2_tlb_resp.pf.inst + val s2_xcpt = s2_tlb_resp.ae.inst || s2_tlb_resp.pf.inst || s2_tlb_resp.gf.inst val s2_speculative = Reg(init=Bool(false)) val s2_partial_insn_valid = RegInit(false.B) val s2_partial_insn = Reg(UInt(width = coreInstBits)) @@ -149,6 +153,8 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) tlb.io.req.bits.vaddr := s1_pc tlb.io.req.bits.passthrough := Bool(false) tlb.io.req.bits.size := log2Ceil(coreInstBytes*fetchWidth) + tlb.io.req.bits.prv := io.ptw.status.prv + tlb.io.req.bits.v := io.ptw.status.v tlb.io.sfence := io.cpu.sfence tlb.io.kill := !s2_valid @@ -324,6 +330,21 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) io.cpu.resp <> fq.io.deq + // supply guest physical address to commit stage + val gpa_valid = Reg(Bool()) + val gpa = Reg(UInt(vaddrBitsExtended.W)) + when (fq.io.enq.fire() && s2_tlb_resp.gf.inst) { + when (!gpa_valid) { + gpa := s2_tlb_resp.gpa + } + gpa_valid := true + } + when (io.cpu.req.valid) { + gpa_valid := false + } + io.cpu.gpa.valid := gpa_valid + io.cpu.gpa.bits := gpa + // performance events io.cpu.perf := icache.io.perf io.cpu.perf.tlbMiss := io.ptw.req.fire() diff --git a/src/main/scala/rocket/HellaCache.scala b/src/main/scala/rocket/HellaCache.scala index 6fbdec1a84..8976d75eea 100644 --- a/src/main/scala/rocket/HellaCache.scala +++ b/src/main/scala/rocket/HellaCache.scala @@ -109,6 +109,7 @@ trait HasCoreMemOp extends HasL1HellaCacheParameters { val size = Bits(width = log2Ceil(coreDataBytes.log2 + 1)) val signed = Bool() val dprv = UInt(width = PRV.SZ) + val dv = Bool() } trait HasCoreData extends HasCoreParameters { @@ -142,6 +143,7 @@ class AlignmentExceptions extends Bundle { class HellaCacheExceptions extends Bundle { val ma = new AlignmentExceptions val pf = new AlignmentExceptions + val gf = new AlignmentExceptions val ae = new AlignmentExceptions } @@ -174,6 +176,8 @@ class HellaCacheIO(implicit p: Parameters) extends CoreBundle()(p) { val resp = Valid(new HellaCacheResp).flip val replay_next = Bool(INPUT) val s2_xcpt = (new HellaCacheExceptions).asInput + val s2_gpa = UInt(vaddrBitsExtended.W).asInput + val s2_gpa_is_pte = Bool(INPUT) val uncached_resp = tileParams.dcache.get.separateUncachedResp.option(Decoupled(new HellaCacheResp).flip) val ordered = Bool(INPUT) val perf = new HellaCachePerfEvents().asInput diff --git a/src/main/scala/rocket/HellaCacheArbiter.scala b/src/main/scala/rocket/HellaCacheArbiter.scala index b2b65e827f..5a4510ec1c 100644 --- a/src/main/scala/rocket/HellaCacheArbiter.scala +++ b/src/main/scala/rocket/HellaCacheArbiter.scala @@ -59,6 +59,8 @@ class HellaCacheArbiter(n: Int)(implicit p: Parameters) extends Module val tag_hit = io.mem.resp.bits.tag(log2Up(n)-1,0) === UInt(i) resp.valid := io.mem.resp.valid && tag_hit io.requestor(i).s2_xcpt := io.mem.s2_xcpt + io.requestor(i).s2_gpa := io.mem.s2_gpa + io.requestor(i).s2_gpa_is_pte := io.mem.s2_gpa_is_pte io.requestor(i).ordered := io.mem.ordered io.requestor(i).perf := io.mem.perf io.requestor(i).s2_nack := io.mem.s2_nack && s2_id === UInt(i) diff --git a/src/main/scala/rocket/IDecode.scala b/src/main/scala/rocket/IDecode.scala index f4bc8c314b..4533a67058 100644 --- a/src/main/scala/rocket/IDecode.scala +++ b/src/main/scala/rocket/IDecode.scala @@ -145,7 +145,7 @@ class CFlushDecode(supportsFlushLine: Boolean)(implicit val p: Parameters) exten class SVMDecode(implicit val p: Parameters) extends DecodeConstants { val table: Array[(BitPat, List[BitPat])] = Array( - SFENCE_VMA->List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO,A1_RS1, IMM_X, DW_XPR,FN_ADD, Y,M_SFENCE, N,N,N,N,N,N,N,CSR.N,N,N,N,N)) + SFENCE_VMA->List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO,A1_RS1, IMM_X, DW_XPR,FN_ADD, Y,M_SFENCE, N,N,N,N,N,N,N,CSR.I,N,N,N,N)) } class SDecode(implicit val p: Parameters) extends DecodeConstants @@ -154,6 +154,26 @@ class SDecode(implicit val p: Parameters) extends DecodeConstants SRET-> List(Y,N,N,N,N,N,N,X,N,A2_X, A1_X, IMM_X, DW_X, FN_X, N,M_X, N,N,N,N,N,N,N,CSR.I,N,N,N,N)) } +class HypervisorDecode(implicit val p: Parameters) extends DecodeConstants +{ + val table: Array[(BitPat, List[BitPat])] = Array( + + HFENCE_VVMA->List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO,A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_HFENCEV, N,N,N,N,N,N,N,CSR.I,N,N,N,N), + HFENCE_GVMA->List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO,A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_HFENCEG, N,N,N,N,N,N,N,CSR.I,N,N,N,N), + + HLV_B -> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HLV_BU-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HLV_H -> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HLV_HU-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HLVX_HU-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_HLVX, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HLV_W-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HLVX_WU-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_HLVX, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + + HSV_B-> List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO, A1_RS1, IMM_I, DW_XPR, FN_ADD, Y,M_XWR, N,N,N,N,N,N,N,CSR.I,N,N,N,N), + HSV_H-> List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO, A1_RS1, IMM_I, DW_XPR, FN_ADD, Y,M_XWR, N,N,N,N,N,N,N,CSR.I,N,N,N,N), + HSV_W-> List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO, A1_RS1, IMM_I, DW_XPR, FN_ADD, Y,M_XWR, N,N,N,N,N,N,N,CSR.I,N,N,N,N)) +} + class DebugDecode(implicit val p: Parameters) extends DecodeConstants { val table: Array[(BitPat, List[BitPat])] = Array( @@ -196,6 +216,14 @@ class I64Decode(implicit val p: Parameters) extends DecodeConstants SRAW-> List(Y,N,N,N,N,N,Y,Y,N,A2_RS2, A1_RS1, IMM_X, DW_32,FN_SRA, N,M_X, N,N,N,N,N,N,Y,CSR.N,N,N,N,N)) } +class Hypervisor64Decode(implicit val p: Parameters) extends DecodeConstants +{ + val table: Array[(BitPat, List[BitPat])] = Array( + HLV_D-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N), + HSV_D-> List(Y,N,N,N,N,N,Y,Y,N,A2_ZERO, A1_RS1, IMM_I, DW_XPR, FN_ADD, Y,M_XWR, N,N,N,N,N,N,N,CSR.I,N,N,N,N), + HLV_WU-> List(Y,N,N,N,N,N,N,Y,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N)) +} + class MDecode(pipelinedMul: Boolean)(implicit val p: Parameters) extends DecodeConstants { val M = if (pipelinedMul) Y else N diff --git a/src/main/scala/rocket/Instructions.scala b/src/main/scala/rocket/Instructions.scala index 28246b0e61..579a812d2b 100644 --- a/src/main/scala/rocket/Instructions.scala +++ b/src/main/scala/rocket/Instructions.scala @@ -225,6 +225,19 @@ object Instructions { def CSRRCI = BitPat("b?????????????????111?????1110011") def HFENCE_VVMA = BitPat("b0010001??????????000000001110011") def HFENCE_GVMA = BitPat("b0110001??????????000000001110011") + def HLV_B = BitPat("b011000000000?????100?????1110011") + def HLV_BU = BitPat("b011000000001?????100?????1110011") + def HLV_H = BitPat("b011001000000?????100?????1110011") + def HLV_HU = BitPat("b011001000001?????100?????1110011") + def HLVX_HU = BitPat("b011001000011?????100?????1110011") + def HLV_W = BitPat("b011010000000?????100?????1110011") + def HLVX_WU = BitPat("b011010000011?????100?????1110011") + def HSV_B = BitPat("b0110001??????????100000001110011") + def HSV_H = BitPat("b0110011??????????100000001110011") + def HSV_W = BitPat("b0110101??????????100000001110011") + def HLV_WU = BitPat("b011010000001?????100?????1110011") + def HLV_D = BitPat("b011011000000?????100?????1110011") + def HSV_D = BitPat("b0110111??????????100000001110011") def FADD_S = BitPat("b0000000??????????????????1010011") def FSUB_S = BitPat("b0000100??????????????????1010011") def FMUL_S = BitPat("b0001000??????????????????1010011") @@ -874,11 +887,15 @@ object Causes { val store_access = 0x7 val user_ecall = 0x8 val supervisor_ecall = 0x9 - val hypervisor_ecall = 0xa + val virtual_supervisor_ecall = 0xa val machine_ecall = 0xb val fetch_page_fault = 0xc val load_page_fault = 0xd val store_page_fault = 0xf + val fetch_guest_page_fault = 0x14 + val load_guest_page_fault = 0x15 + val virtual_instruction = 0x16 + val store_guest_page_fault = 0x17 val all = { val res = collection.mutable.ArrayBuffer[Int]() res += misaligned_fetch @@ -891,11 +908,15 @@ object Causes { res += store_access res += user_ecall res += supervisor_ecall - res += hypervisor_ecall + res += virtual_supervisor_ecall res += machine_ecall res += fetch_page_fault res += load_page_fault res += store_page_fault + res += fetch_guest_page_fault + res += load_guest_page_fault + res += virtual_instruction + res += store_guest_page_fault res.toArray } } @@ -972,8 +993,16 @@ object CSRs { val hstatus = 0x600 val hedeleg = 0x602 val hideleg = 0x603 + val hie = 0x604 + val htimedelta = 0x605 val hcounteren = 0x606 + val hgeie = 0x607 + val htval = 0x643 + val hip = 0x644 + val hvip = 0x645 + val htinst = 0x64a val hgatp = 0x680 + val hgeip = 0xe12 val utvt = 0x7 val unxti = 0x45 val uintstatus = 0x46 @@ -1005,6 +1034,8 @@ object CSRs { val mnepc = 0x351 val mncause = 0x352 val mnstatus = 0x353 + val mtinst = 0x34a + val mtval2 = 0x34b val pmpcfg0 = 0x3a0 val pmpcfg1 = 0x3a1 val pmpcfg2 = 0x3a2 @@ -1101,6 +1132,7 @@ object CSRs { val marchid = 0xf12 val mimpid = 0xf13 val mhartid = 0xf14 + val htimedeltah = 0x615 val cycleh = 0xc80 val timeh = 0xc81 val instreth = 0xc82 @@ -1237,8 +1269,16 @@ object CSRs { res += hstatus res += hedeleg res += hideleg + res += hie + res += htimedelta res += hcounteren + res += hgeie + res += htval + res += hip + res += hvip + res += htinst res += hgatp + res += hgeip res += utvt res += unxti res += uintstatus @@ -1266,6 +1306,12 @@ object CSRs { res += mcause res += mtval res += mip + res += mnscratch + res += mnepc + res += mncause + res += mnstatus + res += mtinst + res += mtval2 res += pmpcfg0 res += pmpcfg1 res += pmpcfg2 @@ -1366,6 +1412,7 @@ object CSRs { } val all32 = { val res = collection.mutable.ArrayBuffer(all:_*) + res += htimedeltah res += cycleh res += timeh res += instreth diff --git a/src/main/scala/rocket/NBDcache.scala b/src/main/scala/rocket/NBDcache.scala index d60bbccd5d..e3c92d62c7 100644 --- a/src/main/scala/rocket/NBDcache.scala +++ b/src/main/scala/rocket/NBDcache.scala @@ -726,6 +726,8 @@ class NonBlockingDCacheModule(outer: NonBlockingDCache) extends HellaCacheModule dtlb.io.req.bits.vaddr := s1_req.addr dtlb.io.req.bits.size := s1_req.size dtlb.io.req.bits.cmd := s1_req.cmd + dtlb.io.req.bits.prv := s1_req.dprv + dtlb.io.req.bits.v := s1_req.dv when (!dtlb.io.req.ready && !io.cpu.req.bits.phys) { io.cpu.req.ready := Bool(false) } dtlb.io.sfence.valid := s1_valid && !io.cpu.s1_kill && s1_sfence diff --git a/src/main/scala/rocket/PTW.scala b/src/main/scala/rocket/PTW.scala index 1ac0bc1bee..a97f1b102f 100644 --- a/src/main/scala/rocket/PTW.scala +++ b/src/main/scala/rocket/PTW.scala @@ -19,15 +19,23 @@ import scala.collection.mutable.ListBuffer class PTWReq(implicit p: Parameters) extends CoreBundle()(p) { val addr = UInt(width = vpnBits) + val need_gpa = Bool() + val vstage1 = Bool() + val stage2 = Bool() } class PTWResp(implicit p: Parameters) extends CoreBundle()(p) { val ae_ptw = Bool() val ae_final = Bool() + val gf = Bool() + val hr = Bool() + val hw = Bool() + val hx = Bool() val pte = new PTE val level = UInt(width = log2Ceil(pgLevels)) val fragmented_superpage = Bool() val homogeneous = Bool() + val gpa = Valid(UInt(vaddrBits.W)) } class TLBPTWIO(implicit p: Parameters) extends CoreBundle()(p) @@ -35,7 +43,11 @@ class TLBPTWIO(implicit p: Parameters) extends CoreBundle()(p) val req = Decoupled(Valid(new PTWReq)) val resp = Valid(new PTWResp).flip val ptbr = new PTBR().asInput + val hgatp = new PTBR().asInput + val vsatp = new PTBR().asInput val status = new MStatus().asInput + val hstatus = new HStatus().asInput + val gstatus = new MStatus().asInput val pmp = Vec(nPMPs, new PMP).asInput val customCSRs = coreParams.customCSRs.asInput } @@ -50,8 +62,12 @@ class PTWPerfEvents extends Bundle { class DatapathPTWIO(implicit p: Parameters) extends CoreBundle()(p) with HasCoreParameters { val ptbr = new PTBR().asInput + val hgatp = new PTBR().asInput + val vsatp = new PTBR().asInput val sfence = Valid(new SFenceReq).flip val status = new MStatus().asInput + val hstatus = new HStatus().asInput + val gstatus = new MStatus().asInput val pmp = Vec(nPMPs, new PMP).asInput val perf = new PTWPerfEvents().asOutput val customCSRs = coreParams.customCSRs.asInput @@ -78,12 +94,13 @@ class PTE(implicit p: Parameters) extends CoreBundle()(p) { def sr(dummy: Int = 0) = leaf() && r def sw(dummy: Int = 0) = leaf() && w && d def sx(dummy: Int = 0) = leaf() && x + def isFullPerm(dummy: Int = 0) = uw() && ux() } class L2TLBEntry(nSets: Int)(implicit p: Parameters) extends CoreBundle()(p) with HasCoreParameters { val idxBits = log2Ceil(nSets) - val tagBits = vpnBits - idxBits + val tagBits = maxSVAddrBits - pgIdxBits - idxBits + (if (usingHypervisor) 1 else 0) val tag = UInt(width = tagBits) val ppn = UInt(width = ppnBits) val d = Bool() @@ -124,14 +141,31 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( withClock (gated_clock) { // entering gated-clock domain val invalidated = Reg(Bool()) - val count = Reg(UInt(width = log2Up(pgLevels))) - val resp_ae_ptw = RegNext(false.B) - val resp_ae_final = RegNext(false.B) - val resp_fragmented_superpage = RegNext(false.B) + val count = Reg(UInt(width = log2Ceil(pgLevels))) + val resp_ae_ptw = Reg(Bool()) + val resp_ae_final = Reg(Bool()) + val resp_gf = Reg(Bool()) + val resp_hr = Reg(Bool()) + val resp_hw = Reg(Bool()) + val resp_hx = Reg(Bool()) + val resp_fragmented_superpage = Reg(Bool()) val r_req = Reg(new PTWReq) val r_req_dest = Reg(Bits()) val r_pte = Reg(new PTE) + val r_hgatp = Reg(new PTBR) + + val aux_count = Reg(UInt(log2Ceil(pgLevels).W)) + val aux_pte = Reg(new PTE) + val gpa_pgoff = Reg(UInt(pgIdxBits.W)) // only valid in resp_gf case + val stage2 = Reg(Bool()) + val stage2_final = Reg(Bool()) + + val satp = Mux(arb.io.out.bits.bits.vstage1, io.dpath.vsatp, io.dpath.ptbr) + val r_hgatp_initial_count = pgLevels - minPgLevels - r_hgatp.additionalPgLevels + val do_both_stages = r_req.vstage1 && r_req.stage2 + val max_count = count max aux_count + val vpn = Mux(r_req.vstage1 && stage2, aux_pte.ppn, r_req.addr) val mem_resp_valid = RegNext(io.mem.resp.valid) val mem_resp_data = RegNext(io.mem.resp.bits.data) @@ -147,53 +181,78 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( val (pte, invalid_paddr) = { val tmp = new PTE().fromBits(mem_resp_data) val res = Wire(init = tmp) - res.ppn := tmp.ppn(ppnBits-1, 0) + res.ppn := Mux(do_both_stages && !stage2, tmp.ppn(vpnBits-1, 0), tmp.ppn(ppnBits-1, 0)) when (tmp.r || tmp.w || tmp.x) { // for superpage mappings, make sure PPN LSBs are zero for (i <- 0 until pgLevels-1) when (count <= i && tmp.ppn((pgLevels-1-i)*pgLevelBits-1, (pgLevels-2-i)*pgLevelBits) =/= 0) { res.v := false } } - (res, (tmp.ppn >> ppnBits) =/= 0) + (res, Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0, (tmp.ppn >> ppnBits) =/= 0)) } val traverse = pte.table() && !invalid_paddr && count < pgLevels-1 val pte_addr = if (!usingVM) 0.U else { - val vpn_idxs = (0 until pgLevels).map(i => (r_req.addr >> (pgLevels-i-1)*pgLevelBits)(pgLevelBits-1,0)) + val vpn_idxs = (0 until pgLevels).map { i => + val width = pgLevelBits + (if (i <= pgLevels - minPgLevels) hypervisorExtraAddrBits else 0) + (vpn >> (pgLevels - i - 1) * pgLevelBits)(width - 1, 0) + } + val mask = Mux(stage2 && count === r_hgatp_initial_count, ((1 << (hypervisorExtraAddrBits + pgLevelBits)) - 1).U, ((1 << pgLevelBits) - 1).U) + val vpn_idx = vpn_idxs(count) & mask + val size = if (usingHypervisor) vaddrBits else paddrBits + (((r_pte.ppn << pgLevelBits) | vpn_idx) << log2Ceil(xLen / 8))(size - 1, 0) + } + val pte_cache_addr = if (!usingHypervisor) pte_addr else { + val vpn_idxs = (0 until pgLevels-1).map(i => (aux_pte.ppn >> (pgLevels-i-1)*pgLevelBits)(pgLevelBits-1,0)) val vpn_idx = vpn_idxs(count) - Cat(r_pte.ppn, vpn_idx) << log2Ceil(xLen/8) + (Cat(r_pte.ppn, vpn_idx) << log2Ceil(xLen/8))(vaddrBits-1, 0) } - val fragmented_superpage_ppn = { - val choices = (pgLevels-1 until 0 by -1).map(i => Cat(r_pte.ppn >> (pgLevelBits*i), r_req.addr(((pgLevelBits*i) min vpnBits)-1, 0).padTo(pgLevelBits*i))) - choices(count) + val stage2_pte_cache_addr = if (!usingHypervisor) 0.U else { + val vpn_idxs = (0 until pgLevels - 1).map(i => (r_req.addr >> (pgLevels - i - 1) * pgLevelBits)(pgLevelBits - 1, 0)) + val vpn_idx = vpn_idxs(aux_count) + (Cat(aux_pte.ppn, vpn_idx) << log2Ceil(xLen / 8))(vaddrBits - 1, 0) } - when (arb.io.out.fire()) { - r_req := arb.io.out.bits.bits - r_req_dest := arb.io.chosen + def makeFragmentedSuperpagePPN(ppn: UInt): Seq[UInt] = { + (pgLevels-1 until 0 by -1).map(i => Cat(ppn >> (pgLevelBits*i), r_req.addr(((pgLevelBits*i) min vpnBits)-1, 0).padTo(pgLevelBits*i))) } - val (pte_cache_hit, pte_cache_data) = { - val size = 1 << log2Up(pgLevels * 2) - val plru = new PseudoLRU(size) - val valid = RegInit(0.U(size.W)) - val tags = Reg(Vec(size, UInt(width = paddrBits))) - val data = Reg(Vec(size, UInt(width = ppnBits))) - - val hits = tags.map(_ === pte_addr).asUInt & valid - val hit = hits.orR - when (mem_resp_valid && traverse && !hit && !invalidated) { + def makePTECache(s2: Boolean): (Bool, UInt) = { + val plru = new PseudoLRU(coreParams.nPTECacheEntries) + val valid = RegInit(0.U(coreParams.nPTECacheEntries.W)) + val tags = Reg(Vec(coreParams.nPTECacheEntries, UInt((if (usingHypervisor) 1 + vaddrBits else paddrBits).W))) + val data = Reg(Vec(coreParams.nPTECacheEntries, UInt((if (usingHypervisor && s2) vpnBits else ppnBits).W))) + val can_hit = + if (s2) count === r_hgatp_initial_count && aux_count < pgLevels-1 && r_req.vstage1 && stage2 && !stage2_final + else count < pgLevels-1 && Mux(r_req.vstage1, stage2, !r_req.stage2) + val can_refill = + if (s2) do_both_stages && !stage2 && !stage2_final + else can_hit + val tag = + if (s2) Cat(true.B, stage2_pte_cache_addr) + else Cat(r_req.vstage1, pte_cache_addr) + + val hits = tags.map(_ === tag).asUInt & valid + val hit = hits.orR && can_hit + when (mem_resp_valid && traverse && can_refill && !hits.orR && !invalidated) { val r = Mux(valid.andR, plru.way, PriorityEncoder(~valid)) valid := valid | UIntToOH(r) - tags(r) := pte_addr + tags(r) := tag data(r) := pte.ppn + plru.access(r) } when (hit && state === s_req) { plru.access(OHToUInt(hits)) } - when (io.dpath.sfence.valid && !io.dpath.sfence.bits.rs1) { valid := 0.U } + when (io.dpath.sfence.valid && (!io.dpath.sfence.bits.rs1 || usingHypervisor && io.dpath.sfence.bits.hg)) { valid := 0.U } - for (i <- 0 until pgLevels-1) - ccover(hit && state === s_req && count === i, s"PTE_CACHE_HIT_L$i", s"PTE cache hit, level $i") + val lcount = if (s2) aux_count else count + for (i <- 0 until pgLevels-1) { + ccover(hit && state === s_req && lcount === i, s"PTE_CACHE_HIT_L$i", s"PTE cache hit, level $i") + } - (hit && count < pgLevels-1, Mux1H(hits, data)) + (hit, Mux1H(hits, data)) } + + val (pte_cache_hit, pte_cache_data) = makePTECache(false) + val (stage2_pte_cache_hit, stage2_pte_cache_data) = makePTECache(true) + val pte_hit = RegNext(false.B) io.dpath.perf.pte_miss := false io.dpath.perf.pte_hit := pte_hit && (state === s_req) && !io.dpath.perf.l2hit @@ -224,7 +283,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( val g = Reg(Vec(coreParams.nL2TLBWays, UInt(width = nL2TLBSets))) val valid = RegInit(Vec(Seq.fill(coreParams.nL2TLBWays)(0.U(nL2TLBSets.W)))) - val (r_tag, r_idx) = Split(r_req.addr, idxBits) + val (r_tag, r_idx) = Split(Cat(r_req.vstage1, r_req.addr(maxSVAddrBits-pgIdxBits-1, 0)), idxBits) val r_valid_vec = valid.map(_(r_idx)).asUInt val r_valid_vec_q = Reg(UInt(coreParams.nL2TLBWays.W)) val r_l2_plru_way = Reg(UInt(log2Ceil(coreParams.nL2TLBWays max 1).W)) @@ -246,16 +305,20 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( } } } + when (io.dpath.sfence.valid) { + val hg = usingHypervisor && io.dpath.sfence.bits.hg for (way <- 0 until coreParams.nL2TLBWays) { valid(way) := - Mux(io.dpath.sfence.bits.rs1, valid(way) & ~UIntToOH(io.dpath.sfence.bits.addr(idxBits+pgIdxBits-1, pgIdxBits)), - Mux(io.dpath.sfence.bits.rs2, valid(way) & g(way), 0.U)) + Mux(!hg && io.dpath.sfence.bits.rs1, valid(way) & ~UIntToOH(io.dpath.sfence.bits.addr(idxBits+pgIdxBits-1, pgIdxBits)), + Mux(!hg && io.dpath.sfence.bits.rs2, valid(way) & g(way), + 0.U)) } } val s0_valid = !l2_refill && arb.io.out.fire() - val s1_valid = RegNext(s0_valid && arb.io.out.bits.valid) + val s0_suitable = arb.io.out.bits.bits.vstage1 === arb.io.out.bits.bits.stage2 && !arb.io.out.bits.bits.need_gpa + val s1_valid = RegNext(s0_valid && s0_suitable && arb.io.out.bits.valid) val s2_valid = RegNext(s1_valid) val s1_rdata = ram.read(arb.io.out.bits.bits.addr(idxBits-1, 0), s0_valid) val s2_rdata = s1_rdata.map(s1_rdway => code.decode(RegEnable(s1_rdway, s1_valid))) @@ -298,50 +361,93 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( io.mem.req.bits.addr := pte_addr io.mem.req.bits.idx.foreach(_ := pte_addr) io.mem.req.bits.dprv := PRV.S.U // PTW accesses are S-mode by definition + io.mem.req.bits.dv := do_both_stages && !stage2 io.mem.s1_kill := l2_hit || state =/= s_wait1 io.mem.s2_kill := Bool(false) val pageGranularityPMPs = pmpGranularity >= (1 << pgIdxBits) + require(!usingHypervisor || pageGranularityPMPs, s"hypervisor requires pmpGranularity >= ${1< val pgSize = BigInt(1) << (pgIdxBits + ((pgLevels - 1 - i) * pgLevelBits)) if (pageGranularityPMPs && i == pgLevels - 1) { require(TLBPageLookup.homogeneous(edge.manager.managers, pgSize), s"All memory regions must be $pgSize-byte aligned") true.B } else { - TLBPageLookup(edge.manager.managers, xLen, p(CacheBlockBytes), pgSize)(pte_addr).homogeneous + TLBPageLookup(edge.manager.managers, xLen, p(CacheBlockBytes), pgSize)(r_pte.ppn << pgIdxBits).homogeneous } } val pmaHomogeneous = pmaPgLevelHomogeneous(count) - val pmpHomogeneous = new PMPHomogeneityChecker(io.dpath.pmp).apply(pte_addr >> pgIdxBits << pgIdxBits, count) + val pmpHomogeneous = new PMPHomogeneityChecker(io.dpath.pmp).apply(r_pte.ppn << pgIdxBits, count) val homogeneous = pmaHomogeneous && pmpHomogeneous for (i <- 0 until io.requestor.size) { io.requestor(i).resp.valid := resp_valid(i) io.requestor(i).resp.bits.ae_ptw := resp_ae_ptw io.requestor(i).resp.bits.ae_final := resp_ae_final + io.requestor(i).resp.bits.gf := resp_gf + io.requestor(i).resp.bits.hr := resp_hr + io.requestor(i).resp.bits.hw := resp_hw + io.requestor(i).resp.bits.hx := resp_hx io.requestor(i).resp.bits.pte := r_pte - io.requestor(i).resp.bits.level := count + io.requestor(i).resp.bits.level := max_count io.requestor(i).resp.bits.homogeneous := homogeneous || pageGranularityPMPs io.requestor(i).resp.bits.fragmented_superpage := resp_fragmented_superpage && pageGranularityPMPs + io.requestor(i).resp.bits.gpa.valid := r_req.need_gpa + io.requestor(i).resp.bits.gpa.bits := + Cat(Mux(!stage2_final || !r_req.vstage1 || aux_count === (pgLevels - 1), aux_pte.ppn, makeFragmentedSuperpagePPN(aux_pte.ppn)(aux_count)), gpa_pgoff) io.requestor(i).ptbr := io.dpath.ptbr + io.requestor(i).hgatp := io.dpath.hgatp + io.requestor(i).vsatp := io.dpath.vsatp io.requestor(i).customCSRs := io.dpath.customCSRs io.requestor(i).status := io.dpath.status + io.requestor(i).hstatus := io.dpath.hstatus + io.requestor(i).gstatus := io.dpath.gstatus io.requestor(i).pmp := io.dpath.pmp } // control state machine val next_state = Wire(init = state) state := OptimizationBarrier(next_state) + val do_switch = Wire(init = false.B) switch (state) { is (s_ready) { when (arb.io.out.fire()) { + val satp_initial_count = pgLevels - minPgLevels - satp.additionalPgLevels + val vsatp_initial_count = pgLevels - minPgLevels - io.dpath.vsatp.additionalPgLevels + val hgatp_initial_count = pgLevels - minPgLevels - io.dpath.hgatp.additionalPgLevels + + r_req := arb.io.out.bits.bits + r_req_dest := arb.io.chosen next_state := Mux(arb.io.out.bits.valid, s_req, s_ready) + stage2 := arb.io.out.bits.bits.stage2 + stage2_final := arb.io.out.bits.bits.stage2 && !arb.io.out.bits.bits.vstage1 + count := Mux(arb.io.out.bits.bits.stage2, hgatp_initial_count, satp_initial_count) + aux_count := Mux(arb.io.out.bits.bits.vstage1, vsatp_initial_count, 0.U) + aux_pte.ppn := Mux(arb.io.out.bits.bits.vstage1, io.dpath.vsatp.ppn, arb.io.out.bits.bits.addr) + resp_ae_ptw := false + resp_ae_final := false + resp_gf := false + resp_hr := true + resp_hw := true + resp_hx := true + resp_fragmented_superpage := false + r_hgatp := io.dpath.hgatp + + assert(!arb.io.out.bits.bits.need_gpa || arb.io.out.bits.bits.stage2) } - count := pgLevels - minPgLevels - io.dpath.ptbr.additionalPgLevels } is (s_req) { - when (pte_cache_hit) { + when(stage2 && count === r_hgatp_initial_count) { + gpa_pgoff := Mux(aux_count === pgLevels-1, r_req.addr << (xLen/8).log2, stage2_pte_cache_addr) + } + + when (stage2_pte_cache_hit) { + aux_count := aux_count + 1 + aux_pte.ppn := stage2_pte_cache_data + pte_hit := true + }.elsewhen (pte_cache_hit) { count := count + 1 pte_hit := true }.otherwise { @@ -357,7 +463,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( io.dpath.perf.pte_miss := count < pgLevels-1 when (io.mem.s2_xcpt.ae.ld) { resp_ae_ptw := true - resp_ae_final := false next_state := s_ready resp_valid(r_req_dest) := true } @@ -365,51 +470,76 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( is (s_fragment_superpage) { next_state := s_ready resp_valid(r_req_dest) := true - resp_ae_ptw := false - resp_ae_final := false when (!homogeneous) { count := pgLevels-1 resp_fragmented_superpage := true } + when (do_both_stages) { + resp_fragmented_superpage := true + } } } - def makePTE(ppn: UInt, default: PTE) = { - val pte = Wire(init = default) - pte.ppn := ppn - pte + val merged_pte = { + val superpage_masks = (0 until pgLevels).map(i => ((BigInt(1) << pte.ppn.getWidth) - (BigInt(1) << (pgLevels-1-i)*pgLevelBits)).U) + val superpage_mask = superpage_masks(Mux(stage2_final, max_count, (pgLevels-1).U)) + val stage1_ppns = (0 until pgLevels-1).map(i => Cat(pte.ppn(pte.ppn.getWidth-1, (pgLevels-i-1)*pgLevelBits), aux_pte.ppn((pgLevels-i-1)*pgLevelBits-1,0))) :+ pte.ppn + val stage1_ppn = stage1_ppns(count) + makePTE(stage1_ppn & superpage_mask, aux_pte) } + r_pte := OptimizationBarrier( - Mux(mem_resp_valid, pte, Mux(l2_hit && !l2_error, l2_pte, - Mux(state === s_fragment_superpage && !homogeneous, makePTE(fragmented_superpage_ppn, r_pte), - Mux(state === s_req && pte_cache_hit, makePTE(pte_cache_data, l2_pte), - Mux(arb.io.out.fire(), makePTE(io.dpath.ptbr.ppn, r_pte), - r_pte)))))) + Mux(state === s_req && !stage2_pte_cache_hit && pte_cache_hit, makePTE(pte_cache_data, l2_pte), + Mux(do_switch, makeHypervisorRootPTE(r_hgatp, pte.ppn, r_pte), + Mux(mem_resp_valid, Mux(!traverse && (r_req.vstage1 && stage2), merged_pte, pte), + Mux(state === s_fragment_superpage && !homogeneous, makePTE(makeFragmentedSuperpagePPN(r_pte.ppn)(count), r_pte), + Mux(arb.io.out.fire(), Mux(arb.io.out.bits.bits.stage2, makeHypervisorRootPTE(io.dpath.hgatp, io.dpath.vsatp.ppn, r_pte), makePTE(satp.ppn, r_pte)), + r_pte))))))) when (l2_hit && !l2_error) { assert(state === s_req || state === s_wait1) next_state := s_ready resp_valid(r_req_dest) := true - resp_ae_ptw := false - resp_ae_final := false count := pgLevels-1 } when (mem_resp_valid) { assert(state === s_wait3) + next_state := s_req when (traverse) { - next_state := s_req + when (do_both_stages && !stage2) { do_switch := true } count := count + 1 }.otherwise { - l2_refill := pte.v && !invalid_paddr && count === pgLevels-1 + val gf = stage2 && !stage2_final && !pte.ur() val ae = pte.v && invalid_paddr - resp_ae_final := ae - resp_ae_ptw := false - when (pageGranularityPMPs && count =/= pgLevels-1 && !ae) { - next_state := s_fragment_superpage + val success = pte.v && !ae && !gf + + when (do_both_stages && !stage2_final && success) { + when (stage2) { + stage2 := false + count := aux_count + }.otherwise { + stage2_final := true + do_switch := true + } }.otherwise { - next_state := s_ready - resp_valid(r_req_dest) := true + l2_refill := success && count === pgLevels-1 && !r_req.need_gpa && + (!r_req.vstage1 && !r_req.stage2 || + do_both_stages && aux_count === pgLevels-1 && pte.isFullPerm()) + count := max_count + + when (pageGranularityPMPs && !(count === pgLevels-1 && (!do_both_stages || aux_count === pgLevels-1))) { + next_state := s_fragment_superpage + }.otherwise { + next_state := s_ready + resp_valid(r_req_dest) := true + } + + resp_ae_final := ae + resp_gf := gf + resp_hr := !stage2 || !gf && pte.ur() + resp_hw := !stage2 || !gf && pte.uw() + resp_hx := !stage2 || !gf && pte.ux() } } } @@ -418,6 +548,16 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( next_state := s_req } + when (do_switch) { + aux_count := Mux(traverse, count + 1, count) + count := r_hgatp_initial_count + aux_pte := Mux(traverse, pte, { + val s1_ppns = (0 until pgLevels-1).map(i => Cat(pte.ppn(pte.ppn.getWidth-1, (pgLevels-i-1)*pgLevelBits), r_req.addr(((pgLevels-i-1)*pgLevelBits min vpnBits)-1,0).padTo((pgLevels-i-1)*pgLevelBits))) :+ pte.ppn + makePTE(s1_ppns(count), pte) + }) + stage2 := true + } + for (i <- 0 until pgLevels) { val leaf = mem_resp_valid && !traverse && count === i ccover(leaf && pte.v && !invalid_paddr, s"L$i", s"successful page-table access, level $i") @@ -434,6 +574,21 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( private def ccover(cond: Bool, label: String, desc: String)(implicit sourceInfo: SourceInfo) = if (usingVM) cover(cond, s"PTW_$label", "MemorySystem;;" + desc) + + private def makePTE(ppn: UInt, default: PTE) = { + val pte = Wire(init = default) + pte.ppn := ppn + pte + } + + private def makeHypervisorRootPTE(hgatp: PTBR, vpn: UInt, default: PTE) = { + val count = pgLevels - minPgLevels - hgatp.additionalPgLevels + val idxs = (0 to pgLevels-minPgLevels).map(i => (vpn >> (pgLevels-i)*pgLevelBits)) + val lsbs = Wire(t = UInt(maxHypervisorExtraAddrBits.W), init = idxs(count)) + val pte = Wire(init = default) + pte.ppn := Cat(hgatp.ppn >> maxHypervisorExtraAddrBits, lsbs) + pte + } } /** Mix-ins for constructing tiles that might have a PTW */ diff --git a/src/main/scala/rocket/RocketCore.scala b/src/main/scala/rocket/RocketCore.scala index 4806e0a81c..b5a551db1b 100644 --- a/src/main/scala/rocket/RocketCore.scala +++ b/src/main/scala/rocket/RocketCore.scala @@ -19,6 +19,7 @@ case class RocketCoreParams( useVM: Boolean = true, useUser: Boolean = false, useSupervisor: Boolean = false, + useHypervisor: Boolean = false, useDebug: Boolean = true, useAtomics: Boolean = true, useAtomicsOnlyForIO: Boolean = false, @@ -38,6 +39,7 @@ case class RocketCoreParams( misaWritable: Boolean = true, nL2TLBEntries: Int = 0, nL2TLBWays: Int = 1, + nPTECacheEntries: Int = 8, mtvecInit: Option[BigInt] = Some(BigInt(0)), mtvecWritable: Boolean = true, fastLoadWord: Boolean = true, @@ -51,7 +53,7 @@ case class RocketCoreParams( ) extends CoreParams { val lgPauseCycles = 5 val haveFSDirty = false - val pmpGranularity: Int = 4 + val pmpGranularity: Int = if (useHypervisor) 4096 else 4 val fetchWidth: Int = if (useCompressed) 2 else 1 // fetchWidth doubled, but coreInstBytes halved, for RVC: val decodeWidth: Int = fetchWidth / (if (useCompressed) 2 else 1) @@ -180,6 +182,8 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) (if (xLen == 32) new I32Decode else new I64Decode) +: (usingVM.option(new SVMDecode)) ++: (usingSupervisor.option(new SDecode)) ++: + (usingHypervisor.option(new HypervisorDecode)) ++: + ((usingHypervisor && (xLen == 64)).option(new Hypervisor64Decode)) ++: (usingDebug.option(new DebugDecode)) ++: (usingNMI.option(new NMIDecode)) ++: Seq(new FenceIDecode(tile.dcache.flushOnFenceI)) ++: @@ -202,6 +206,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) val ex_reg_replay = Reg(Bool()) val ex_reg_pc = Reg(UInt()) val ex_reg_mem_size = Reg(UInt()) + val ex_reg_hls = Reg(Bool()) val ex_reg_inst = Reg(Bits()) val ex_reg_raw_inst = Reg(UInt()) val ex_scie_unpipelined = Reg(Bool()) @@ -223,6 +228,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) val mem_reg_pc = Reg(UInt()) val mem_reg_inst = Reg(Bits()) val mem_reg_mem_size = Reg(UInt()) + val mem_reg_hls_or_dv = Reg(Bool()) val mem_reg_raw_inst = Reg(UInt()) val mem_scie_unpipelined = Reg(Bool()) val mem_scie_pipelined = Reg(Bool()) @@ -240,6 +246,9 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) val wb_reg_sfence = Reg(Bool()) val wb_reg_pc = Reg(UInt()) val wb_reg_mem_size = Reg(UInt()) + val wb_reg_hls_or_dv = Reg(Bool()) + val wb_reg_hfence_v = Reg(Bool()) + val wb_reg_hfence_g = Reg(Bool()) val wb_reg_inst = Reg(Bits()) val wb_reg_raw_inst = Reg(UInt()) val wb_reg_wdata = Reg(Bits()) @@ -283,9 +292,8 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) val id_csr_en = id_ctrl.csr.isOneOf(CSR.S, CSR.C, CSR.W) val id_system_insn = id_ctrl.csr === CSR.I val id_csr_ren = id_ctrl.csr.isOneOf(CSR.S, CSR.C) && id_expanded_inst(0).rs1 === UInt(0) - val id_csr = Mux(id_csr_ren, CSR.R, id_ctrl.csr) - val id_sfence = id_ctrl.mem && id_ctrl.mem_cmd === M_SFENCE - val id_csr_flush = id_sfence || id_system_insn || (id_csr_en && !id_csr_ren && csr.io.decode(0).write_flush) + val id_csr = Mux(id_system_insn && id_ctrl.mem, CSR.N, Mux(id_csr_ren, CSR.R, id_ctrl.csr)) + val id_csr_flush = id_system_insn || (id_csr_en && !id_csr_ren && csr.io.decode(0).write_flush) val id_scie_decoder = if (!rocketParams.useSCIE) Wire(new SCIEDecoderInterface) else { val d = Module(new SCIEDecoder) @@ -305,7 +313,10 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) id_ctrl.rocc && csr.io.decode(0).rocc_illegal || id_ctrl.scie && !(id_scie_decoder.unpipelined || id_scie_decoder.pipelined) || id_csr_en && (csr.io.decode(0).read_illegal || !id_csr_ren && csr.io.decode(0).write_illegal) || - !ibuf.io.inst(0).bits.rvc && ((id_sfence || id_system_insn) && csr.io.decode(0).system_illegal) + !ibuf.io.inst(0).bits.rvc && (id_system_insn && csr.io.decode(0).system_illegal) + val id_virtual_insn = id_ctrl.legal && + ((id_csr_en && !(!id_csr_ren && csr.io.decode(0).write_illegal) && csr.io.decode(0).virtual_access_illegal) || + (!ibuf.io.inst(0).bits.rvc && id_system_insn && csr.io.decode(0).virtual_system_illegal)) // stall decode for fences (now, for AMO.rl; later, for AMO.aq and FENCE) val id_amo_aq = id_inst(0)(26) val id_amo_rl = id_inst(0)(25) @@ -335,9 +346,12 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) (bpu.io.debug_if, UInt(CSR.debugTriggerCause)), (bpu.io.xcpt_if, UInt(Causes.breakpoint)), (id_xcpt0.pf.inst, UInt(Causes.fetch_page_fault)), + (id_xcpt0.gf.inst, UInt(Causes.fetch_guest_page_fault)), (id_xcpt0.ae.inst, UInt(Causes.fetch_access)), (id_xcpt1.pf.inst, UInt(Causes.fetch_page_fault)), + (id_xcpt1.gf.inst, UInt(Causes.fetch_guest_page_fault)), (id_xcpt1.ae.inst, UInt(Causes.fetch_access)), + (id_virtual_insn, UInt(Causes.virtual_instruction)), (id_illegal_insn, UInt(Causes.illegal_instruction)))) val idCoverCauses = List( @@ -450,10 +464,14 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) } ex_reg_flush_pipe := id_ctrl.fence_i || id_csr_flush ex_reg_load_use := id_load_use - ex_reg_mem_size := id_inst(0)(13, 12) - when (id_ctrl.mem_cmd.isOneOf(M_SFENCE, M_FLUSH_ALL)) { + ex_reg_hls := usingHypervisor && id_system_insn && id_ctrl.mem_cmd.isOneOf(M_XRD, M_XWR, M_HLVX) + ex_reg_mem_size := Mux(usingHypervisor && id_system_insn, id_inst(0)(27, 26), id_inst(0)(13, 12)) + when (id_ctrl.mem_cmd.isOneOf(M_SFENCE, M_HFENCEV, M_HFENCEG, M_FLUSH_ALL)) { ex_reg_mem_size := Cat(id_raddr2 =/= UInt(0), id_raddr1 =/= UInt(0)) } + when (id_ctrl.mem_cmd === M_SFENCE && csr.io.status.v) { + ex_ctrl.mem_cmd := M_HFENCEV + } if (tile.dcache.flushOnFenceI) { when (id_ctrl.fence_i) { ex_reg_mem_size := 0 @@ -470,7 +488,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) ex_reg_rs_msb(i) := id_rs(i) >> log2Ceil(bypass_sources.size) } } - when (id_illegal_insn) { + when (id_illegal_insn || id_virtual_insn) { val inst = Mux(ibuf.io.inst(0).bits.rvc, id_raw_inst(0)(15, 0), id_raw_inst(0)) ex_reg_rs_bypass(0) := false ex_reg_rs_lsb(0) := inst(log2Ceil(bypass_sources.size)-1, 0) @@ -496,7 +514,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) val ctrl_killx = take_pc_mem_wb || replay_ex || !ex_reg_valid // detect 2-cycle load-use delay for LB/LH/SC val ex_slow_bypass = ex_ctrl.mem_cmd === M_XSC || ex_reg_mem_size < 2 - val ex_sfence = Bool(usingVM) && ex_ctrl.mem && ex_ctrl.mem_cmd === M_SFENCE + val ex_sfence = Bool(usingVM) && ex_ctrl.mem && (ex_ctrl.mem_cmd === M_SFENCE || ex_ctrl.mem_cmd === M_HFENCEV || ex_ctrl.mem_cmd === M_HFENCEG) val (ex_xcpt, ex_cause) = checkExceptions(List( (ex_reg_xcpt_interrupt || ex_reg_xcpt, ex_reg_cause))) @@ -520,7 +538,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) val mem_cfi_taken = (mem_ctrl.branch && mem_br_taken) || mem_ctrl.jalr || mem_ctrl.jal val mem_direction_misprediction = mem_ctrl.branch && mem_br_taken =/= (usingBTB && mem_reg_btb_resp.taken) val mem_misprediction = if (usingBTB) mem_wrong_npc else mem_cfi_taken - take_pc_mem := mem_reg_valid && (mem_misprediction || mem_reg_sfence) + take_pc_mem := mem_reg_valid && !mem_reg_xcpt && (mem_misprediction || mem_reg_sfence) mem_reg_valid := !ctrl_killx mem_reg_replay := !take_pc_mem_wb && replay_ex @@ -548,6 +566,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) mem_reg_inst := ex_reg_inst mem_reg_raw_inst := ex_reg_raw_inst mem_reg_mem_size := ex_reg_mem_size + mem_reg_hls_or_dv := io.dmem.req.bits.dv mem_reg_pc := ex_reg_pc mem_reg_wdata := Mux(ex_scie_unpipelined, ex_scie_unpipelined_wdata, alu.io.out) mem_br_taken := alu.io.cmp_out @@ -605,6 +624,9 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) wb_reg_inst := mem_reg_inst wb_reg_raw_inst := mem_reg_raw_inst wb_reg_mem_size := mem_reg_mem_size + wb_reg_hls_or_dv := mem_reg_hls_or_dv + wb_reg_hfence_v := mem_ctrl.mem_cmd === M_HFENCEV + wb_reg_hfence_g := mem_ctrl.mem_cmd === M_HFENCEG wb_reg_pc := mem_reg_pc wb_reg_wphit := mem_reg_wphit | bpu.io.bpwatch.map { bpw => (bpw.rvalid(0) && mem_reg_load) || (bpw.wvalid(0) && mem_reg_store) } @@ -616,6 +638,8 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.ma.ld, UInt(Causes.misaligned_load)), (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.pf.st, UInt(Causes.store_page_fault)), (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.pf.ld, UInt(Causes.load_page_fault)), + (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.gf.st, UInt(Causes.store_guest_page_fault)), + (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.gf.ld, UInt(Causes.load_guest_page_fault)), (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.ae.st, UInt(Causes.store_access)), (wb_reg_valid && wb_ctrl.mem && io.dmem.s2_xcpt.ae.ld, UInt(Causes.load_access)) )) @@ -680,7 +704,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) // hook up control/status regfile csr.io.ungated_clock := clock - csr.io.decode(0).csr := id_raw_inst(0)(31,20) + csr.io.decode(0).inst := id_inst(0) csr.io.exception := wb_xcpt csr.io.cause := wb_cause csr.io.retire := wb_valid @@ -693,14 +717,30 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) io.fpu.hartid := io.hartid csr.io.rocc_interrupt := io.rocc.interrupt csr.io.pc := wb_reg_pc - val tval_valid = wb_xcpt && wb_cause.isOneOf(Causes.illegal_instruction, Causes.breakpoint, - Causes.misaligned_load, Causes.misaligned_store, - Causes.load_access, Causes.store_access, Causes.fetch_access, - Causes.load_page_fault, Causes.store_page_fault, Causes.fetch_page_fault) + val tval_dmem_addr = !wb_reg_xcpt + val tval_any_addr = tval_dmem_addr || + wb_reg_cause.isOneOf(Causes.breakpoint, Causes.fetch_access, Causes.fetch_page_fault, Causes.fetch_guest_page_fault) + val tval_inst = wb_reg_cause === Causes.illegal_instruction + val tval_valid = wb_xcpt && (tval_any_addr || tval_inst) + csr.io.gva := wb_xcpt && (tval_any_addr && csr.io.status.v || tval_dmem_addr && wb_reg_hls_or_dv) csr.io.tval := Mux(tval_valid, encodeVirtualAddress(wb_reg_wdata, wb_reg_wdata), 0.U) + csr.io.htval := { + val htval_valid_imem = wb_reg_xcpt && wb_reg_cause === Causes.fetch_guest_page_fault + val htval_imem = Mux(htval_valid_imem, io.imem.gpa.bits, 0.U) + assert(!htval_valid_imem || io.imem.gpa.valid) + + val htval_valid_dmem = wb_xcpt && tval_dmem_addr && io.dmem.s2_xcpt.gf.asUInt.orR && !(io.dmem.s2_xcpt.ma.asUInt.orR || io.dmem.s2_xcpt.pf.asUInt.orR) + val htval_dmem = Mux(htval_valid_dmem, io.dmem.s2_gpa, 0.U) + + (htval_dmem | htval_imem) >> hypervisorExtraAddrBits + } io.ptw.ptbr := csr.io.ptbr + io.ptw.hgatp := csr.io.hgatp + io.ptw.vsatp := csr.io.vsatp (io.ptw.customCSRs.csrs zip csr.io.customCSRs).map { case (lhs, rhs) => lhs := rhs } io.ptw.status := csr.io.status + io.ptw.hstatus := csr.io.hstatus + io.ptw.gstatus := csr.io.gstatus io.ptw.pmp := csr.io.pmp csr.io.rw.addr := wb_reg_inst(31,20) csr.io.rw.cmd := CSR.maskCmd(wb_reg_valid, wb_ctrl.csr) @@ -799,6 +839,8 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) io.imem.sfence.bits.rs2 := wb_reg_mem_size(1) io.imem.sfence.bits.addr := wb_reg_wdata io.imem.sfence.bits.asid := wb_reg_rs2 + io.imem.sfence.bits.hv := wb_reg_hfence_v + io.imem.sfence.bits.hg := wb_reg_hfence_g io.ptw.sfence := io.imem.sfence ibuf.io.inst(0).ready := !ctrl_stalld @@ -839,11 +881,12 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p) io.dmem.req.bits.tag := ex_dcache_tag io.dmem.req.bits.cmd := ex_ctrl.mem_cmd io.dmem.req.bits.size := ex_reg_mem_size - io.dmem.req.bits.signed := !ex_reg_inst(14) + io.dmem.req.bits.signed := !Mux(ex_reg_hls, ex_reg_inst(20), ex_reg_inst(14)) io.dmem.req.bits.phys := Bool(false) io.dmem.req.bits.addr := encodeVirtualAddress(ex_rs(0), alu.io.adder_out) io.dmem.req.bits.idx.foreach(_ := io.dmem.req.bits.addr) - io.dmem.req.bits.dprv := csr.io.status.dprv + io.dmem.req.bits.dprv := Mux(ex_reg_hls, csr.io.hstatus.spvp, csr.io.status.dprv) + io.dmem.req.bits.dv := ex_reg_hls || csr.io.status.dv io.dmem.s1_data.data := (if (fLen == 0) mem_reg_rs2 else Mux(mem_ctrl.fp, Fill((xLen max fLen) / fLen, io.fpu.store_data), mem_reg_rs2)) io.dmem.s1_kill := killm_common || mem_ldst_xcpt || fpu_kill_mem io.dmem.s2_kill := false diff --git a/src/main/scala/rocket/TLB.scala b/src/main/scala/rocket/TLB.scala index 792e822869..f7d5102362 100644 --- a/src/main/scala/rocket/TLB.scala +++ b/src/main/scala/rocket/TLB.scala @@ -18,12 +18,16 @@ import chisel3.internal.sourceinfo.SourceInfo case object PgLevels extends Field[Int](2) case object ASIdBits extends Field[Int](0) +case object VMIdBits extends Field[Int](0) + class SFenceReq(implicit p: Parameters) extends CoreBundle()(p) { val rs1 = Bool() val rs2 = Bool() val addr = UInt(width = vaddrBits) val asid = UInt(width = asIdBits max 1) // TODO zero-width + val hv = Bool() + val hg = Bool() } class TLBReq(lgMaxSize: Int)(implicit p: Parameters) extends CoreBundle()(p) { @@ -31,21 +35,27 @@ class TLBReq(lgMaxSize: Int)(implicit p: Parameters) extends CoreBundle()(p) { val passthrough = Bool() val size = UInt(width = log2Ceil(lgMaxSize + 1)) val cmd = Bits(width = M_SZ) + val prv = UInt(PRV.SZ.W) + val v = Bool() override def cloneType = new TLBReq(lgMaxSize).asInstanceOf[this.type] } -class TLBExceptions extends Bundle { +class TLBExceptions(implicit p: Parameters) extends CoreBundle()(p) { val ld = Bool() val st = Bool() val inst = Bool() + val v = Bool() } class TLBResp(implicit p: Parameters) extends CoreBundle()(p) { // lookup responses val miss = Bool() val paddr = UInt(width = paddrBits) + val gpa = UInt(vaddrBitsExtended.W) + val gpa_is_pte = Bool() val pf = new TLBExceptions + val gf = new TLBExceptions val ae = new TLBExceptions val ma = new TLBExceptions val cacheable = Bool() @@ -55,13 +65,18 @@ class TLBResp(implicit p: Parameters) extends CoreBundle()(p) { class TLBEntryData(implicit p: Parameters) extends CoreBundle()(p) { val ppn = UInt(width = ppnBits) + val v = Bool() val u = Bool() val g = Bool() val ae_ptw = Bool() val ae_final = Bool() + val gf = Bool() val sw = Bool() val sx = Bool() val sr = Bool() + val hw = Bool() + val hx = Bool() + val hr = Bool() val pw = Bool() val px = Bool() val pr = Bool() @@ -87,26 +102,29 @@ class TLBEntry(val nSectors: Int, val superpage: Boolean, val superpageOnly: Boo def getData(vpn: UInt) = OptimizationBarrier(data(sectorIdx(vpn)).asTypeOf(new TLBEntryData)) def sectorHit(vpn: UInt) = valid.orR && sectorTagMatch(vpn) def sectorTagMatch(vpn: UInt) = ((tag ^ vpn) >> nSectors.log2) === 0 - def hit(vpn: UInt) = { + def hit(vpn: UInt, virtual: Bool): Bool = { if (superpage && usingVM) { - var tagMatch = valid.head + var tagMatch = valid.head && entry_data.head.v === virtual for (j <- 0 until pgLevels) { - val base = vpnBits - (j + 1) * pgLevelBits + val base = (pgLevels - 1 - j) * pgLevelBits + val n = pgLevelBits + (if (j == 0) hypervisorExtraAddrBits else 0) val ignore = level < j || superpageOnly && j == pgLevels - 1 - tagMatch = tagMatch && (ignore || tag(base + pgLevelBits - 1, base) === vpn(base + pgLevelBits - 1, base)) + tagMatch = tagMatch && (ignore || (tag ^ vpn)(base + n - 1, base) === 0) } tagMatch } else { val idx = sectorIdx(vpn) - valid(idx) && sectorTagMatch(vpn) + val virtualMatch = entry_data.map(_.v === virtual) + valid(idx) && virtualMatch(idx) && sectorTagMatch(vpn) } } def ppn(vpn: UInt, data: TLBEntryData) = { + val supervisorVPNBits = pgLevels * pgLevelBits if (superpage && usingVM) { var res = data.ppn >> pgLevelBits*(pgLevels - 1) for (j <- 1 until pgLevels) { val ignore = level < j || superpageOnly && j == pgLevels - 1 - res = Cat(res, (Mux(ignore, vpn, 0.U) | data.ppn)(vpnBits - j*pgLevelBits - 1, vpnBits - (j + 1)*pgLevelBits)) + res = Cat(res, (Mux(ignore, vpn, 0.U) | data.ppn)(supervisorVPNBits - j*pgLevelBits - 1, supervisorVPNBits - (j + 1)*pgLevelBits)) } res } else { @@ -124,23 +142,30 @@ class TLBEntry(val nSectors: Int, val superpage: Boolean, val superpageOnly: Boo } def invalidate(): Unit = { valid.foreach(_ := false) } - def invalidateVPN(vpn: UInt): Unit = { + def invalidate(virtual: Bool, guestPhys: Bool): Unit = { + for ((v, e) <- valid zip entry_data) + when (e.v === virtual || e.v && guestPhys) { v := false } + } + def invalidateVPN(vpn: UInt, virtual: Bool): Unit = { if (superpage) { - when (hit(vpn)) { invalidate() } + when (hit(vpn, virtual)) { invalidate() } } else { - when (sectorTagMatch(vpn)) { valid(sectorIdx(vpn)) := false } + when (sectorTagMatch(vpn)) { + for (((v, e), i) <- (valid zip entry_data).zipWithIndex) + when (e.v === virtual && i === sectorIdx(vpn)) { v := false } + } // For fragmented superpage mappings, we assume the worst (largest) // case, and zap entries whose most-significant VPNs match when (((tag ^ vpn) >> (pgLevelBits * (pgLevels - 1))) === 0) { for ((v, e) <- valid zip entry_data) - when (e.fragmented_superpage) { v := false } + when (e.v === virtual && e.fragmented_superpage) { v := false } } } } - def invalidateNonGlobal(): Unit = { + def invalidateNonGlobal(virtual: Bool): Unit = { for ((v, e) <- valid zip entry_data) - when (!e.g) { v := false } + when (e.v === virtual && !e.g) { v := false } } } @@ -173,14 +198,30 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T val state = Reg(init=s_ready) val r_refill_tag = Reg(UInt(width = vpnBits)) val r_superpage_repl_addr = Reg(UInt(log2Ceil(superpage_entries.size).W)) - val r_sectored_repl_addr = Reg(UInt(log2Ceil(sectored_entries(0).size).W)) - val r_sectored_hit_addr = Reg(UInt(log2Ceil(sectored_entries(0).size).W)) - val r_sectored_hit = Reg(Bool()) - - val priv = if (instruction) io.ptw.status.prv else io.ptw.status.dprv + val r_sectored_repl_addr = Reg(UInt(log2Ceil(sectored_entries.head.size).W)) + val r_sectored_hit = Reg(Valid(UInt(log2Ceil(sectored_entries.head.size).W))) + val r_superpage_hit = Reg(Valid(UInt(log2Ceil(superpage_entries.size).W))) + val r_vstage1_en = Reg(Bool()) + val r_stage2_en = Reg(Bool()) + val r_need_gpa = Reg(Bool()) + val r_gpa_valid = Reg(Bool()) + val r_gpa = Reg(UInt(vaddrBits.W)) + val r_gpa_vpn = Reg(UInt(vpnBits.W)) + val r_gpa_gf = Reg(Bool()) + + val priv = io.req.bits.prv + val priv_v = usingHypervisor && io.req.bits.v val priv_s = priv(0) val priv_uses_vm = priv <= PRV.S - val vm_enabled = Bool(usingVM) && io.ptw.ptbr.mode(io.ptw.ptbr.mode.getWidth-1) && priv_uses_vm && !io.req.bits.passthrough + val satp = Mux(priv_v, io.ptw.vsatp, io.ptw.ptbr) + val stage1_en = Bool(usingVM) && satp.mode(satp.mode.getWidth-1) + val vstage1_en = Bool(usingHypervisor) && priv_v && io.ptw.vsatp.mode(io.ptw.vsatp.mode.getWidth-1) + val stage2_en = Bool(usingHypervisor) && priv_v && io.ptw.hgatp.mode(io.ptw.hgatp.mode.getWidth-1) + val vm_enabled = (stage1_en || stage2_en) && priv_uses_vm && !io.req.bits.passthrough + + // flush guest entries on vsatp.MODE Bare <-> SvXX transitions + val v_entries_use_stage1 = RegInit(false.B) + val vsatp_mode_mismatch = priv_v && (vstage1_en =/= v_entries_use_stage1) && !io.req.bits.passthrough // share a single physical memory attribute checker (unshare if critical path) val refill_ppn = io.ptw.resp.bits.pte.ppn(ppnBits-1, 0) @@ -210,8 +251,8 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T val prot_eff = fastCheck(Seq(RegionType.PUT_EFFECTS, RegionType.GET_EFFECTS) contains _.regionType) val sector_hits = sectored_entries(memIdx).map(_.sectorHit(vpn)) - val superpage_hits = superpage_entries.map(_.hit(vpn)) - val hitsVec = all_entries.map(vm_enabled && _.hit(vpn)) + val superpage_hits = superpage_entries.map(_.hit(vpn, priv_v)) + val hitsVec = all_entries.map(vm_enabled && _.hit(vpn, priv_v)) val real_hits = hitsVec.asUInt val hits = Cat(!vm_enabled, real_hits) @@ -220,11 +261,16 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T val pte = io.ptw.resp.bits.pte val newEntry = Wire(new TLBEntryData) newEntry.ppn := pte.ppn + newEntry.v := r_vstage1_en || r_stage2_en newEntry.c := cacheable newEntry.u := pte.u newEntry.g := pte.g && pte.v newEntry.ae_ptw := io.ptw.resp.bits.ae_ptw newEntry.ae_final := io.ptw.resp.bits.ae_final + newEntry.gf := io.ptw.resp.bits.gf + newEntry.hr := io.ptw.resp.bits.hr + newEntry.hw := io.ptw.resp.bits.hw + newEntry.hx := io.ptw.resp.bits.hx newEntry.sr := pte.sr() newEntry.sw := pte.sw() newEntry.sx := pte.sx() @@ -238,24 +284,26 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T newEntry.fragmented_superpage := io.ptw.resp.bits.fragmented_superpage when (special_entry.nonEmpty && !io.ptw.resp.bits.homogeneous) { - special_entry.foreach { e => - e.insert(r_refill_tag, io.ptw.resp.bits.level, newEntry) - when (invalidate_refill) { e.invalidate() } - } + special_entry.foreach(_.insert(r_refill_tag, io.ptw.resp.bits.level, newEntry)) }.elsewhen (io.ptw.resp.bits.level < pgLevels-1) { + val waddr = Mux(r_superpage_hit.valid && usingHypervisor, r_superpage_hit.bits, r_superpage_repl_addr) for ((e, i) <- superpage_entries.zipWithIndex) when (r_superpage_repl_addr === i) { e.insert(r_refill_tag, io.ptw.resp.bits.level, newEntry) when (invalidate_refill) { e.invalidate() } } }.otherwise { val r_memIdx = r_refill_tag.extract(cfg.nSectors.log2 + cfg.nSets.log2 - 1, cfg.nSectors.log2) - val waddr = Mux(r_sectored_hit, r_sectored_hit_addr, r_sectored_repl_addr) + val waddr = Mux(r_sectored_hit.valid, r_sectored_hit.bits, r_sectored_repl_addr) for ((e, i) <- sectored_entries(r_memIdx).zipWithIndex) when (waddr === i) { - when (!r_sectored_hit) { e.invalidate() } + when (!r_sectored_hit.valid) { e.invalidate() } e.insert(r_refill_tag, 0.U, newEntry) when (invalidate_refill) { e.invalidate() } } } + + r_gpa_valid := io.ptw.resp.bits.gpa.valid + r_gpa := io.ptw.resp.bits.gpa.bits + r_gpa_gf := io.ptw.resp.bits.gf } val entries = all_entries.map(_.getData(vpn)) @@ -265,11 +313,19 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T val nPhysicalEntries = 1 + special_entry.size val ptw_ae_array = Cat(false.B, entries.map(_.ae_ptw).asUInt) val final_ae_array = Cat(false.B, entries.map(_.ae_final).asUInt) - val priv_rw_ok = Mux(!priv_s || io.ptw.status.sum, entries.map(_.u).asUInt, 0.U) | Mux(priv_s, ~entries.map(_.u).asUInt, 0.U) + val ptw_gf_array = Cat(false.B, entries.map(_.gf).asUInt) + val sum = Mux(priv_v, io.ptw.gstatus.sum, io.ptw.status.sum) + val priv_rw_ok = Mux(!priv_s || sum, entries.map(_.u).asUInt, 0.U) | Mux(priv_s, ~entries.map(_.u).asUInt, 0.U) val priv_x_ok = Mux(priv_s, ~entries.map(_.u).asUInt, entries.map(_.u).asUInt) - val r_array = Cat(true.B, priv_rw_ok & (entries.map(_.sr).asUInt | Mux(io.ptw.status.mxr, entries.map(_.sx).asUInt, UInt(0)))) - val w_array = Cat(true.B, priv_rw_ok & entries.map(_.sw).asUInt) - val x_array = Cat(true.B, priv_x_ok & entries.map(_.sx).asUInt) + val stage1_bypass = Fill(entries.size, usingHypervisor && !stage1_en) + val mxr = io.ptw.status.mxr | Mux(priv_v, io.ptw.gstatus.mxr, false.B) + val r_array = Cat(true.B, (priv_rw_ok & (entries.map(_.sr).asUInt | Mux(mxr, entries.map(_.sx).asUInt, UInt(0)))) | stage1_bypass) + val w_array = Cat(true.B, (priv_rw_ok & entries.map(_.sw).asUInt) | stage1_bypass) + val x_array = Cat(true.B, (priv_x_ok & entries.map(_.sx).asUInt) | stage1_bypass) + val stage2_bypass = Fill(entries.size, !stage2_en) + val hr_array = Cat(true.B, entries.map(_.hr).asUInt | Mux(io.ptw.status.mxr, entries.map(_.hx).asUInt, UInt(0)) | stage2_bypass) + val hw_array = Cat(true.B, entries.map(_.hw).asUInt | stage2_bypass) + val hx_array = Cat(true.B, entries.map(_.hx).asUInt | stage2_bypass) val pr_array = Cat(Fill(nPhysicalEntries, prot_r), normal_entries.map(_.pr).asUInt) & ~(ptw_ae_array | final_ae_array) val pw_array = Cat(Fill(nPhysicalEntries, prot_w), normal_entries.map(_.pw).asUInt) & ~(ptw_ae_array | final_ae_array) val px_array = Cat(Fill(nPhysicalEntries, prot_x), normal_entries.map(_.px).asUInt) & ~(ptw_ae_array | final_ae_array) @@ -284,21 +340,31 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T val prefetchable_array = Cat((cacheable && homogeneous) << (nPhysicalEntries-1), normal_entries.map(_.c).asUInt) val misaligned = (io.req.bits.vaddr & (UIntToOH(io.req.bits.size) - 1)).orR - val bad_va = if (!usingVM || (minPgLevels == pgLevels && vaddrBits == vaddrBitsExtended)) false.B else vm_enabled && { + def badVA(guestPA: Boolean): Bool = { + val additionalPgLevels = (if (guestPA) io.ptw.hgatp else satp).additionalPgLevels + val extraBits = if (guestPA) hypervisorExtraAddrBits else 0 + val signed = !guestPA val nPgLevelChoices = pgLevels - minPgLevels + 1 - val minVAddrBits = pgIdxBits + minPgLevels * pgLevelBits + val minVAddrBits = pgIdxBits + minPgLevels * pgLevelBits + extraBits (for (i <- 0 until nPgLevelChoices) yield { - val mask = ((BigInt(1) << vaddrBitsExtended) - (BigInt(1) << (minVAddrBits + i * pgLevelBits - 1))).U + val mask = ((BigInt(1) << vaddrBitsExtended) - (BigInt(1) << (minVAddrBits + i * pgLevelBits - signed.toInt))).U val maskedVAddr = io.req.bits.vaddr & mask - io.ptw.ptbr.additionalPgLevels === i && !(maskedVAddr === 0 || maskedVAddr === mask) + additionalPgLevels === i && !(maskedVAddr === 0 || signed && maskedVAddr === mask) }).orR } + val bad_gpa = + if (!usingHypervisor) false.B + else vm_enabled && !stage1_en && badVA(true) + val bad_va = + if (!usingVM || (minPgLevels == pgLevels && vaddrBits == vaddrBitsExtended)) false.B + else vm_enabled && stage1_en && badVA(false) val cmd_lrsc = Bool(usingAtomics) && io.req.bits.cmd.isOneOf(M_XLR, M_XSC) val cmd_amo_logical = Bool(usingAtomics) && isAMOLogical(io.req.bits.cmd) val cmd_amo_arithmetic = Bool(usingAtomics) && isAMOArithmetic(io.req.bits.cmd) val cmd_put_partial = io.req.bits.cmd === M_PWR val cmd_read = isRead(io.req.bits.cmd) + val cmd_readx = usingHypervisor && io.req.bits.cmd === M_HLVX val cmd_write = isWrite(io.req.bits.cmd) val cmd_write_perms = cmd_write || io.req.bits.cmd.isOneOf(M_FLUSH_ALL, M_WOK) // not a write, but needs write permissions @@ -318,16 +384,27 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T Mux(cmd_amo_logical, ~paa_array, 0.U) | Mux(cmd_amo_arithmetic, ~pal_array, 0.U) | Mux(cmd_lrsc, ~0.U(pal_array.getWidth.W), 0.U) - val ma_ld_array = Mux(misaligned && cmd_read, ~eff_array & ~(ptw_ae_array | final_ae_array), 0.U) - val ma_st_array = Mux(misaligned && cmd_write, ~eff_array & ~(ptw_ae_array | final_ae_array), 0.U) - val pf_ld_array = Mux(cmd_read, ~(r_array | ptw_ae_array), 0.U) - val pf_st_array = Mux(cmd_write_perms, ~(w_array | ptw_ae_array), 0.U) - val pf_inst_array = ~(x_array | ptw_ae_array) + val ma_ld_array = Mux(misaligned && cmd_read, ~eff_array & ~(ptw_ae_array | final_ae_array | ptw_gf_array), 0.U) + val ma_st_array = Mux(misaligned && cmd_write, ~eff_array & ~(ptw_ae_array | final_ae_array | ptw_gf_array), 0.U) + val pf_ld_array = Mux(cmd_read, ~(Mux(cmd_readx, x_array, r_array) | (ptw_ae_array | ptw_gf_array)), 0.U) + val pf_st_array = Mux(cmd_write_perms, ~(w_array | (ptw_ae_array | ptw_gf_array)), 0.U) + val pf_inst_array = ~(x_array | (ptw_ae_array | ptw_gf_array)) + val gf_ld_array = Mux(priv_v && cmd_read, ~(Mux(cmd_readx, hx_array, hr_array) | ptw_ae_array), 0.U) + val gf_st_array = Mux(priv_v && cmd_write_perms, ~(hw_array | ptw_ae_array), 0.U) + val gf_inst_array = Mux(priv_v, ~(hx_array | ptw_ae_array), 0.U) + + val gpa_hits = { + val need_gpa_mask = if (instruction) gf_inst_array else gf_ld_array | gf_st_array + val hit_mask = Fill(ordinary_entries.size, r_gpa_valid && r_gpa_vpn === vpn) | Fill(all_entries.size, !vstage1_en) + hit_mask | ~need_gpa_mask(all_entries.size-1, 0) + } - val tlb_hit = real_hits.orR - val tlb_miss = vm_enabled && !bad_va && !tlb_hit + val tlb_hit_if_not_gpa_miss = real_hits.orR + val tlb_hit = (real_hits & gpa_hits).orR - val sectored_plru = new SetAssocLRU(cfg.nSets, sectored_entries(0).size, "plru") + val tlb_miss = vm_enabled && !vsatp_mode_mismatch && !bad_va && !tlb_hit + + val sectored_plru = new SetAssocLRU(cfg.nSets, sectored_entries.head.size, "plru") val superpage_plru = new PseudoLRU(superpage_entries.size) when (io.req.valid && vm_enabled) { when (sector_hits.orR) { sectored_plru.access(memIdx, OHToUInt(sector_hits)) } @@ -345,6 +422,9 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T io.resp.pf.ld := (bad_va && cmd_read) || (pf_ld_array & hits).orR io.resp.pf.st := (bad_va && cmd_write_perms) || (pf_st_array & hits).orR io.resp.pf.inst := bad_va || (pf_inst_array & hits).orR + io.resp.gf.ld := (bad_gpa && cmd_read) || (gf_ld_array & hits).orR + io.resp.gf.st := (bad_gpa && cmd_write_perms) || (gf_st_array & hits).orR + io.resp.gf.inst := bad_gpa || (gf_inst_array & hits).orR io.resp.ae.ld := (ae_ld_array & hits).orR io.resp.ae.st := (ae_st_array & hits).orR io.resp.ae.inst := (~px_array & hits).orR @@ -354,23 +434,41 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T io.resp.cacheable := (c_array & hits).orR io.resp.must_alloc := (must_alloc_array & hits).orR io.resp.prefetchable := (prefetchable_array & hits).orR && edge.manager.managers.forall(m => !m.supportsAcquireB || m.supportsHint) - io.resp.miss := do_refill || tlb_miss || multipleHits + io.resp.miss := do_refill || vsatp_mode_mismatch || tlb_miss || multipleHits io.resp.paddr := Cat(ppn, io.req.bits.vaddr(pgIdxBits-1, 0)) + io.resp.gpa_is_pte := vstage1_en && r_gpa_gf + io.resp.gpa := { + val page = Mux(!vstage1_en, Cat(bad_va, vpn), r_gpa >> pgIdxBits) + val offset = Mux(io.resp.gpa_is_pte, r_gpa(pgIdxBits-1, 0), io.req.bits.vaddr(pgIdxBits-1, 0)) + Cat(page, offset) + } io.ptw.req.valid := state === s_request io.ptw.req.bits.valid := !io.kill io.ptw.req.bits.bits.addr := r_refill_tag + io.ptw.req.bits.bits.vstage1 := r_vstage1_en + io.ptw.req.bits.bits.stage2 := r_stage2_en + io.ptw.req.bits.bits.need_gpa := r_need_gpa if (usingVM) { + when(io.ptw.req.fire() && io.ptw.req.bits.valid) { + r_gpa_valid := false + r_gpa_vpn := r_refill_tag + } + val sfence = io.sfence.valid when (io.req.fire() && tlb_miss) { state := s_request r_refill_tag := vpn - + r_need_gpa := tlb_hit_if_not_gpa_miss + r_vstage1_en := vstage1_en + r_stage2_en := stage2_en r_superpage_repl_addr := replacementEntry(superpage_entries, superpage_plru.way) r_sectored_repl_addr := replacementEntry(sectored_entries(memIdx), sectored_plru.way(memIdx)) - r_sectored_hit_addr := OHToUInt(sector_hits) - r_sectored_hit := sector_hits.orR + r_sectored_hit.valid := sector_hits.orR + r_sectored_hit.bits := OHToUInt(sector_hits) + r_superpage_hit.valid := superpage_hits.orR + r_superpage_hit.bits := OHToUInt(superpage_hits) } when (state === s_request) { when (sfence) { state := s_ready } @@ -387,11 +485,17 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T when (sfence) { assert(!io.sfence.bits.rs1 || (io.sfence.bits.addr >> pgIdxBits) === vpn) for (e <- all_real_entries) { - when (io.sfence.bits.rs1) { e.invalidateVPN(vpn) } - .elsewhen (io.sfence.bits.rs2) { e.invalidateNonGlobal() } - .otherwise { e.invalidate() } + val hv = usingHypervisor && io.sfence.bits.hv + val hg = usingHypervisor && io.sfence.bits.hg + when (!hg && io.sfence.bits.rs1) { e.invalidateVPN(vpn, hv) } + .elsewhen (!hg && io.sfence.bits.rs2) { e.invalidateNonGlobal(hv) } + .otherwise { e.invalidate(hv, hg) } } } + when(io.req.fire() && vsatp_mode_mismatch) { + all_real_entries.foreach(_.invalidate(true, false)) + v_entries_use_stage1 := vstage1_en + } when (multipleHits || reset) { all_real_entries.foreach(_.invalidate()) } diff --git a/src/main/scala/tile/BaseTile.scala b/src/main/scala/tile/BaseTile.scala index 9651e352f9..ef4ed5e5ad 100644 --- a/src/main/scala/tile/BaseTile.scala +++ b/src/main/scala/tile/BaseTile.scala @@ -47,6 +47,7 @@ trait HasNonDiplomaticTileParameters { def usingVM: Boolean = tileParams.core.useVM def usingUser: Boolean = tileParams.core.useUser || usingSupervisor def usingSupervisor: Boolean = tileParams.core.hasSupervisorMode + def usingHypervisor: Boolean = usingVM && tileParams.core.useHypervisor def usingDebug: Boolean = tileParams.core.useDebug def usingRoCC: Boolean = !p(BuildRoCC).isEmpty def usingBTB: Boolean = tileParams.btb.isDefined && tileParams.btb.get.nEntries > 0 @@ -60,12 +61,19 @@ trait HasNonDiplomaticTileParameters { def pgLevelBits: Int = 10 - log2Ceil(xLen / 32) def pgLevels: Int = p(PgLevels) def maxSVAddrBits: Int = pgIdxBits + pgLevels * pgLevelBits + def maxHypervisorExtraAddrBits: Int = 2 + def hypervisorExtraAddrBits: Int = { + if (usingHypervisor) maxHypervisorExtraAddrBits + else 0 + } + def maxHVAddrBits: Int = maxSVAddrBits + hypervisorExtraAddrBits def minPgLevels: Int = { val res = xLen match { case 32 => 2; case 64 => 3 } require(pgLevels >= res) res } def asIdBits: Int = p(ASIdBits) + def vmIdBits: Int = p(VMIdBits) lazy val maxPAddrBits: Int = { require(xLen == 32 || xLen == 64, s"Only XLENs of 32 or 64 are supported, but got $xLen") xLen match { case 32 => 34; case 64 => 56 } @@ -126,9 +134,16 @@ trait HasNonDiplomaticTileParameters { "i-tlb-size" -> (i.nTLBWays * i.nTLBSets).asProperty, "i-tlb-sets" -> i.nTLBSets.asProperty)).getOrElse(Nil) - val mmu = if (!tileParams.core.useVM) Nil else Map( - "tlb-split" -> Nil, - "mmu-type" -> s"riscv,sv$maxSVAddrBits".asProperty) + val mmu = + if (tileParams.core.useVM) { + if (tileParams.core.useHypervisor) { + Map("tlb-split" -> Nil, "mmu-type" -> s"riscv,sv${maxSVAddrBits},sv${maxSVAddrBits}x4".asProperty) + } else { + Map("tlb-split" -> Nil, "mmu-type" -> s"riscv,sv$maxSVAddrBits".asProperty) + } + } else { + Nil + } val pmp = if (tileParams.core.nPMPs > 0) Map( "riscv,pmpregions" -> tileParams.core.nPMPs.asProperty, @@ -152,7 +167,7 @@ trait HasTileParameters extends HasNonDiplomaticTileParameters { } def vaddrBits: Int = if (usingVM) { - val v = maxSVAddrBits + val v = maxHVAddrBits require(v == xLen || xLen > v && v > paddrBits) v } else { diff --git a/src/main/scala/tile/Core.scala b/src/main/scala/tile/Core.scala index 7142a72633..09fa4ad029 100644 --- a/src/main/scala/tile/Core.scala +++ b/src/main/scala/tile/Core.scala @@ -15,6 +15,7 @@ case object MaxHartIdBits extends Field[Int] trait CoreParams { val bootFreqHz: BigInt val useVM: Boolean + val useHypervisor: Boolean val useUser: Boolean val useSupervisor: Boolean val useDebug: Boolean @@ -46,6 +47,7 @@ trait CoreParams { val haveCFlush: Boolean val nL2TLBEntries: Int val nL2TLBWays: Int + val nPTECacheEntries: Int val mtvecInit: Option[BigInt] val mtvecWritable: Boolean def customCSRs(implicit p: Parameters): CustomCSRs = new CustomCSRs diff --git a/src/main/scala/tile/LazyRoCC.scala b/src/main/scala/tile/LazyRoCC.scala index d58b1f003a..ae36dcc59f 100644 --- a/src/main/scala/tile/LazyRoCC.scala +++ b/src/main/scala/tile/LazyRoCC.scala @@ -185,6 +185,7 @@ class AccumulatorExampleModuleImp(outer: AccumulatorExample)(implicit p: Paramet io.mem.req.bits.data := 0.U // we're not performing any stores... io.mem.req.bits.phys := false.B io.mem.req.bits.dprv := cmd.bits.status.dprv + io.mem.req.bits.dv := cmd.bits.status.dv } class TranslatorExample(opcodes: OpcodeSet)(implicit p: Parameters) extends LazyRoCC(opcodes, nPTWPorts = 1) { @@ -351,6 +352,7 @@ class BlackBoxExampleModuleImp(outer: BlackBoxExample, blackBoxFile: String)(imp "coreDataBits" -> IntParam(coreDataBits), "coreDataBytes" -> IntParam(coreDataBytes), "paddrBits" -> IntParam(paddrBits), + "vaddrBitsExtended" -> IntParam(vaddrBitsExtended), "FPConstants_RM_SZ" -> IntParam(FPConstants.RM_SZ), "fLen" -> IntParam(fLen), "FPConstants_FLAGS_SZ" -> IntParam(FPConstants.FLAGS_SZ)