Skip to content

Commit

Permalink
feat: module resolution (#2567)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored May 12, 2023
1 parent c440ae2 commit 442e1d2
Show file tree
Hide file tree
Showing 15 changed files with 299 additions and 152 deletions.
8 changes: 6 additions & 2 deletions prql-compiler/prqlc/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,13 @@ impl Command {
let stmts = prql_to_pl(source)?;

// resolve
let (stmts, _) = semantic::resolve_only(stmts, None)?;
let (_, ctx) = semantic::resolve_only(stmts, None)?;

let frames = collect_frames(stmts);
let frames = if let Some(main) = ctx.find_main() {
collect_frames(main.clone())
} else {
vec![]
};

// combine with source
combine_prql_and_frames(source, frames).as_bytes().to_vec()
Expand Down
11 changes: 11 additions & 0 deletions prql-compiler/src/ast/pl/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ pub trait AstFold {
value: ty_def.value.map(|x| self.fold_expr(x)).transpose()?,
})
}
fn fold_module_def(&mut self, module_def: ModuleDef) -> Result<ModuleDef> {
fold_module_def(self, module_def)
}
fn fold_pipeline(&mut self, pipeline: Pipeline) -> Result<Pipeline> {
fold_pipeline(self, pipeline)
}
Expand Down Expand Up @@ -130,10 +133,18 @@ pub fn fold_stmt_kind<T: ?Sized + AstFold>(fold: &mut T, stmt_kind: StmtKind) ->
VarDef(var_def) => VarDef(fold.fold_var_def(var_def)?),
TypeDef(type_def) => TypeDef(fold.fold_type_def(type_def)?),
Main(expr) => Main(Box::new(fold.fold_expr(*expr)?)),
ModuleDef(module_def) => ModuleDef(fold.fold_module_def(module_def)?),
QueryDef(_) => stmt_kind,
})
}

fn fold_module_def<F: ?Sized + AstFold>(fold: &mut F, module_def: ModuleDef) -> Result<ModuleDef> {
Ok(ModuleDef {
name: module_def.name,
stmts: fold.fold_stmts(module_def.stmts)?,
})
}

pub fn fold_window<F: ?Sized + AstFold>(fold: &mut F, window: WindowFrame) -> Result<WindowFrame> {
Ok(WindowFrame {
kind: window.kind,
Expand Down
14 changes: 14 additions & 0 deletions prql-compiler/src/ast/pl/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum StmtKind {
FuncDef(FuncDef),
VarDef(VarDef),
TypeDef(TypeDef),
ModuleDef(ModuleDef),
Main(Box<Expr>),
}

Expand Down Expand Up @@ -71,6 +72,12 @@ pub struct TypeDef {
pub value: Option<Expr>,
}

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct ModuleDef {
pub name: String,
pub stmts: Vec<Stmt>,
}

impl From<StmtKind> for Stmt {
fn from(kind: StmtKind) -> Self {
Stmt {
Expand Down Expand Up @@ -142,6 +149,13 @@ impl Display for StmtKind {
write!(f, "type {}\n\n", ty_def.name)?;
}
}
StmtKind::ModuleDef(module_def) => {
write!(f, "module {} {{", module_def.name)?;
for stmt in &module_def.stmts {
write!(f, "{}", stmt.kind)?;
}
write!(f, "}}\n\n")?;
}
}
Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion prql-compiler/src/parser/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub fn lexer() -> impl Parser<char, Vec<(Token, std::ops::Range<usize>)>, Error
just("??").to(Token::Coalesce),
));

let control = one_of("></%=+-*[]().,:|!").map(Token::Control);
let control = one_of("></%=+-*[]().,:|!{}").map(Token::Control);

let ident = ident_part().map(Token::Ident);

Expand All @@ -59,6 +59,7 @@ pub fn lexer() -> impl Parser<char, Vec<(Token, std::ops::Range<usize>)>, Error
just("case"),
just("prql"),
just("type"),
just("module"),
))
.then_ignore(end_expr())
.map(|x| x.to_string())
Expand Down
26 changes: 18 additions & 8 deletions prql-compiler/src/parser/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,28 @@ use super::lexer::Token;
pub fn source() -> impl Parser<Token, Vec<Stmt>, Error = Simple<Token>> {
query_def()
.or_not()
.chain::<Stmt, _, _>(
choice((type_def(), var_def(), function_def()))
.map_with_span(into_stmt)
.separated_by(new_line().repeated())
.allow_leading()
.allow_trailing(),
)
.chain(main_pipeline().or_not())
.chain(module_contents())
.then_ignore(end())
.labelled("source file")
}

