Skip to content

Commit

Permalink
Fix generic pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
StarsX committed Apr 17, 2024
1 parent be0477b commit c2b1049
Show file tree
Hide file tree
Showing 20 changed files with 1,360 additions and 181 deletions.
44 changes: 29 additions & 15 deletions XUSG/Core/XUSG.h
Original file line number Diff line number Diff line change
Expand Up @@ -2340,12 +2340,30 @@ namespace XUSG
virtual Pipeline CreatePipeline(PipelineLib* pPipelineLib, const wchar_t* name = nullptr) const = 0;
virtual Pipeline GetPipeline(PipelineLib* pPipelineLib, const wchar_t* name = nullptr) const = 0;

// Get the key of the pipeline cache for XUSG pipeline lib
virtual const std::string& GetKey() const = 0;
virtual PipelineLayout GetPipelineLayout() const = 0;
virtual Blob GetShader(Shader::Stage stage) const = 0;
virtual Blob GetCachedPipeline() const = 0;
virtual uint32_t GetNodeMask() const = 0;
virtual PipelineFlag GetFlags() const = 0;

virtual uint32_t OMGetSampleMask() const = 0;
virtual const Blend* OMGetBlendState() const = 0;
virtual const Rasterizer* RSGetState() const = 0;
virtual const DepthStencil* DSGetState() const = 0;

virtual const InputLayout* IAGetInputLayout() const = 0;
virtual PrimitiveTopologyType IAGetPrimitiveTopologyType() const = 0;
virtual IBStripCutValue IAGetIndexBufferStripCutValue() const = 0;

virtual uint8_t OMGetNumRenderTargets() const = 0;
virtual Format OMGetRTVFormat(uint8_t i) const = 0;
virtual Format OMGetDSVFormat() const = 0;
virtual uint8_t OMGetSampleCount() const = 0;
virtual uint8_t OMGetSampleQuality() const = 0;

// Get API native desc of this PSO handle
// pInputElements should be a pointer to std::vector<[API]_INPUT_ELEMENT_DESC>
virtual void GetHandleDesc(void* pHandleDesc, void* pInputElements, PipelineLib* pPipelineLib) const = 0;
virtual void GetHandleDesc(void* pHandleDesc, void* pInputElements) const = 0;

using uptr = std::unique_ptr<State>;
using sptr = std::shared_ptr<State>;
Expand All @@ -2362,7 +2380,7 @@ namespace XUSG
virtual ~PipelineLib() {};

virtual void SetDevice(const Device* pDevice) = 0;
virtual void SetPipeline(const std::string& key, const Pipeline& pipeline) = 0;
virtual void SetPipeline(const State* pState, const Pipeline& pipeline) = 0;

virtual void SetInputLayout(uint32_t index, const InputElement* pElements, uint32_t numElements) = 0;
virtual const InputLayout* GetInputLayout(uint32_t index) const = 0;
Expand All @@ -2375,10 +2393,6 @@ namespace XUSG
virtual const Rasterizer* GetRasterizer(RasterizerPreset preset) = 0;
virtual const DepthStencil* GetDepthStencil(DepthStencilPreset preset) = 0;

// Get API native desc of this PSO handle
// pInputElements should be a pointer to std::vector<[API]_INPUT_ELEMENT_DESC>
virtual void GetHandleDesc(void* pHandleDesc, void* pInputElements, const std::string& key) = 0;

static DepthStencil DepthStencilDefault();
static DepthStencil DepthStencilNone();
static DepthStencil DepthRead();
Expand Down Expand Up @@ -2439,11 +2453,14 @@ namespace XUSG
virtual Pipeline CreatePipeline(PipelineLib* pPipelineLib, const wchar_t* name = nullptr) const = 0;
virtual Pipeline GetPipeline(PipelineLib* pPipelineLib, const wchar_t* name = nullptr) const = 0;

// Get the key of the pipeline cache for XUSG pipeline lib
virtual const std::string& GetKey() const = 0;
virtual PipelineLayout GetPipelineLayout() const = 0;
virtual Blob GetShader() const = 0;
virtual Blob GetCachedPipeline() const = 0;
virtual uint32_t GetNodeMask() const = 0;
virtual PipelineFlag GetFlags() const = 0;

