Skip to content

Commit

Permalink
51 improve code coverage (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertohernandez1995 authored Jan 16, 2025
1 parent 881efbc commit a4259ae
Show file tree
Hide file tree
Showing 122 changed files with 3,490 additions and 1,618 deletions.
31 changes: 11 additions & 20 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,25 @@ def main():
"""

data_structures = {
'datasets': [
{'name': 'DS_1',
'DataStructure': [
{'name': 'Id_1',
'type':
'Integer',
'role': 'Identifier',
'nullable': False},
{'name': 'Me_1',
'type': 'Number',
'role': 'Measure',
'nullable': True}
]
}
"datasets": [
{
"name": "DS_1",
"DataStructure": [
{"name": "Id_1", "type": "Integer", "role": "Identifier", "nullable": False},
{"name": "Me_1", "type": "Number", "role": "Measure", "nullable": True},
],
}
]
}

data_df = pd.DataFrame(
{"Id_1": [1, 2, 3],
"Me_1": [10, 20, 30]})
data_df = pd.DataFrame({"Id_1": [1, 2, 3], "Me_1": [10, 20, 30]})

datapoints = {"DS_1": data_df}

run_result = run(script=script, data_structures=data_structures,
datapoints=datapoints)
run_result = run(script=script, data_structures=data_structures, datapoints=datapoints)

print(run_result)


if __name__ == '__main__':
if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ mypy = "^1.11.2"
pandas-stubs = "^2.2.3.241009"
stubs = "^1.0.0"
toml = "^0.10.2"
ruff = "^0.7.1"
ruff = "^0.8.3"


[tool.ruff]
line-length = 100
Expand Down
8 changes: 5 additions & 3 deletions src/vtlengine/API/_InternalApi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def _load_dataset_from_structure(structures: Dict[str, Any]) -> Dict[str, Any]:
for scalar_json in structures["scalars"]:
scalar_name = scalar_json["name"]
scalar = Scalar(
name=scalar_name, data_type=SCALAR_TYPES[scalar_json["type"]], value=None
name=scalar_name,
data_type=SCALAR_TYPES[scalar_json["type"]],
value=None,
)
datasets[scalar_name] = scalar # type: ignore[assignment]
return datasets
Expand Down Expand Up @@ -115,7 +117,7 @@ def _load_single_datapoint(datapoint: Union[str, Path]) -> Dict[str, Any]:


def _load_datapoints_path(
datapoints: Union[Path, str, List[Union[str, Path]]]
datapoints: Union[Path, str, List[Union[str, Path]]],
) -> Dict[str, Dataset]:
"""
Returns a dict with the data given from a Path.
Expand Down Expand Up @@ -156,7 +158,7 @@ def _load_datastructure_single(data_structure: Union[Dict[str, Any], Path]) -> D


def load_datasets(
data_structure: Union[Dict[str, Any], Path, List[Union[Dict[str, Any], Path]]]
data_structure: Union[Dict[str, Any], Path, List[Union[Dict[str, Any], Path]]],
) -> Dict[str, Dataset]:
"""
Loads multiple datasets.
Expand Down
13 changes: 11 additions & 2 deletions src/vtlengine/API/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ class __VTLSingleErrorListener(ErrorListener): # type: ignore[misc]
""" """

def syntaxError(
self, recognizer: Any, offendingSymbol: str, line: str, column: str, msg: str, e: Any
self,
recognizer: Any,
offendingSymbol: str,
line: str,
column: str,
msg: str,
e: Any,
) -> None:
raise Exception(
f"Not valid VTL Syntax \n "
Expand Down Expand Up @@ -150,7 +156,10 @@ class takes all of this information and checks it with the ast generated to

# Running the interpreter
interpreter = InterpreterAnalyzer(
datasets=structures, value_domains=vd, external_routines=ext_routines, only_semantic=True
datasets=structures,
value_domains=vd,
external_routines=ext_routines,
only_semantic=True,
)
with pd.option_context("future.no_silent_downcasting", True):
result = interpreter.visit(ast)
Expand Down
9 changes: 5 additions & 4 deletions src/vtlengine/AST/ASTConstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def visitHierRuleSignature(self, ctx: Parser.HierRuleSignatureContext):
if conditions:
identifiers_list = [
DefIdentifier(
value=elto.alias if getattr(elto, "alias", None) else elto.value, kind=kind
value=elto.alias if getattr(elto, "alias", None) else elto.value,
kind=kind,
)
for elto in conditions[0]
]
Expand All @@ -395,7 +396,7 @@ def visitHierRuleSignature(self, ctx: Parser.HierRuleSignatureContext):
def visitValueDomainSignature(self, ctx: Parser.ValueDomainSignatureContext):
"""
valueDomainSignature: CONDITION IDENTIFIER (AS IDENTIFIER)? (',' IDENTIFIER (AS IDENTIFIER)?)* ;
""" # noqa E501
""" # noqa E501
# AST_ASTCONSTRUCTOR.7
ctx_list = list(ctx.getChildren())
component_nodes = [
Expand Down Expand Up @@ -459,7 +460,7 @@ def visitCodeItemRelation(self, ctx: Parser.CodeItemRelationContext):
codeItemRelation: ( WHEN expr THEN )? codeItemRef codeItemRelationClause (codeItemRelationClause)* ;
( WHEN exprComponent THEN )? codetemRef=valueDomainValue comparisonOperand? codeItemRelationClause (codeItemRelationClause)*
""" # noqa E501
""" # noqa E501

ctx_list = list(ctx.getChildren())

Expand Down Expand Up @@ -512,7 +513,7 @@ def visitCodeItemRelation(self, ctx: Parser.CodeItemRelationContext):
def visitCodeItemRelationClause(self, ctx: Parser.CodeItemRelationClauseContext):
"""
(opAdd=( PLUS | MINUS ))? rightCodeItem=valueDomainValue ( QLPAREN rightCondition=exprComponent QRPAREN )?
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())

expr = [expr for expr in ctx_list if isinstance(expr, Parser.ExprContext)]
Expand Down
75 changes: 47 additions & 28 deletions src/vtlengine/AST/ASTConstructorModules/Expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class Expr(VtlVisitor):
Expr Definition.
_______________________________________________________________________________________"""
_______________________________________________________________________________________
"""

def visitExpr(self, ctx: Parser.ExprContext):
"""
Expand All @@ -66,7 +67,7 @@ def visitExpr(self, ctx: Parser.ExprContext):
| constant # constantExpr
| varID # varIdExpr
;
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())
c = ctx_list[0]

Expand Down Expand Up @@ -121,7 +122,6 @@ def visitExpr(self, ctx: Parser.ExprContext):

# CASE WHEN expr THEN expr ELSE expr END # caseExpr
elif isinstance(c, TerminalNodeImpl) and (c.getSymbol().type == Parser.CASE):

if len(ctx_list) % 4 != 3:
raise ValueError("Syntax error.")

Expand Down Expand Up @@ -221,7 +221,6 @@ def visitMembershipExpr(self, ctx: Parser.MembershipExprContext):
return previous_node

def visitClauseExpr(self, ctx: Parser.ClauseExprContext):

ctx_list = list(ctx.getChildren())

dataset = self.visitExpr(ctx_list[0])
Expand Down Expand Up @@ -347,7 +346,7 @@ def visitJoinClauseItem(self, ctx: Parser.JoinClauseItemContext):

def visitJoinClause(self, ctx: Parser.JoinClauseContext):
"""
joinClauseItem (COMMA joinClauseItem)* (USING componentID (COMMA componentID)*)?
JoinClauseItem (COMMA joinClauseItem)* (USING componentID (COMMA componentID)*)?
"""
ctx_list = list(ctx.getChildren())

Expand All @@ -373,7 +372,7 @@ def visitJoinClause(self, ctx: Parser.JoinClauseContext):
def visitJoinClauseWithoutUsing(self, ctx: Parser.JoinClauseWithoutUsingContext):
"""
joinClause: joinClauseItem (COMMA joinClauseItem)* (USING componentID (COMMA componentID)*)? ;
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())

clause_nodes = []
Expand All @@ -388,7 +387,7 @@ def visitJoinClauseWithoutUsing(self, ctx: Parser.JoinClauseWithoutUsingContext)
def visitJoinBody(self, ctx: Parser.JoinBodyContext):
"""
joinBody: filterClause? (calcClause|joinApplyClause|aggrClause)? (keepOrDropClause)? renameClause?
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())

