Skip to content

Commit 34435db

Browse files
authored
implement Context-local functions and ergonomic wrappers marking functions as safe (Keats#666)
1 parent 3be477b commit 34435db

File tree

5 files changed

+113
-9
lines changed

5 files changed

+113
-9
lines changed

src/builtins/functions.rs

+57
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,26 @@ use serde_json::value::{from_value, to_value, Value};
88

99
use crate::errors::{Error, Result};
1010

11+
/// The context-local function type definition
12+
pub trait FunctionRelaxed {
13+
/// The context-local function type definition
14+
fn call(&self, args: &HashMap<String, Value>) -> Result<Value>;
15+
16+
/// Whether the current function's output should be treated as safe, defaults to `false`
17+
fn is_safe(&self) -> bool {
18+
false
19+
}
20+
}
21+
22+
impl<F> FunctionRelaxed for F
23+
where
24+
F: Fn(&HashMap<String, Value>) -> Result<Value>,
25+
{
26+
fn call(&self, args: &HashMap<String, Value>) -> Result<Value> {
27+
self(args)
28+
}
29+
}
30+
1131
/// The global function type definition
1232
pub trait Function: Sync + Send {
1333
/// The global function type definition
@@ -28,6 +48,43 @@ where
2848
}
2949
}
3050