// Get API native desc of this PSO handle
virtual void GetHandleDesc(void* pHandleDesc, PipelineLib* pPipelineLib) const = 0;
virtual void GetHandleDesc(void* pHandleDesc) const = 0;

using uptr = std::unique_ptr<State>;
using sptr = std::shared_ptr<State>;
Expand All @@ -2460,14 +2477,11 @@ namespace XUSG
virtual ~PipelineLib() {};

virtual void SetDevice(const Device* pDevice) = 0;
virtual void SetPipeline(const std::string& key, const Pipeline& pipeline) = 0;
virtual void SetPipeline(const State* pState, const Pipeline& pipeline) = 0;

virtual Pipeline CreatePipeline(const State* pState, const wchar_t* name = nullptr) = 0;
virtual Pipeline GetPipeline(const State* pState, const wchar_t* name = nullptr) = 0;

// Get API native desc of this PSO handle
virtual void GetHandleDesc(void* pHandleDesc, const std::string& key) = 0;

using uptr = std::unique_ptr<PipelineLib>;
using sptr = std::shared_ptr<PipelineLib>;

Expand Down
50 changes: 42 additions & 8 deletions XUSG/Core/XUSGComputeState_DX12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,39 @@ Pipeline State_DX12::GetPipeline(PipelineLib* pPipelineLib, const wchar_t* name)
return pPipelineLib->GetPipeline(this, name);
}

const string& State_DX12::GetKey() const
PipelineLayout State_DX12::GetPipelineLayout() const
{
return m_key;
return m_pKey->Layout;
}

Blob State_DX12::GetShader() const
{
return m_pKey->Shader;
}

Blob State_DX12::GetCachedPipeline() const
{
return m_pKey->CachedPipeline;
}

uint32_t State_DX12::GetNodeMask() const
{
return m_pKey->NodeMask;
}

void State_DX12::GetHandleDesc(void* pHandleDesc, PipelineLib* pPipelineLib) const
PipelineFlag State_DX12::GetFlags() const
{
pPipelineLib->GetHandleDesc(pHandleDesc, GetKey());
return m_pKey->Flags;
}

void State_DX12::GetHandleDesc(void* pHandleDesc) const
{
PipelineLib_DX12::GetHandleDesc(pHandleDesc, GetKey());
}

const string& State_DX12::GetKey() const
{
return m_key;
}

//--------------------------------------------------------------------------------------
Expand All @@ -91,19 +116,28 @@ void PipelineLib_DX12::SetDevice(const Device* pDevice)
assert(m_device);
}

void PipelineLib_DX12::SetPipeline(const string& key, const Pipeline& pipeline)
void PipelineLib_DX12::SetPipeline(const State* pState, const Pipeline& pipeline)
{
m_pipelines[key] = pipeline;
const auto p = dynamic_cast<const State_DX12*>(pState);
assert(p);

m_pipelines[p->GetKey()] = pipeline;
}

Pipeline PipelineLib_DX12::CreatePipeline(const State* pState, const wchar_t* name)
{
return createPipeline(pState->GetKey(), name);
const auto p = dynamic_cast<const State_DX12*>(pState);
assert(p);

return createPipeline(p->GetKey(), name);
}

Pipeline PipelineLib_DX12::GetPipeline(const State* pState, const wchar_t* name)
{
return getPipeline(pState->GetKey(), name);
const auto p = dynamic_cast<const State_DX12*>(pState);
assert(p);

return getPipeline(p->GetKey(), name);
}

Pipeline PipelineLib_DX12::createPipeline(const string& key, const wchar_t* name)
Expand Down
12 changes: 9 additions & 3 deletions XUSG/Core/XUSGComputeState_DX12.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ namespace XUSG

const std::string& GetKey() const;

void GetHandleDesc(void* pHandleDesc, PipelineLib* pPipelineLib) const;
PipelineLayout GetPipelineLayout() const;
Blob GetShader() const;
Blob GetCachedPipeline() const;
uint32_t GetNodeMask() const;
PipelineFlag GetFlags() const;

void GetHandleDesc(void* pHandleDesc) const;

protected:
PipelineDesc* m_pKey;
Expand All @@ -53,12 +59,12 @@ namespace XUSG
virtual ~PipelineLib_DX12();

void SetDevice(const Device* pDevice);
void SetPipeline(const std::string& key, const Pipeline& pipeline);
void SetPipeline(const State* pState, const Pipeline& pipeline);