body_nodes = []
Expand Down Expand Up @@ -457,7 +456,7 @@ def visitCallDataset(self, ctx: Parser.CallDatasetContext):
def visitEvalAtom(self, ctx: Parser.EvalAtomContext):
"""
| EVAL LPAREN routineName LPAREN (varID|scalarItem)? (COMMA (varID|scalarItem))* RPAREN (LANGUAGE STRING_CONSTANT)? (RETURNS evalDatasetType)? RPAREN # evalAtom
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())

routine_name = Terminals().visitRoutineName(ctx_list[2])
Expand Down Expand Up @@ -505,7 +504,7 @@ def visitEvalAtom(self, ctx: Parser.EvalAtomContext):
def visitCastExprDataset(self, ctx: Parser.CastExprDatasetContext):
"""
| CAST LPAREN expr COMMA (basicScalarType|valueDomainName) (COMMA STRING_CONSTANT)? RPAREN # castExprDataset
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())
c = ctx_list[0]

Expand Down Expand Up @@ -795,15 +794,19 @@ def visitTimeFunctions(self, ctx: Parser.TimeFunctionsContext):
return self.visitTimeDiffAtom(ctx)
elif isinstance(ctx, Parser.DateAddAtomContext):
return self.visitTimeAddAtom(ctx)
elif isinstance(ctx, (Parser.YearAtomContext,
Parser.MonthAtomContext,
Parser.DayOfMonthAtomContext,
Parser.DayOfYearAtomContext,
Parser.DayToYearAtomContext,
Parser.DayToMonthAtomContext,
Parser.YearTodayAtomContext,
Parser.MonthTodayAtomContext)):

