Skip to content

Commit

Permalink
Fix some typos in 'tracing.py'
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Jan 20, 2025
1 parent 4a1122b commit e3b4f2d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,14 +656,14 @@ def infer_output_type(
expansion_strategy: ExpansionStrategy,
num_implicit_inputs: int | None = None,
) -> Tuple[List[TracerLike], OutputType]:
"""Deduce the Jax ``OutputType`` of a part of program (typically, a function) given its
"""Deduce the Jax ``OutputType`` of a part of a program (typically, a function) given its
constants, input and ouput tracers or variables. Return the expanded outputs along with the
output type calculated.
The core task of this function is to find out which tracers have dynamic dimensions and
translate this information into the language of the De Brujin indices residing in Jax types. In
order to do this, we scan the outputs and mind what dimensions are already known (from the
intputs) and what are not known. The known dimensions are marked with InDBIdx and the unknown
inputs) and what are not known. The known dimensions are marked with InDBIdx and the unknown
dimensions are treated as calculated and marked using OutDBIdx.
Expand Down

0 comments on commit e3b4f2d

Please sign in to comment.