Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement Context-local functions and ergonomic wrappers marking functions as safe #666

Merged
merged 1 commit into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/builtins/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@ use serde_json::value::{from_value, to_value, Value};

use crate::errors::{Error, Result};

/// The context-local function type definition
pub trait FunctionRelaxed {
/// The context-local function type definition
fn call(&self, args: &HashMap<String, Value>) -> Result<Value>;

/// Whether the current function's output should be treated as safe, defaults to `false`
fn is_safe(&self) -> bool {
false
}
}

impl<F> FunctionRelaxed for F
where
F: Fn(&HashMap<String, Value>) -> Result<Value>,
{
fn call(&self, args: &HashMap<String, Value>) -> Result<Value> {
self(args)
}
}

/// The global function type definition
pub trait Function: Sync + Send {
/// The global function type definition
Expand All @@ -28,6 +48,43 @@ where
}
}

macro_rules! safe_function {
($function_trait:ident, $function_wrapper:ident) => {
/// Wrapper to make `is_safe` return `true` instead of `false` for a trait implementation.
pub struct $function_wrapper<F>
where
F: $function_trait,
{
inner: F,
}

impl<F> $function_trait for $function_wrapper<F>
where
F: $function_trait,
{
fn call(&self, args: &HashMap<String, Value>) -> Result<Value> {
self.inner.call(args)
}

fn is_safe(&self) -> bool {
true
}
}

impl<F> From<F> for $function_wrapper<F>
where
F: $function_trait,
{
fn from(func: F) -> Self {
$function_wrapper { inner: func }
}
}
};
}

safe_function!(FunctionRelaxed, FunctionRelaxedSafe);
safe_function!(Function, FunctionSafe);

pub fn range(args: &HashMap<String, Value>) -> Result<Value> {
let start = match args.get("start") {
Some(val) => match from_value::<usize>(val.clone()) {
Expand Down
40 changes: 37 additions & 3 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,39 @@ use serde::ser::Serialize;
use serde_json::value::{to_value, Map, Value};

use crate::errors::{Error, Result as TeraResult};
use crate::FunctionRelaxed;
use std::sync::Arc;

/// The struct that holds the context of a template rendering.
///
/// Light wrapper around a `BTreeMap` for easier insertions of Serializable
/// values
#[derive(Debug, Clone, PartialEq)]
#[derive(Clone)]
pub struct Context {
data: BTreeMap<String, Value>,
/// Ignored by PartialEq!
functions: BTreeMap<String, Arc<dyn FunctionRelaxed>>,
}

impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("data", &self.data)
.field("functions", &self.functions.keys())
.finish()
}
}

impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
self.data.eq(&other.data)
}
}

impl Context {
/// Initializes an empty context
pub fn new() -> Self {
Context { data: BTreeMap::new() }
Context { data: BTreeMap::new(), functions: Default::default() }
}

/// Converts the `val` parameter to `Value` and insert it into the context.
Expand Down Expand Up @@ -63,6 +82,15 @@ impl Context {
Ok(())
}

/// Registers Context-local function
pub fn register_function<T: FunctionRelaxed + 'static, S: Into<String>>(
&mut self,
key: S,
val: T,
) {
self.functions.insert(key.into(), Arc::new(val));
}

/// Appends the data of the `source` parameter to `self`, overwriting existing keys.
/// The source context will be dropped.
///
Expand Down Expand Up @@ -97,7 +125,7 @@ impl Context {
for (key, value) in m {
data.insert(key, value);
}
Ok(Context { data })
Ok(Context { data, functions: Default::default() })
}
_ => Err(Error::msg(
"Creating a Context from a Value/Serialize requires it being a JSON object",
Expand Down Expand Up @@ -127,6 +155,12 @@ impl Context {
pub fn contains_key(&self, index: &str) -> bool {
self.data.contains_key(index)
}

/// Looks up Context-local registered function
#[inline]
pub fn get_function(&self, fn_name: &str) -> Option<&Arc<dyn FunctionRelaxed>> {
self.functions.get(fn_name)
}
}

impl Default for Context {
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! See the [site](https://tera.netlify.com) for features and to get started.

#![deny(missing_docs)]
#![deny(warnings)]

#[macro_use]
mod macros;
Expand All @@ -24,7 +25,9 @@ mod utils;

// Template is meant to be used internally only but is exported for test/bench.
pub use crate::builtins::filters::Filter;
pub use crate::builtins::functions::Function;
pub use crate::builtins::functions::{
Function, FunctionRelaxed, FunctionRelaxedSafe, FunctionSafe,
};
pub use crate::builtins::testers::Test;
pub use crate::context::Context;
pub use crate::errors::{Error, ErrorKind, Result};
Expand Down
7 changes: 6 additions & 1 deletion src/renderer/call_stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::errors::{Error, Result};
use crate::renderer::for_loop::{ForLoop, ForLoopState};
use crate::renderer::stack_frame::{FrameContext, FrameType, StackFrame, Val};
use crate::template::Template;
use crate::Context;
use crate::{Context, FunctionRelaxed};
use std::sync::Arc;

/// Contains the user data and allows no mutation
#[derive(Debug)]
Expand Down Expand Up @@ -131,6 +132,10 @@ impl<'a> CallStack<'a> {
None
}

pub fn lookup_function(&self, fn_name: &str) -> Option<&Arc<dyn FunctionRelaxed>> {
self.context.inner.get_function(fn_name)
}

/// Add an assignment value (via {% set ... %} and {% set_global ... %} )
pub fn add_assignment(&mut self, key: &'a str, global: bool, value: Val<'a>) {
if global {
Expand Down
13 changes: 9 additions & 4 deletions src/renderer/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,6 @@ impl<'a> Processor<'a> {
function_call: &'a FunctionCall,
needs_escape: &mut bool,
) -> Result<Val<'a>> {
let tera_fn = self.tera.get_function(&function_call.name)?;
*needs_escape = !tera_fn.is_safe();

let err_wrap = |e| Error::call_function(&function_call.name, e);

let mut args = HashMap::new();
Expand All @@ -504,7 +501,15 @@ impl<'a> Processor<'a> {
);
}

Ok(Cow::Owned(tera_fn.call(&args).map_err(err_wrap)?))
if let Some(tera_fn) = self.call_stack.lookup_function(&function_call.name) {
*needs_escape = !tera_fn.is_safe();
Ok(Cow::Owned(tera_fn.call(&args).map_err(err_wrap)?))
} else {
let tera_fn = self.tera.get_function(&function_call.name)?;
*needs_escape = !tera_fn.is_safe();

Ok(Cow::Owned(tera_fn.call(&args).map_err(err_wrap)?))
}
}

fn eval_macro_call(&mut self, macro_call: &'a MacroCall, write: &mut impl Write) -> Result<()> {
Expand Down