diff --git a/src/sqlfluff_easy_ql/LT01.py b/src/sqlfluff_easy_ql/LT01.py index 8f4b9cc..08d4383 100644 --- a/src/sqlfluff_easy_ql/LT01.py +++ b/src/sqlfluff_easy_ql/LT01.py @@ -4,6 +4,7 @@ from sqlfluff.core.parser import NewlineSegment, WhitespaceSegment from sqlfluff.core.rules import BaseRule, LintFix, LintResult, RuleContext +from sqlfluff.core.parser.segments.base import BaseSegment from sqlfluff.core.rules.crawlers import SegmentSeekerCrawler @@ -21,9 +22,14 @@ class Rule_EasyQL_LT01(BaseRule): }) is_fix_compatible = True - def _aux_fix(self, segments, object_name_segment, object_name_idx) -> List[LintFix]: + def _aux_fix( + self, + segments, + object_name_segment, + object_name_idx + ) -> List[LintFix]: """Applies the fix logic for the different possible cases.""" - fix_segments = list() + fix_segments: List[BaseSegment] = list() # check whether there is a newline if segments[object_name_idx-2].type != "newline" and \ @@ -47,19 +53,23 @@ def _eval(self, context: RuleContext) -> List[LintResult]: obj_reference = context.segment.type.split("_")[1] + "_reference" table_name_idx, table_name_segment = next( ((idx, s) for idx, s in enumerate(segments) - if s.type == obj_reference or s.type == "function_name") # in procedures is function_name + if s.type == obj_reference or s.type == "function_name") ) # assert that there is a newline and 4 spaces before the name if segments[table_name_idx-2].type == "newline" and \ segments[table_name_idx-1].type == "whitespace" and \ len(segments[table_name_idx-1].raw) == 4: - return None + return list() else: # apply fixes and return lint result - fixes_to_apply = self._aux_fix(segments, table_name_segment, table_name_idx) + fixes_to_apply = self._aux_fix( + segments, table_name_segment, table_name_idx + ) - return LintResult( + return [ + LintResult( anchor=table_name_segment, fixes=fixes_to_apply, - description="The name of the created object must be in a new line and indented." + description="The name of the created object must be in a new line and indented." # noqa: E501 ) + ] diff --git a/src/sqlfluff_easy_ql/rules.py b/src/sqlfluff_easy_ql/rules.py index 395d013..fa72dae 100644 --- a/src/sqlfluff_easy_ql/rules.py +++ b/src/sqlfluff_easy_ql/rules.py @@ -53,4 +53,3 @@ def _eval(self, context: RuleContext): if "1=1" in seg.raw_upper.replace(" ", ""): return LintResult(anchor=seg) return None -