elif isinstance(
ctx,
(
Parser.YearAtomContext,
Parser.MonthAtomContext,
Parser.DayOfMonthAtomContext,
Parser.DayOfYearAtomContext,
Parser.DayToYearAtomContext,
Parser.DayToMonthAtomContext,
Parser.YearTodayAtomContext,
Parser.MonthTodayAtomContext,
),
):
return self.visitTimeUnaryAtom(ctx)
else:
raise NotImplementedError
Expand Down Expand Up @@ -878,7 +881,7 @@ def visitFillTimeAtom(self, ctx: Parser.FillTimeAtomContext):
def visitTimeAggAtom(self, ctx: Parser.TimeAggAtomContext):
"""
TIME_AGG LPAREN periodIndTo=STRING_CONSTANT (COMMA periodIndFrom=(STRING_CONSTANT| OPTIONAL ))? (COMMA op=optionalExpr)? (COMMA (FIRST|LAST))? RPAREN # timeAggAtom
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())
c = ctx_list[0]

Expand Down Expand Up @@ -911,7 +914,11 @@ def visitTimeAggAtom(self, ctx: Parser.TimeAggAtomContext):
# AST_ASTCONSTRUCTOR.17
raise Exception("Optional as expression node is not allowed in Time Aggregation")
return TimeAggregation(
op=op, operand=operand_node, period_to=period_to, period_from=period_from, conf=conf
op=op,
operand=operand_node,
period_to=period_to,
period_from=period_from,
conf=conf,
)

def visitFlowAtom(self, ctx: Parser.FlowAtomContext):
Expand Down Expand Up @@ -988,7 +995,7 @@ def visitSetFunctions(self, ctx: Parser.SetFunctionsContext):
setExpr: UNION LPAREN left=expr (COMMA expr)+ RPAREN # unionAtom
| INTERSECT LPAREN left=expr (COMMA expr)+ RPAREN # intersectAtom
| op=(SETDIFF|SYMDIFF) LPAREN left=expr COMMA right=expr RPAREN # setOrSYmDiffAtom
""" # noqa E501
""" # noqa E501
if isinstance(ctx, Parser.UnionAtomContext):
return self.visitUnionAtom(ctx)
elif isinstance(ctx, Parser.IntersectAtomContext):
Expand Down Expand Up @@ -1031,7 +1038,7 @@ def visitSetOrSYmDiffAtom(self, ctx: Parser.SetOrSYmDiffAtomContext):
def visitHierarchyFunctions(self, ctx: Parser.HierarchyFunctionsContext):
"""
HIERARCHY LPAREN op=expr COMMA hrName=IDENTIFIER (conditionClause)? (RULE ruleComponent=componentID)? (validationMode)? (inputModeHierarchy)? outputModeHierarchy? RPAREN
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())
c = ctx_list[0]