Pipeline CreatePipeline(const State* pState, const wchar_t* name = nullptr);
Pipeline GetPipeline(const State* pState, const wchar_t* name = nullptr);

void GetHandleDesc(void* pHandleDesc, const std::string& key);
static void GetHandleDesc(void* pHandleDesc, const std::string& key);

protected:
virtual Pipeline createPipeline(const std::string& key, const wchar_t* name);
Expand Down
142 changes: 129 additions & 13 deletions XUSG/Core/XUSGGraphicsState_DX12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,103 @@ Pipeline State_DX12::GetPipeline(PipelineLib* pPipelineLib, const wchar_t* name)
return pPipelineLib->GetPipeline(this, name);
}

const string& State_DX12::GetKey() const
PipelineLayout State_DX12::GetPipelineLayout() const
{
return m_key;
return m_pKey->Layout;
}

Blob State_DX12::GetShader(Shader::Stage stage) const
{
assert(stage < Shader::Stage::NUM_GRAPHICS);

return m_pKey->Shaders[stage];
}

Blob State_DX12::GetCachedPipeline() const
{
return m_pKey->CachedPipeline;
}

uint32_t State_DX12::GetNodeMask() const
{
return m_pKey->NodeMask;
}

PipelineFlag State_DX12::GetFlags() const
{
return m_pKey->Flags;
}

uint32_t State_DX12::OMGetSampleMask() const
{
return m_pKey->SampleMask;
}

void State_DX12::GetHandleDesc(void* pHandleDesc, void* pInputElements, PipelineLib* pPipelineLib) const
const Graphics::Blend* State_DX12::OMGetBlendState() const
{
pPipelineLib->GetHandleDesc(pHandleDesc, pInputElements, GetKey());
return m_pKey->pBlend;
}

const Graphics::Rasterizer* State_DX12::RSGetState() const
{
return m_pKey->pRasterizer;
}

const Graphics::DepthStencil* State_DX12::DSGetState() const
{
return m_pKey->pDepthStencil;
}

const InputLayout* State_DX12::IAGetInputLayout() const
{
return m_pKey->pInputLayout;
}

PrimitiveTopologyType State_DX12::IAGetPrimitiveTopologyType() const
{
return m_pKey->PrimTopologyType;
}

IBStripCutValue State_DX12::IAGetIndexBufferStripCutValue() const
{
return static_cast<IBStripCutValue>(m_pKey->IBStripCutValue);
}

uint8_t State_DX12::OMGetNumRenderTargets() const
{
return m_pKey->NumRenderTargets;
}

Format State_DX12::OMGetRTVFormat(uint8_t i) const
{
assert(i < m_pKey->NumRenderTargets);

return m_pKey->RTVFormats[i];
}

Format State_DX12::OMGetDSVFormat() const
{
return m_pKey->DSVFormat;
}

uint8_t State_DX12::OMGetSampleCount() const
{
return m_pKey->SampleCount;
}

uint8_t State_DX12::OMGetSampleQuality() const
{
return m_pKey->SampleQuality;
}

void State_DX12::GetHandleDesc(void* pHandleDesc, void* pInputElements) const
{
PipelineLib_DX12::GetHandleDesc(pHandleDesc, pInputElements, GetKey());
}

const string& State_DX12::GetKey() const
{
return m_key;
}

//--------------------------------------------------------------------------------------
Expand Down Expand Up @@ -206,9 +295,12 @@ void PipelineLib_DX12::SetDevice(const Device* pDevice)
assert(m_device);
}

void PipelineLib_DX12::SetPipeline(const string& key, const Pipeline& pipeline)
void PipelineLib_DX12::SetPipeline(const State* pState, const Pipeline& pipeline)
{
m_pipelines[key] = pipeline;
const auto p = dynamic_cast<const State_DX12*>(pState);
assert(p);

m_pipelines[p->GetKey()] = pipeline;
}