fn module_contents() -> impl Parser<Token, Vec<Stmt>, Error = Simple<Token>> {
recursive(|module_contents| {
let module_def = keyword("module")
.ignore_then(ident_part())
.then(module_contents.delimited_by(ctrl('{'), ctrl('}')))
.map(|(name, stmts)| StmtKind::ModuleDef(ModuleDef { name, stmts }))
.labelled("module definition");

choice((type_def(), var_def(), function_def(), module_def))
.map_with_span(into_stmt)
.separated_by(new_line().repeated())
.allow_leading()
.allow_trailing()
.chain(main_pipeline().or_not())
})
}

fn main_pipeline() -> impl Parser<Token, Stmt, Error = Simple<Token>> {
pipeline(expr_call())
.map_with_span(into_expr)
Expand Down
73 changes: 51 additions & 22 deletions prql-compiler/src/semantic/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,9 @@ pub enum TableColumn {
}

impl Context {
pub fn declare_func(&mut self, func_def: FuncDef, id: Option<usize>) {
let name = func_def.name.clone();

let path = vec![NS_STD.to_string()];
let ident = Ident { name, path };

pub fn declare(&mut self, ident: Ident, decl: DeclKind, id: Option<usize>) {
let decl = Decl {
kind: DeclKind::FuncDef(func_def),
kind: decl,
declared_at: id,
order: 0,
};
Expand All @@ -102,16 +97,14 @@ impl Context {

pub fn declare_var(
&mut self,
var_def: VarDef,
ident: Ident,
value: Box<Expr>,
id: Option<usize>,
span: Option<Span>,
) -> Result<()> {
let name = var_def.name;
let mut path = Vec::new();

let decl = match &var_def.value.ty {
let decl = match &value.ty {
Some(Ty::Table(_) | Ty::Infer) => {
let mut value = var_def.value;
let mut value = value;

let ty = value.ty.clone().unwrap();
let frame = ty.into_table().unwrap_or_else(|_| {
Expand All @@ -122,8 +115,6 @@ impl Context {
assumed.into_table().unwrap()
});

path = vec![NS_DEFAULT_DB.to_string()];

let columns = (frame.columns.iter())
.map(|col| match col {
FrameColumn::All { .. } => RelationColumn::Wildcard,
Expand All @@ -137,10 +128,9 @@ impl Context {
DeclKind::TableDecl(TableDecl { columns, expr })
}
Some(_) => {
let mut value = var_def.value;
let mut value = value;

// TODO: check that declaring module is std
if let Some(kind) = get_stdlib_decl(name.as_str()) {
if let Some(kind) = get_stdlib_decl(&ident) {
value.kind = kind;
}

Expand All @@ -161,13 +151,16 @@ impl Context {
order: 0,
};

let ident = Ident { name, path };
self.root_mod.insert(ident, decl).unwrap();

Ok(())
}

pub fn resolve_ident(&mut self, ident: &Ident) -> Result<Ident, String> {
pub fn resolve_ident(
&mut self,
ident: &Ident,
default_namespace: Option<&String>,
) -> Result<Ident, String> {
// special case: wildcard
if ident.name == "*" {
// TODO: we may want to raise an error if someone has passed `download*` in
Expand Down Expand Up @@ -200,6 +193,29 @@ impl Context {
}
}

let ident = if let Some(default_namespace) = default_namespace {
let ident = ident.clone().prepend(default_namespace.clone());

let decls = self.root_mod.lookup(&ident);
match decls.len() {
// no match: try match *
0 => ident,

// single match, great!
1 => return Ok(decls.into_iter().next().unwrap()),

// ambiguous
_ => {
return Err({
let decls = decls.into_iter().map(|d| d.to_string()).join(", ");
format!("Ambiguous name. Could be from any of {decls}")
})
}
}
} else {
ident.clone()
};

// fallback case: try to match with NS_INFER and infer the declaration from the original ident.
match self.resolve_ident_fallback(ident.clone(), NS_INFER) {
// The declaration and all needed parent modules were created
Expand Down Expand Up @@ -466,10 +482,23 @@ impl Context {
let table_fq = default_db_ident + Ident::from_name(global_name);
self.table_decl_to_frame(&table_fq, input_name, id)
}

pub fn find_main(&self) -> Option<&Expr> {
let main = Ident::from_name("main");
let decl = self.root_mod.get(&main)?;

let decl = decl.kind.as_table_decl()?;

Some(decl.expr.as_relation_var()?.as_ref())
}
}

fn get_stdlib_decl(name: &str) -> Option<ExprKind> {
let ty_lit = match name {
fn get_stdlib_decl(ident: &Ident) -> Option<ExprKind> {
if !ident.starts_with_part(NS_STD) {
return None;
}

let ty_lit = match ident.name.as_str() {
"int" => TyLit::Int,
"float" => TyLit::Float,
"bool" => TyLit::Bool,
Expand Down
46 changes: 26 additions & 20 deletions prql-compiler/src/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ use itertools::Itertools;

use crate::ast::pl::fold::AstFold;
use crate::ast::pl::{
self, Expr, ExprKind, Frame, FrameColumn, Ident, InterpolateItem, Range, SwitchCase, Ty,
WindowFrame,
self, Expr, ExprKind, Frame, FrameColumn, Ident, InterpolateItem, Range, StmtKind, SwitchCase,
Ty, WindowFrame,
};
use crate::ast::rq::{self, CId, Query, RelationColumn, TId, TableDecl, Transform};
use crate::error::{Error, Reason, Span, WithErrorInfo};
use crate::semantic::context::TableExpr;
use crate::semantic::module::Module;
use crate::utils::{toposort, IdGenerator};
use crate::utils::{toposort, IdGenerator, Pluck};

use super::context::{self, Context, DeclKind};
use super::module::NS_DEFAULT_DB;
Expand All @@ -29,30 +29,36 @@ pub fn lower_ast_to_ir(statements: Vec<pl::Stmt>, context: Context) -> Result<Qu

TableExtractor::extract(&mut l)?;

let mut query_def = None;
let mut main_pipeline = None;

for statement in statements {
use pl::StmtKind::*;
let def = statements
.into_iter()
.find_map(|stmt| match stmt.kind {
StmtKind::QueryDef(def) => Some(def),
_ => None,
})
.unwrap_or_default();

match statement.kind {
QueryDef(def) => query_def = Some(def),
Main(expr) => {
let relation = l.lower_relation(*expr)?;
main_pipeline = Some(relation);
}
FuncDef(_) | VarDef(_) | TypeDef(_) => {}
}
}
let relation = find_main_relation(&mut l)
.ok_or_else(|| Error::new_simple("Missing query").with_code("E0001"))?;

Ok(Query {
def: query_def.unwrap_or_default(),
def,
tables: l.table_buffer,
relation: main_pipeline
.ok_or_else(|| Error::new_simple("Missing query").with_code("E0001"))?,
relation,
})
}

fn find_main_relation(l: &mut Lowerer) -> Option<rq::Relation> {
let main = Ident::from_name("main");
let main_tid = l.table_mapping.get(&main)?;

let main = l
.table_buffer
.pluck(|t| if &t.id == main_tid { Ok(t) } else { Err(t) });

let main = main.into_iter().next()?;
Some(main.relation)
}

#[derive(Debug)]
struct Lowerer {
cid: IdGenerator<CId>,
Expand Down
7 changes: 4 additions & 3 deletions prql-compiler/src/semantic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub use self::module::Module;
use crate::ast::pl::frame::{Frame, FrameColumn};
use crate::ast::pl::Stmt;
use crate::ast::rq::Query;
use crate::semantic::module::NS_STD;
use crate::PRQL_VERSION;

use anyhow::{bail, Result};
Expand All @@ -24,7 +25,7 @@ use semver::{Version, VersionReq};
pub fn resolve(statements: Vec<Stmt>) -> Result<Query> {
let context = load_std_lib();

let (statements, context) = resolver::resolve(statements, context)?;
let (statements, context) = resolver::resolve(statements, vec![], context)?;

let query = lowering::lower_ast_to_ir(statements, context)?;

Expand All @@ -42,7 +43,7 @@ pub fn resolve_only(
) -> Result<(Vec<Stmt>, Context)> {
let context = context.unwrap_or_else(load_std_lib);

resolver::resolve(statements, context)
resolver::resolve(statements, vec![], context)
}

pub fn load_std_lib() -> Context {
Expand All @@ -55,7 +56,7 @@ pub fn load_std_lib() -> Context {
..Context::default()
};

let (_, context) = resolver::resolve(statements, context).unwrap();
let (_, context) = resolver::resolve(statements, vec![NS_STD.to_string()], context).unwrap();
context
}

Expand Down
4 changes: 2 additions & 2 deletions prql-compiler/src/semantic/reporting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ impl<'a> AstFold for Labeler<'a> {
}
}

pub fn collect_frames(stmts: Vec<Stmt>) -> Vec<(Span, Frame)> {
pub fn collect_frames(expr: Expr) -> Vec<(Span, Frame)> {
let mut collector = FrameCollector { frames: vec![] };

collector.fold_stmts(stmts).unwrap();
collector.fold_expr(expr).unwrap();

collector.frames.reverse();
collector.frames
Expand Down
Loading

0 comments on commit 442e1d2

Please sign in to comment.