Expand Down Expand Up @@ -1102,7 +1109,7 @@ def visitValidationFunctions(self, ctx: Parser.ValidationFunctionsContext):
def visitValidateDPruleset(self, ctx: Parser.ValidateDPrulesetContext):
"""
validationDatapoint: CHECK_DATAPOINT '(' expr ',' IDENTIFIER (COMPONENTS componentID (',' componentID)*)? (INVALID|ALL_MEASURES|ALL)? ')' ;
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())
c = ctx_list[0]

Expand Down Expand Up @@ -1137,7 +1144,7 @@ def visitValidateDPruleset(self, ctx: Parser.ValidateDPrulesetContext):
def visitValidateHRruleset(self, ctx: Parser.ValidateHRrulesetContext):
"""
CHECK_HIERARCHY LPAREN op=expr COMMA hrName=IDENTIFIER conditionClause? (RULE componentID)? validationMode? inputMode? validationOutput? RPAREN # validateHRruleset
""" # noqa E501
""" # noqa E501

ctx_list = list(ctx.getChildren())
c = ctx_list[0]
Expand Down Expand Up @@ -1199,7 +1206,7 @@ def visitValidateHRruleset(self, ctx: Parser.ValidateHRrulesetContext):
def visitValidationSimple(self, ctx: Parser.ValidationSimpleContext):
"""
| CHECK LPAREN op=expr (codeErr=erCode)? (levelCode=erLevel)? imbalanceExpr? output=(INVALID|ALL)? RPAREN # validationSimple
""" # noqa E501
""" # noqa E501
ctx_list = list(ctx.getChildren())
c = ctx_list[0]
token = c.getSymbol()
Expand Down Expand Up @@ -1331,11 +1338,19 @@ def visitAnSimpleFunction(self, ctx: Parser.AnSimpleFunctionContext):

if window is None:
window = Windowing(
type_="data", start=-1, stop=0, start_mode="preceding", stop_mode="current"
type_="data",
start=-1,
stop=0,
start_mode="preceding",
stop_mode="current",
)

return Analytic(
op=op_node, operand=operand, partition_by=partition_by, order_by=order_by, window=window
op=op_node,
operand=operand,
partition_by=partition_by,
order_by=order_by,
window=window,
)

def visitLagOrLeadAn(self, ctx: Parser.LagOrLeadAnContext):
Expand Down Expand Up @@ -1369,7 +1384,11 @@ def visitLagOrLeadAn(self, ctx: Parser.LagOrLeadAnContext):
raise Exception(f"{op_node} requires an offset parameter.")

return Analytic(
op=op_node, operand=operand, partition_by=partition_by, order_by=order_by, params=params
op=op_node,
operand=operand,
partition_by=partition_by,
order_by=order_by,
params=params,
)

def visitRatioToReportAn(self, ctx: Parser.RatioToReportAnContext):
Expand Down
Loading

0 comments on commit a4259ae

Please sign in to comment.