Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types to functions in printing.py #804

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,35 @@
VALID_ASSOC = {"left", "right", "either"}


def char_from_number(number):
"""Convert numbers to strings by rendering it in base 26 using capital letters as digits."""
def char_from_number(number: int) -> str:
"""Convert a number to a string.

base = 26
It renders it in base 26 using capital letters as digits.
For example: 3·26² + 2·26¹ + 0·26⁰ → "DCA"

rval = ""
Parameters
----------
number : int
The number to be converted.

if number == 0:
rval = "A"
Returns
-------
str
The converted string.
"""

base = 26

remainders = []

while number != 0:
remainder = number % base
new_char = chr(ord("A") + remainder)
rval = new_char + rval
number //= base
number, remainder = number // base, number % base
remainders.append(remainder)

return rval
if not remainders:
remainders = [0]

return "".join(chr(ord("A") + r) for r in remainders[::-1])


@singledispatch
Expand Down Expand Up @@ -1188,18 +1200,18 @@

def pydotprint(
fct,
outfile=None,
compact=True,
format="png",
with_ids=False,
high_contrast=True,
outfile: str | None = None,
compact: bool = True,
format: str = "png",
with_ids: bool = False,
high_contrast: bool = True,
cond_highlight=None,
colorCodes=None,
max_label_size=70,
scan_graphs=False,
var_with_name_simple=False,
print_output_file=True,
return_image=False,
colorCodes: dict | None = None,
max_label_size: int = 70,
scan_graphs: bool = False,
var_with_name_simple: bool = False,
print_output_file: bool = True,
return_image: bool = False,
):
"""Print to a file the graph of a compiled pytensor function's ops. Supports
all pydot output formats, including png and svg.
Expand Down Expand Up @@ -1664,7 +1676,9 @@
return rval


def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None):
def min_informative_str(
obj, indent_level: int = 0, _prev_obs: dict | None = None, _tag_generator=None
) -> str:
"""
Returns a string specifying to the user what obj is
The string will print out as much of the graph as is needed
Expand Down Expand Up @@ -1764,7 +1778,7 @@
return rval


def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> str:
"""
Returns a string, with no endlines, fully specifying
how a variable is computed. Does not include any memory
Expand Down Expand Up @@ -1820,7 +1834,7 @@
return rval


def position_independent_str(obj):
def position_independent_str(obj) -> str:
if isinstance(obj, Variable):
rval = "pytensor_var"
rval += "{type=" + str(obj.type) + "}"
Expand All @@ -1830,18 +1844,18 @@
return rval


def hex_digest(x):
def hex_digest(x: np.ndarray) -> str:
"""
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
assert isinstance(x, np.ndarray)
rval = hashlib.sha256(x.tostring()).hexdigest()
rval = hashlib.sha256(x.tobytes()).hexdigest()

Check warning on line 1852 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1852

Added line #L1852 was not covered by tests
# hex digest must be annotated with strides to avoid collisions
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
# into a tensor
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
rval += "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval += "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
return rval


Expand Down
17 changes: 17 additions & 0 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PatternPrinter,
PPrinter,
Print,
char_from_number,
debugprint,
default_printer,
get_node_by_id,
Expand All @@ -30,6 +31,22 @@
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable


@pytest.mark.parametrize(
"number,s",
[
(0, "A"),
(1, "B"),
(25, "Z"),
(26, "BA"),
(27, "BB"),
(3 * 26**2 + 2 * 26 + 0, "DCA"),
(42421337, "DOVPLX"),
],
)
def test_char_from_number(number: int, s: str):
assert char_from_number(number) == s


@pytest.mark.skipif(not pydot_imported, reason="pydot not available")
def test_pydotprint_cond_highlight():
# This is a REALLY PARTIAL TEST.
Expand Down
Loading