Skip to content

Commit

Permalink
Fix when variable is None for include_stats is True or False
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720569480
  • Loading branch information
CLU Authors authored and copybara-github committed Jan 30, 2025
1 parent 77b0602 commit 294a87b
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ def flatten_dict(
def _count_parameters(params: _ParamsContainer) -> int:
"""Returns the count of variables for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(np.prod(v.shape) for v in params.values())
return sum(np.prod(v.shape) for v in params.values() if v is not None)


def _parameters_size(params: _ParamsContainer) -> int:
"""Returns total size (bytes) for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(np.prod(v.shape) * v.dtype.itemsize for v in params.values())
return sum(
np.prod(v.shape) * v.dtype.itemsize
for v in params.values()
if v is not None
)


def count_parameters(params: _ParamsContainer) -> int:
Expand Down Expand Up @@ -127,6 +131,8 @@ def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:

def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
row = _make_row(name, value)
mean = mean or 0.0
std = std or 0.0
return _ParamRowWithStats(
**dataclasses.asdict(row),
mean=float(jax.device_get(mean)),
Expand Down Expand Up @@ -156,12 +162,11 @@ def _get_parameter_rows(
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested. Alternatively a `tf.Module` can be provided, in which case the
`trainable_variables` of the module will be used.
include_stats: If True, add columns with mean and std for each variable.
If the string "sharding", add column a column with the sharding of the
variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
The sharding of the variables is also added as a column.
include_stats: If True, add columns with mean and std for each variable. If
the string "sharding", add column a column with the sharding of the
variable. If the string "global", params are sharded global arrays and
this function assumes it is called on every host, i.e. can use
collectives. The sharding of the variables is also added as a column.
Returns:
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
Expand All @@ -185,12 +190,14 @@ def _get_parameter_rows(
case True:
mean_and_std = _mean_std(values)
return jax.tree_util.tree_map(
_make_row_with_stats, names, values, *mean_and_std)
_make_row_with_stats, names, values, *mean_and_std
)

case "global":
mean_and_std = _mean_std_jit(values)
return jax.tree_util.tree_map(
_make_row_with_stats_and_sharding, names, values, *mean_and_std)
_make_row_with_stats_and_sharding, names, values, *mean_and_std
)

case "sharding":
return jax.tree_util.tree_map(_make_row_with_sharding, names, values)
Expand Down Expand Up @@ -256,8 +263,7 @@ def __init__(self, name, values):
column_names = [field.name for field in dataclasses.fields(rows[0])]

columns = [
Column(name, [value_formatter(getattr(row, name))
for row in rows])
Column(name, [value_formatter(getattr(row, name)) for row in rows])
for name in column_names
]

Expand Down Expand Up @@ -312,12 +318,11 @@ def get_parameter_overview(
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable.
If the string "sharding", add column a column with the sharding of the
variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
The sharding of the variables is also added as a column.
include_stats: If True, add columns with mean and std for each variable. If
the string "sharding", add column a column with the sharding of the
variable. If the string "global", params are sharded global arrays and
this function assumes it is called on every host, i.e. can use
collectives. The sharding of the variables is also added as a column.
max_lines: If not `None`, the maximum number of variables to include.
Returns:
Expand Down Expand Up @@ -375,16 +380,19 @@ def log_parameter_overview(
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
include_stats: If True, add columns with mean and std for each variable. If
the string "global", params are sharded global arrays and this function
assumes it is called on every host, i.e. can use collectives.
max_lines: If not `None`, the maximum number of variables to include.
msg: Message to be logged before the overview.
jax_logging_process: Which JAX process ID should do the logging. None = all.
Use this to avoid logspam when include_stats="global".
"""

_log_parameter_overview(
params, include_stats=include_stats, max_lines=max_lines, msg=msg,
jax_logging_process=jax_logging_process
params,
include_stats=include_stats,
max_lines=max_lines,
msg=msg,
jax_logging_process=jax_logging_process,
)

0 comments on commit 294a87b

Please sign in to comment.