Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return more specific QNames for assignments #477

Merged
merged 3 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,17 @@ def f(self) -> "c":
if isinstance(assignment, Assignment):
assignment_node = assignment.node
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
results |= _NameUtil.find_qualified_name_for_import_alike(
names = _NameUtil.find_qualified_name_for_import_alike(
assignment_node, full_name
)
else:
results |= _NameUtil.find_qualified_name_for_non_import(
names = _NameUtil.find_qualified_name_for_non_import(
assignment, full_name
)
if not isinstance(node, str) and _is_assignment(node, assignment_node):
return names
else:
results |= names
elif isinstance(assignment, BuiltinAssignment):
results.add(
QualifiedName(
Expand Down Expand Up @@ -747,6 +751,30 @@ def _gen_dotted_names(
yield from name_values


def _is_assignment(node: cst.CSTNode, assignment_node: cst.CSTNode) -> bool:
"""
Returns true if ``node`` is part of the assignment at ``assignment_node``.

Normally this is just a simple identity check, except for imports where the
assignment is attached to the entire import statement but we are interested in
``Name`` nodes inside the statement.
"""
if node is assignment_node:
return True
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
aliases = assignment_node.names
if isinstance(aliases, cst.ImportStar):
return False
for alias in aliases:
if alias.name is node:
return True
asname = alias.asname
if asname is not None:
if asname.name is node:
return True
return False


class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
Expand Down
67 changes: 67 additions & 0 deletions libcst/metadata/tests/test_name_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,73 @@ class Foo:
names[attribute], {QualifiedName("a.aa.aaa", QualifiedNameSource.IMPORT)}
)

def test_multiple_qualified_names(self) -> None:
m, names = get_qualified_name_metadata_provider(
"""
if False:
def f(): pass
elif False:
from b import f
else:
import f
import a.b as f

f()
"""
)
if_ = ensure_type(m.body[0], cst.If)
first_f = ensure_type(if_.body.body[0], cst.FunctionDef)
second_f_alias = ensure_type(
ensure_type(
ensure_type(if_.orelse, cst.If).body.body[0],
cst.SimpleStatementLine,
).body[0],
cst.ImportFrom,
).names
self.assertFalse(isinstance(second_f_alias, cst.ImportStar))
second_f = second_f_alias[0].name
third_f_alias = ensure_type(
ensure_type(
ensure_type(ensure_type(if_.orelse, cst.If).orelse, cst.Else).body.body[
0
],
cst.SimpleStatementLine,
).body[0],
cst.Import,
).names
self.assertFalse(isinstance(third_f_alias, cst.ImportStar))
third_f = third_f_alias[0].name
fourth_f = ensure_type(
ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import
)
.names[0]
.asname,
cst.AsName,
).name
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)

self.assertEqual(
names[first_f], {QualifiedName("f", QualifiedNameSource.LOCAL)}
)
self.assertEqual(names[second_f], set())
self.assertEqual(names[third_f], set())
self.assertEqual(names[fourth_f], set())
self.assertEqual(
names[call],
{
QualifiedName("f", QualifiedNameSource.IMPORT),
QualifiedName("b.f", QualifiedNameSource.IMPORT),
QualifiedName("f", QualifiedNameSource.LOCAL),
QualifiedName("a.b", QualifiedNameSource.IMPORT),
},
)


class FullyQualifiedNameProviderTest(UnitTest):
def test_builtins(self) -> None:
Expand Down