51+
macro_rules! safe_function {
52+
($function_trait:ident, $function_wrapper:ident) => {
53+
/// Wrapper to make `is_safe` return `true` instead of `false` for a trait implementation.
54+
pub struct $function_wrapper<F>
55+
where
56+
F: $function_trait,
57+
{
58+
inner: F,
59+
}
60+
61+
impl<F> $function_trait for $function_wrapper<F>
62+
where
63+
F: $function_trait,
64+
{
65+
fn call(&self, args: &HashMap<String, Value>) -> Result<Value> {
66+
self.inner.call(args)
67+
}
68+
69+
fn is_safe(&self) -> bool {
70+
true
71+
}
72+
}
73+
74+
impl<F> From<F> for $function_wrapper<F>
75+
where
76+
F: $function_trait,
77+
{
78+
fn from(func: F) -> Self {
79+
$function_wrapper { inner: func }
80+
}
81+
}
82+
};
83+
}
84+
85+
safe_function!(FunctionRelaxed, FunctionRelaxedSafe);
86+
safe_function!(Function, FunctionSafe);
87+
3188
pub fn range(args: &HashMap<String, Value>) -> Result<Value> {
3289
let start = match args.get("start") {
3390
Some(val) => match from_value::<usize>(val.clone()) {

src/context.rs

+37-3
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,39 @@ use serde::ser::Serialize;
55
use serde_json::value::{to_value, Map, Value};
66

77
use crate::errors::{Error, Result as TeraResult};
8+
use crate::FunctionRelaxed;
9+
use std::sync::Arc;
810

911
/// The struct that holds the context of a template rendering.
1012
///
1113
/// Light wrapper around a `BTreeMap` for easier insertions of Serializable
1214
/// values
13-
#[derive(Debug, Clone, PartialEq)]
15+
#[derive(Clone)]
1416
pub struct Context {
1517
data: BTreeMap<String, Value>,
18+
/// Ignored by PartialEq!
19+
functions: BTreeMap<String, Arc<dyn FunctionRelaxed>>,
20+
}
21+
22+
impl std::fmt::Debug for Context {
23+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24+
f.debug_struct("Context")
25+
.field("data", &self.data)
26+
.field("functions", &self.functions.keys())
27+
.finish()
28+
}
29+
}
30+
31+
impl PartialEq for Context {
32+
fn eq(&self, other: &Self) -> bool {
33+
self.data.eq(&other.data)
34+
}
1635
}
1736

1837
impl Context {
1938
/// Initializes an empty context
2039
pub fn new() -> Self {
21-
Context { data: BTreeMap::new() }
40+
Context { data: BTreeMap::new(), functions: Default::default() }
2241
}
2342

2443
/// Converts the `val` parameter to `Value` and insert it into the context.
@@ -63,6 +82,15 @@ impl Context {
6382
Ok(())
6483
}
6584

85+
/// Registers Context-local function
86+
pub fn register_function<T: FunctionRelaxed + 'static, S: Into<String>>(
87+
&mut self,
88+
key: S,
89+
val: T,
90+
) {
91+
self.functions.insert(key.into(), Arc::new(val));
92+
}
93+
6694
/// Appends the data of the `source` parameter to `self`, overwriting existing keys.
6795
/// The source context will be dropped.
6896
///
@@ -97,7 +125,7 @@ impl Context {
97125
for (key, value) in m {
98126
data.insert(key, value);
99127
}
100-
Ok(Context { data })
128+
Ok(Context { data, functions: Default::default() })
101129
}
102130
_ => Err(Error::msg(
103131
"Creating a Context from a Value/Serialize requires it being a JSON object",
@@ -127,6 +155,12 @@ impl Context {
127155
pub fn contains_key(&self, index: &str) -> bool {
128156
self.data.contains_key(index)
129157
}
158+
159+
/// Looks up Context-local registered function
160+
#[inline]
161+
pub fn get_function(&self, fn_name: &str) -> Option<&Arc<dyn FunctionRelaxed>> {
162+
self.functions.get(fn_name)
163+
}
130164
}
131165

132166
impl Default for Context {

src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//! See the [site](https://tera.netlify.com) for features and to get started.
88
99
#![deny(missing_docs)]
10+
#![deny(warnings)]
1011

1112
#[macro_use]
1213
mod macros;
@@ -23,7 +24,9 @@ mod utils;
2324
// Library exports.
2425

2526
pub use crate::builtins::filters::Filter;
26-
pub use crate::builtins::functions::Function;
27+
pub use crate::builtins::functions::{
28+
Function, FunctionRelaxed, FunctionRelaxedSafe, FunctionSafe,
29+
};
2730
pub use crate::builtins::testers::Test;
2831
pub use crate::context::Context;
2932
pub use crate::errors::{Error, ErrorKind, Result};

src/renderer/call_stack.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ use crate::errors::{Error, Result};
88
use crate::renderer::for_loop::{ForLoop, ForLoopState};
99
use crate::renderer::stack_frame::{FrameContext, FrameType, StackFrame, Val};
1010
use crate::template::Template;
11-
use crate::Context;
11+
use crate::{Context, FunctionRelaxed};
12+
use std::sync::Arc;
1213

1314
/// Contains the user data and allows no mutation
1415
#[derive(Debug)]
@@ -131,6 +132,10 @@ impl<'a> CallStack<'a> {
131132
None
132133
}
133134

135+
pub fn lookup_function(&self, fn_name: &str) -> Option<&Arc<dyn FunctionRelaxed>> {
136+
self.context.inner.get_function(fn_name)
137+
}
138+
134139
/// Add an assignment value (via {% set ... %} and {% set_global ... %} )
135140
pub fn add_assignment(&mut self, key: &'a str, global: bool, value: Val<'a>) {
136141
if global {

src/renderer/processor.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,6 @@ impl<'a> Processor<'a> {
491491
function_call: &'a FunctionCall,
492492
needs_escape: &mut bool,
493493
) -> Result<Val<'a>> {
494-
let tera_fn = self.tera.get_function(&function_call.name)?;
495-
*needs_escape = !tera_fn.is_safe();
496-
497494
let err_wrap = |e| Error::call_function(&function_call.name, e);
498495

499496
let mut args = HashMap::new();
@@ -504,7 +501,15 @@ impl<'a> Processor<'a> {
504501
);
505502
}
506503

507-
Ok(Cow::Owned(tera_fn.call(&args).map_err(err_wrap)?))
504+
if let Some(tera_fn) = self.call_stack.lookup_function(&function_call.name) {
505+
*needs_escape = !tera_fn.is_safe();
506+
Ok(Cow::Owned(tera_fn.call(&args).map_err(err_wrap)?))
507+
} else {
508+
let tera_fn = self.tera.get_function(&function_call.name)?;
509+
*needs_escape = !tera_fn.is_safe();
510+
511+
Ok(Cow::Owned(tera_fn.call(&args).map_err(err_wrap)?))
512+
}
508513
}
509514

510515
fn eval_macro_call(&mut self, macro_call: &'a MacroCall, write: &mut impl Write) -> Result<()> {

0 commit comments

Comments
 (0)