Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type inference for optional and fold-related outputs. #261

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions trustfall_core/src/ir/indexed.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{collections::BTreeMap, convert::TryFrom, ptr, sync::Arc};
use std::{
collections::{BTreeMap, BTreeSet},
convert::TryFrom,
ptr,
sync::Arc,
};

use async_graphql_parser::types::{BaseType, Type};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -71,7 +76,7 @@ impl TryFrom<IRQuery> for IndexedQuery {
&mut outputs,
&ir_query.variables,
&ir_query.root_component,
0,
&mut vec![],
)?;

Ok(Self {
Expand All @@ -83,14 +88,45 @@ impl TryFrom<IRQuery> for IndexedQuery {
}
}

fn get_optional_vertices_in_component(component: &Arc<IRQueryComponent>) -> BTreeSet<Vid> {
let mut output = BTreeSet::new();
for edge in component.edges.values() {
if edge.optional || output.contains(&edge.from_vid) {
output.insert(edge.to_vid);
}
}
output
}

fn get_output_type(
output_at: Vid,
field_type: &Type,
component_optional_vertices: &BTreeSet<Vid>,
are_folds_optional: &[bool],
) -> Type {
let mut wrapped_output_type = field_type.clone();
if component_optional_vertices.contains(&output_at) {
wrapped_output_type.nullable = true;
}
for is_fold_optional in are_folds_optional.iter().rev() {
wrapped_output_type = Type {
base: BaseType::List(Box::new(wrapped_output_type)),
nullable: *is_fold_optional,
};
}
wrapped_output_type
}

fn add_data_from_component(
vids: &mut BTreeMap<Vid, Arc<IRQueryComponent>>,
eids: &mut BTreeMap<Eid, EdgeKind>,
outputs: &mut BTreeMap<Arc<str>, Output>,
variables: &BTreeMap<Arc<str>, Type>,
component: &Arc<IRQueryComponent>,
fold_depth: usize,
are_folds_optional: &mut Vec<bool>, // whether each level of @fold is inside an @optional
) -> Result<(), InvalidIRQueryError> {
let component_optional_vertices = get_optional_vertices_in_component(component);

// the root vertex Vid must belong to an existing vertex in the component
if component.vertices.get(&component.root).is_none() {
return Err(InvalidIRQueryError::GetBetterVariant(-1));
Expand Down Expand Up @@ -142,20 +178,13 @@ fn add_data_from_component(
return Err(InvalidIRQueryError::GetBetterVariant(2));
}

let output_type = if fold_depth == 0 {
field.field_type.clone()
} else {
let mut wrapped_output_type = field.field_type.clone();
for _ in 0..fold_depth {
wrapped_output_type = Type {
base: BaseType::List(Box::new(wrapped_output_type)),
nullable: false,
};
}
wrapped_output_type
};

let output_name = output_name.clone();
let output_type = get_output_type(
output_vid,
&field.field_type,
&component_optional_vertices,
are_folds_optional,
);
let output = Output {
name: output_name.clone(),
value_type: output_type,
Expand Down Expand Up @@ -193,7 +222,6 @@ fn add_data_from_component(
}
}

let new_fold_depth = fold_depth + 1;
for (eid, fold) in component.folds.iter() {
// The "to" vertex must have Vid equal to the folded edge's Eid + 1.
if usize::from(eid.0) + 1 != usize::from(fold.to_vid.0) {
Expand All @@ -220,26 +248,36 @@ fn add_data_from_component(

// Include fold-specific outputs in the list of outputs.
for (name, kind) in &fold.fold_specific_outputs {
let output_type = get_output_type(
fold.from_vid,
kind.field_type(),
&component_optional_vertices,
are_folds_optional,
);
outputs
.insert_or_error(
name.clone(),
Output {
name: name.clone(),
value_type: kind.field_type().clone(),
value_type: output_type,
vid: fold.to_vid,
},
)
.map_err(|_| InvalidIRQueryError::GetBetterVariant(15))?;
}

are_folds_optional.push(component_optional_vertices.contains(&fold.from_vid));
add_data_from_component(
vids,
eids,
outputs,
variables,
&fold.component,
new_fold_depth,
are_folds_optional,
)?;
are_folds_optional
.pop()
.expect("pushed value is no longer present");
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ TestInterpreterOutputData(
outputs: {
"count": Output(
name: "count",
value_type: "Int!",
value_type: "Int",
vid: Vid(4),
),
"start": Output(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TestInterpreterOutputData(
),
"multiplecount": Output(
name: "multiplecount",
value_type: "Int!",
value_type: "[Int!]!",
vid: Vid(3),
),
"multiples": Output(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ TestInterpreterOutputData(
),
"successor_counts": Output(
name: "successor_counts",
value_type: "Int!",
value_type: "Int",
vid: Vid(3),
),
"successors": Output(
name: "successors",
value_type: "[Int]!",
value_type: "[Int]",
vid: Vid(3),
),
"zero": Output(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ TestInterpreterOutputData(
outputs: {
"next_successor_counts": Output(
name: "next_successor_counts",
value_type: "Int!",
value_type: "[[Int!]!]!",
vid: Vid(4),
),
"next_successors": Output(
Expand All @@ -18,7 +18,7 @@ TestInterpreterOutputData(
),
"successor_counts": Output(
name: "successor_counts",
value_type: "Int!",
value_type: "[Int!]!",
vid: Vid(3),
),
"successors": Output(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TestInterpreterOutputData(
),
"successor_counts": Output(
name: "successor_counts",
value_type: "Int!",
value_type: "[Int!]!",
vid: Vid(3),
),
"successors": Output(
Expand Down