Skip to content

Commit

Permalink
[glsl-in] Don't clone HirExpr when lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
JCapucho authored and kvark committed Aug 10, 2021
1 parent 07c2862 commit e384bce
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 116 deletions.
5 changes: 5 additions & 0 deletions src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ impl<T> Arena<T> {
marker: PhantomData,
}
}

/// Clears the arena keeping all allocations
pub fn clear(&mut self) {
self.data.clear()
}
}

impl<T> ops::Index<Handle<T>> for Arena<T> {
Expand Down
142 changes: 117 additions & 25 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pub struct Context {
pub lookup_global_var_exps: FastHashMap<String, VariableReference>,
pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,

pub hir_exprs: Arena<HirExpr>,
pub typifier: Typifier,
emitter: Emitter,
stmt_ctx: Option<StmtContext>,
}

impl Context {
Expand All @@ -51,9 +51,9 @@ impl Context {
),
samplers: FastHashMap::default(),

hir_exprs: Arena::default(),
typifier: Typifier::new(),
emitter: Emitter::default(),
stmt_ctx: Some(StmtContext::new()),
};

this.emit_start();
Expand Down Expand Up @@ -280,14 +280,71 @@ impl Context {
self.scopes.pop();
}

/// Returns a [`StmtContext`](StmtContext) to be used in parsing and lowering
///
/// # Panics
/// - If more than one [`StmtContext`](StmtContext) are active at the same
/// time or if the previous call didn't use it in lowering.
#[must_use]
pub fn stmt_ctx(&mut self) -> StmtContext {
self.stmt_ctx.take().unwrap()
}

/// Lowers a [`HirExpr`](HirExpr) which might produce a [`Expression`](Expression).
///
/// consumes a [`StmtContext`](StmtContext) returning it to the context so
/// that it can be used again later.
pub fn lower(
&mut self,
mut stmt: StmtContext,
parser: &mut Parser,
expr: Handle<HirExpr>,
lhs: bool,
body: &mut Block,
) -> Result<(Option<Handle<Expression>>, SourceMetadata)> {
let res = self.lower_inner(&stmt, parser, expr, lhs, body);

stmt.hir_exprs.clear();
self.stmt_ctx = Some(stmt);

res
}

/// Similar to [`lower`](Self::lower) but returns an error if the expression
/// returns void (ie. doesn't produce a [`Expression`](Expression)).
///
/// consumes a [`StmtContext`](StmtContext) returning it to the context so
/// that it can be used again later.
pub fn lower_expect(
&mut self,
mut stmt: StmtContext,
parser: &mut Parser,
expr: Handle<HirExpr>,
lhs: bool,
body: &mut Block,
) -> Result<(Handle<Expression>, SourceMetadata)> {
let res = self.lower_expect_inner(&stmt, parser, expr, lhs, body);

stmt.hir_exprs.clear();
self.stmt_ctx = Some(stmt);

res
}

/// internal implementation of [`lower_expect`](Self::lower_expect)
///
/// this method is only public because it's used in
/// [`function_call`](Parser::function_call), unless you know what
/// you're doing use [`lower_expect`](Self::lower_expect)
pub fn lower_expect_inner(
&mut self,
stmt: &StmtContext,
parser: &mut Parser,
expr: Handle<HirExpr>,
lhs: bool,
body: &mut Block,
) -> Result<(Handle<Expression>, SourceMetadata)> {
let (maybe_expr, meta) = self.lower(parser, expr, lhs, body)?;
let (maybe_expr, meta) = self.lower_inner(stmt, parser, expr, lhs, body)?;

let expr = match maybe_expr {
Some(e) => e,
Expand All @@ -302,19 +359,22 @@ impl Context {
Ok((expr, meta))
}

pub fn lower(
/// Internal implementation of [`lower`](Self::lower)
fn lower_inner(
&mut self,
stmt: &StmtContext,
parser: &mut Parser,
expr: Handle<HirExpr>,
lhs: bool,
body: &mut Block,
) -> Result<(Option<Handle<Expression>>, SourceMetadata)> {
let HirExpr { kind, meta } = self.hir_exprs[expr].clone();
let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];

let handle = match kind {
let handle = match *kind {
HirExprKind::Access { base, index } => {
let base = self.lower_expect(parser, base, true, body)?.0;
let (index, index_meta) = self.lower_expect(parser, index, false, body)?;
let base = self.lower_expect_inner(stmt, parser, base, true, body)?.0;
let (index, index_meta) =
self.lower_expect_inner(stmt, parser, index, false, body)?;

let pointer = parser
.solve_constant(self, index, index_meta)
Expand Down Expand Up @@ -354,17 +414,19 @@ impl Context {

pointer
}
HirExprKind::Select { base, field } => {
let base = self.lower_expect(parser, base, lhs, body)?.0;
HirExprKind::Select { base, ref field } => {
let base = self.lower_expect_inner(stmt, parser, base, lhs, body)?.0;

parser.field_selection(self, lhs, body, base, &field, meta)?
parser.field_selection(self, lhs, body, base, field, meta)?
}
HirExprKind::Constant(constant) if !lhs => {
self.add_expression(Expression::Constant(constant), body)
}
HirExprKind::Binary { left, op, right } if !lhs => {
let (mut left, left_meta) = self.lower_expect(parser, left, false, body)?;
let (mut right, right_meta) = self.lower_expect(parser, right, false, body)?;
let (mut left, left_meta) =
self.lower_expect_inner(stmt, parser, left, false, body)?;
let (mut right, right_meta) =
self.lower_expect_inner(stmt, parser, right, false, body)?;

match op {
BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => self
Expand Down Expand Up @@ -439,11 +501,11 @@ impl Context {
}
}
HirExprKind::Unary { op, expr } if !lhs => {
let expr = self.lower_expect(parser, expr, false, body)?.0;
let expr = self.lower_expect_inner(stmt, parser, expr, false, body)?.0;

self.add_expression(Expression::Unary { op, expr }, body)
}
HirExprKind::Variable(var) => {
HirExprKind::Variable(ref var) => {
if lhs {
if !var.mutable {
parser.errors.push(Error {
Expand All @@ -461,19 +523,29 @@ impl Context {
var.expr
}
}
HirExprKind::Call(call) if !lhs => {
let maybe_expr =
parser.function_or_constructor_call(self, body, call.kind, &call.args, meta)?;
HirExprKind::Call(ref call) if !lhs => {
let maybe_expr = parser.function_or_constructor_call(
self,
stmt,
body,
call.kind.clone(),
&call.args,
meta,
)?;
return Ok((maybe_expr, meta));
}
HirExprKind::Conditional {
condition,
accept,
reject,
} if !lhs => {
let condition = self.lower_expect(parser, condition, false, body)?.0;
let (mut accept, accept_meta) = self.lower_expect(parser, accept, false, body)?;
let (mut reject, reject_meta) = self.lower_expect(parser, reject, false, body)?;
let condition = self
.lower_expect_inner(stmt, parser, condition, false, body)?
.0;
let (mut accept, accept_meta) =
self.lower_expect_inner(stmt, parser, accept, false, body)?;
let (mut reject, reject_meta) =
self.lower_expect_inner(stmt, parser, reject, false, body)?;

self.binary_implicit_conversion(
parser,
Expand All @@ -493,8 +565,9 @@ impl Context {
)
}
HirExprKind::Assign { tgt, value } if !lhs => {
let (pointer, ptr_meta) = self.lower_expect(parser, tgt, true, body)?;
let (mut value, value_meta) = self.lower_expect(parser, value, false, body)?;
let (pointer, ptr_meta) = self.lower_expect_inner(stmt, parser, tgt, true, body)?;
let (mut value, value_meta) =
self.lower_expect_inner(stmt, parser, value, false, body)?;

let scalar_components = self.expr_scalar_components(parser, pointer, ptr_meta)?;

Expand Down Expand Up @@ -564,7 +637,7 @@ impl Context {
false => BinaryOperator::Subtract,
};

let pointer = self.lower_expect(parser, expr, true, body)?.0;
let pointer = self.lower_expect_inner(stmt, parser, expr, true, body)?.0;
let left = self.add_expression(Expression::Load { pointer }, body);

let uint = match parser.resolve_type(self, left, meta)?.scalar_kind() {
Expand Down Expand Up @@ -641,7 +714,7 @@ impl Context {
_ => {
return Err(Error {
kind: ErrorKind::SemanticError(
format!("{:?} cannot be in the left hand side", self.hir_exprs[expr])
format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
.into(),
),
meta,
Expand Down Expand Up @@ -779,3 +852,22 @@ impl Index<Handle<Expression>> for Context {
&self.expressions[index]
}
}

/// Helper struct passed when parsing expressions
///
/// This struct should only be obtained trough [`stmt_ctx`](Context::stmt_ctx)
/// and only one of these may be active at any time per context.
#[derive(Debug)]
pub struct StmtContext {
/// A arena of high level expressions which can be lowered trough a
/// [`Context`](Context) to naga's [`Expression`](crate::Expression)s
pub hir_exprs: Arena<HirExpr>,
}

impl StmtContext {
fn new() -> Self {
StmtContext {
hir_exprs: Arena::new(),
}
}
}
45 changes: 21 additions & 24 deletions src/front/glsl/functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
ast::*,
context::Context,
context::{Context, StmtContext},
error::{Error, ErrorKind},
types::{scalar_components, type_power},
Parser, Result, SourceMetadata,
Expand Down Expand Up @@ -41,14 +41,15 @@ impl Parser {
pub(crate) fn function_or_constructor_call(
&mut self,
ctx: &mut Context,
stmt: &StmtContext,
body: &mut Block,
fc: FunctionCallKind,
raw_args: &[Handle<HirExpr>],
meta: SourceMetadata,
) -> Result<Option<Handle<Expression>>> {
let args: Vec<_> = raw_args
.iter()
.map(|e| ctx.lower_expect(self, *e, false, body))
.map(|e| ctx.lower_expect_inner(stmt, self, *e, false, body))
.collect::<Result<_>>()?;

match fc {
Expand Down Expand Up @@ -307,14 +308,16 @@ impl Parser {
Ok(Some(h))
}
FunctionCallKind::Function(name) => {
self.function_call(ctx, body, name, args, raw_args, meta)
self.function_call(ctx, stmt, body, name, args, raw_args, meta)
}
}
}

#[allow(clippy::too_many_arguments)]
fn function_call(
&mut self,
ctx: &mut Context,
stmt: &StmtContext,
body: &mut Block,
name: String,
mut args: Vec<(Handle<Expression>, SourceMetadata)>,
Expand Down Expand Up @@ -969,8 +972,13 @@ impl Parser {
.iter()
.zip(raw_args.iter().zip(parameters.iter()))
{
let (mut handle, meta) =
ctx.lower_expect(self, *expr, parameter_info.qualifier.is_lhs(), body)?;
let (mut handle, meta) = ctx.lower_expect_inner(
stmt,
self,
*expr,
parameter_info.qualifier.is_lhs(),
body,
)?;

if let TypeInner::Vector { size, kind, width } =
*self.resolve_type(ctx, handle, meta)?
Expand Down Expand Up @@ -1025,27 +1033,16 @@ impl Parser {

ctx.emit_start();
for (tgt, pointer) in proxy_writes {
let temp_ref = ctx.hir_exprs.append(HirExpr {
kind: HirExprKind::Variable(VariableReference {
expr: pointer,
load: true,
mutable: true,
entry_arg: None,
}),
meta,
});
let assign = ctx.hir_exprs.append(HirExpr {
kind: HirExprKind::Assign {
tgt,
value: temp_ref,
},
meta,
});
let value = ctx.add_expression(Expression::Load { pointer }, body);
let target = ctx.lower_expect_inner(stmt, self, tgt, true, body)?.0;

let _ = ctx.lower_expect(self, assign, false, body)?;
ctx.emit_flush(body);
body.push(Statement::Store {
pointer: target,
value,
});
ctx.emit_start();
}
ctx.emit_flush(body);
ctx.emit_start();

Ok(result)
}
Expand Down
5 changes: 3 additions & 2 deletions src/front/glsl/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ impl<'source> ParsingContext<'source> {

let mut ctx = Context::new(parser, &mut block);

let expr = self.parse_conditional(parser, &mut ctx, &mut block, None)?;
let (root, meta) = ctx.lower_expect(parser, expr, false, &mut block)?;
let mut stmt_ctx = ctx.stmt_ctx();
let expr = self.parse_conditional(parser, &mut ctx, &mut stmt_ctx, &mut block, None)?;
let (root, meta) = ctx.lower_expect(stmt_ctx, parser, expr, false, &mut block)?;

Ok((parser.solve_constant(&ctx, root, meta)?, meta))
}
Expand Down
5 changes: 3 additions & 2 deletions src/front/glsl/parser/declarations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ impl<'source> ParsingContext<'source> {
meta,
))
} else {
let expr = self.parse_assignment(parser, ctx, body)?;
let (mut init, init_meta) = ctx.lower_expect(parser, expr, false, body)?;
let mut stmt = ctx.stmt_ctx();
let expr = self.parse_assignment(parser, ctx, &mut stmt, body)?;
let (mut init, init_meta) = ctx.lower_expect(stmt, parser, expr, false, body)?;

let scalar_components = scalar_components(&parser.module.types[ty].inner);
if let Some((kind, width)) = scalar_components {
Expand Down
Loading

0 comments on commit e384bce

Please sign in to comment.