void PipelineLib_DX12::SetInputLayout(uint32_t index, const InputElement* pElements, uint32_t numElements)
Expand All @@ -228,23 +320,29 @@ const InputLayout* PipelineLib_DX12::CreateInputLayout(const InputElement* pElem

Pipeline PipelineLib_DX12::CreatePipeline(const State* pState, const wchar_t* name)
{
return createPipeline(pState->GetKey(), name);
const auto p = dynamic_cast<const State_DX12*>(pState);
assert(p);

return createPipeline(p->GetKey(), name);
}

Pipeline PipelineLib_DX12::GetPipeline(const State* pState, const wchar_t* name)
{
return getPipeline(pState->GetKey(), name);
const auto p = dynamic_cast<const State_DX12*>(pState);
assert(p);

return getPipeline(p->GetKey(), name);
}

const Blend* PipelineLib_DX12::GetBlend(BlendPreset preset, uint8_t numColorRTs)
const Graphics::Blend* PipelineLib_DX12::GetBlend(BlendPreset preset, uint8_t numColorRTs)
{
if (m_blends[preset] == nullptr)
m_blends[preset] = make_unique<Blend>(m_pfnBlends[preset](numColorRTs));

return m_blends[preset].get();
}

const Rasterizer* PipelineLib_DX12::GetRasterizer(RasterizerPreset preset)
const Graphics::Rasterizer* PipelineLib_DX12::GetRasterizer(RasterizerPreset preset)
{
if (m_rasterizers[preset] == nullptr)
m_rasterizers[preset] = make_unique<Rasterizer>(m_pfnRasterizers[preset]());
Expand Down Expand Up @@ -308,7 +406,13 @@ void PipelineLib_DX12::GetHandleDesc(void* pHandleDesc, void* pInputElements, co
desc.GS = CD3DX12_SHADER_BYTECODE(static_cast<ID3DBlob*>(pDesc->Shaders[Shader::Stage::GS]));

// Blend state
const auto pBlend = pDesc->pBlend ? pDesc->pBlend : GetBlend(BlendPreset::DEFAULT_OPAQUE);
unique_ptr<Blend> blend;
auto pBlend = pDesc->pBlend;
if (pBlend == nullptr)
{
blend = make_unique<Blend>(DefaultOpaque(pDesc->NumRenderTargets));
pBlend = blend.get();
}
desc.BlendState.AlphaToCoverageEnable = pBlend->AlphaToCoverageEnable ? TRUE : FALSE;
desc.BlendState.IndependentBlendEnable = pBlend->IndependentBlendEnable ? TRUE : FALSE;
for (uint8_t i = 0; i < D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT; ++i)
Expand All @@ -329,7 +433,13 @@ void PipelineLib_DX12::GetHandleDesc(void* pHandleDesc, void* pInputElements, co
desc.SampleMask = pDesc->SampleMask;

// Rasterizer state
const auto pRasterizer = pDesc->pRasterizer ? pDesc->pRasterizer : GetRasterizer(RasterizerPreset::CULL_BACK);
unique_ptr<Rasterizer> rasterizer;
auto pRasterizer = pDesc->pRasterizer;
if (pRasterizer == nullptr)
{
rasterizer = make_unique<Rasterizer>(CullBack());
pRasterizer = rasterizer.get();
}
desc.RasterizerState.FillMode = GetDX12FillMode(pRasterizer->Fill);
desc.RasterizerState.CullMode = GetDX12CullMode(pRasterizer->Cull);
desc.RasterizerState.FrontCounterClockwise = pRasterizer->FrontCounterClockwise ? TRUE : FALSE;
Expand All @@ -348,7 +458,13 @@ void PipelineLib_DX12::GetHandleDesc(void* pHandleDesc, void* pInputElements, co
D3D12_CONSERVATIVE_RASTERIZATION_MODE_ON : D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF;

// Depth-stencil state
const auto pDepthStencil = pDesc->pDepthStencil ? pDesc->pDepthStencil : GetDepthStencil(DepthStencilPreset::DEFAULT_LESS);
unique_ptr<DepthStencil> depthStencil;
auto pDepthStencil = pDesc->pDepthStencil;
if (pDepthStencil == nullptr)
{
depthStencil = make_unique<DepthStencil>(DepthStencilDefault());
pDepthStencil = depthStencil.get();
}
desc.DepthStencilState.DepthEnable = pDepthStencil->DepthEnable ? TRUE : FALSE;
desc.DepthStencilState.DepthWriteMask = pDepthStencil->DepthWriteMask ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO;
desc.DepthStencilState.DepthFunc = GetDX12ComparisonFunc(pDepthStencil->Comparison);
Expand Down
Loading

0 comments on commit c2b1049

Please sign in to comment.