Skip to content

Commit

Permalink
feat: validate ops.main() call for operator framework charms
Browse files Browse the repository at this point in the history
  • Loading branch information
dimaqq committed Sep 6, 2024
1 parent 0d3b271 commit c32422e
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 0 deletions.
92 changes: 92 additions & 0 deletions charmcraft/linters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import abc
import ast
import os
import re
import pathlib
import shlex
import typing
Expand Down Expand Up @@ -523,6 +524,96 @@ def run(self, basedir: pathlib.Path) -> str:
return self.Result.OK


class OpsMainCall(Linter):
"""Check that the entrypoint contains call to ops.main()."""

name = "ops-main-call"
url = f"{BASE_DOCS_URL}#heading--ops-main-call"
text = ""

def run(self, basedir: pathlib.Path) -> str:
entrypoint = get_entrypoint_from_dispatch(basedir)
if entrypoint is None:
self.text = "Cannot find a proper 'dispatch' script pointing to an entrypoint."
return self.Result.NONAPPLICABLE

if not entrypoint.exists():
self.text = f"Cannot find the entrypoint file: {str(entrypoint)!r}"
return self.Result.NONAPPLICABLE

if not self._check_main_calls(entrypoint.read_text()):
self.text = f"The ops.main() call missing from {str(entrypoint)!r}."
return self.Result.ERROR

if Framework().run(basedir) != Framework.Result.OPERATOR:
self.text = "The charm is not based on the operator framework"
return self.Result.NONAPPLICABLE

return self.Result.OK

def _check_main_calls(self, code: str):
tree = ast.parse(code)
imports = self._ops_main_imports(tree)
return self._detect_main_calls(tree, imports=imports)

def _ops_main_imports(self, tree: ast.AST) -> dict[str, str]:
rv = {}

class ImportVisitor(ast.NodeVisitor):
def visit_Import(self, node: ast.Import):
for alias in node.names:
# import ops
if alias.name == "ops":
rv[alias.asname or alias.name] = "ops"
if alias.name == "ops.main" and alias.asname:
rv[alias.asname] = "ops.main"
elif alias.name.startswith("ops.") and not alias.asname:
rv["ops"] = "ops"

def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
# from ops import main [as ops_main]
if node.module in ("ops", "ops.main") and alias.name == "main":
rv[alias.asname or alias.name] = f"{node.module}.main"

ImportVisitor().visit(tree)
return rv

def _detect_main_calls(self, tree: ast.AST, *, imports: dict[str, str]) -> bool:
main_call_sites = []

class OpsMainFinder(ast.NodeVisitor):
def visit_Call(self, node: ast.Call):
match node.func:
# ops.main.main(...)
case ast.Attribute(
value=ast.Attribute(value=ast.Name(id=first), attr=second),
attr=third,
):
call_site = f"{first}.{second}.{third}(...)"
# ops.main(...)
case ast.Attribute(value=ast.Name(id=first), attr=second):
call_site = f"{first}.{second}(...)"
# main(...)
case ast.Name(id=first):
call_site = f"{first}(...)"
case _:
call_site = "_dummy()"

match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)(.*)", call_site)
assert match
alias, rest = match.groups()
resolved = f"{imports.get(alias, '_dummy')}{rest}"

if resolved in ("ops.main(...)", "ops.main.main(...)"):
main_call_sites.append(call_site)

self.generic_visit(node)

OpsMainFinder().visit(tree)
return any(main_call_sites)


class AdditionalFiles(Linter):
"""Check that the charm does not contain any additional files in the prime directory.
Expand Down Expand Up @@ -584,5 +675,6 @@ def run(self, basedir: pathlib.Path) -> str:
NamingConventions,
Framework,
Entrypoint,
OpsMainCall,
AdditionalFiles,
]
95 changes: 95 additions & 0 deletions tests/test_linters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
JujuMetadata,
Language,
NamingConventions,
OpsMainCall,
check_dispatch_with_python_entrypoint,
get_entrypoint_from_dispatch,
)
Expand Down Expand Up @@ -1066,3 +1067,97 @@ def test_additional_files_checker_generated_ignore(tmp_path, file):

assert result == LintResult.OK
assert linter.text == "No additional files found in the charm."


CODE_SAMPLES = {
"canonical example": dedent(
"""
import ops
if __name__ == "__main__":
ops.main(SomeCharm)
"""
),
"recommended import style": dedent(
"""
import ops
ops.main(SomeCharm)
"""
),
"recommended import style, legacy call": dedent(
"""
import ops
ops.main.main(SomeCharm)
"""
),
"call with kwarg": dedent(
"""
import ops
ops.main(charm_class=SomeCharm)
"""
),
"function import": dedent(
"""
from ops import main
main(SomeCharm)
"""
),
"function import, legacy call": dedent(
"""
from ops import main
main.main(SomeCharm)
"""
),
"submodule import": dedent(
"""
import ops.main
ops.main(SomeCharm) # type: ignore
"""
),
"submodule import, legacy call": dedent(
"""
import ops.main
ops.main.main(SomeCharm)
"""
),
"multiple imports, simple": dedent(
"""
import ops
import ops.main
ops.main(SomeCharm)
"""
),
"multiple imports, earlier": dedent(
"""
import ops
from ops.main import main
ops.main(SomeCharm)
"""
),
"multiple imports, latter": dedent(
"""
import ops
from ops.main import main
main(SomeCharm)
"""
),
"function import from submodule": dedent(
"""
from ops.main import main
main(SomeCharm)
"""
),
"function import from submodule": dedent(
"""
from ops.main import main as alias
alias(SomeCharm)
"""
),
}


@pytest.mark.parametrize(
"code",
[pytest.param(v, id=k) for k, v in CODE_SAMPLES.items()],
)
def test_ops_main(code: str):
assert OpsMainCall()._check_main_calls(code)

0 comments on commit c32422e

Please sign in to comment.