This library provide a collection of utilities for machine learning experiments. MLXU is a thin wrapper on top of absl-py, ml_collections and wandb. It also provides some convenient JAX utils.
This library includes the following modules:
- config Experiment configuration and command line flags utils
- logging W&B logging utils
- jax_utils JAX specific utils
MLXU can be installed via pip. To install from PYPI
pip install mlxu
To install the latest version from GitHub
pip install git+https://github.com/young-geng/mlxu.git
Here are some examples for the utilities provide in MLXU
MLXU provides convenient wrappers around absl-py and wandb to make command line arg parsing and logging easy.
import mlxu
class ConfigurableModule(object):
# Define a configurable module with a default configuration. This module
# can be directly configured from the command line when plugged into
# the FLAGS.
@staticmethod
def get_default_config(updates=None):
config = mlxu.config_dict()
config.integer_value = 10
config.float_value = 1.0
config.string_value = 'hello'
config.boolean_value = True
return mlxu.update_config_dict(config, updates)
def __init__(self, config):
self.config = self.get_default_config(config)
# Define absl command line flags in one function, with automatic type inference.
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
name='example_experiment', # string flag
seed=42, # integer flag
learning_rate=1e-3, # floating point flag
use_mlxu=True, # boolean flag
num_gpus=int, # integer flag without default value
weight_decay=float, # floating point flag without default value
save_checkpoints=bool, # boolean flag without default value
epochs=(10, 'Number of epochs'), # we can also specify help strings
network_architecture=mlxu.config_dict(
activation='relu',
hidden_dim=128,
hidden_layers=5,
), # nested ml_collections config_dict
configurable_module=ConfigurableModule.get_default_config(), # nested custom config_dict
logger=mlxu.WandBLogger.get_default_config(), # logger configuration
)
def main(argv):
# Print the command line flags
mlxu.print_flags(FLAGS, FLAGS_DEF)
# Access the flags
name = FLAGS.name
seed = FLAGS.seed
# Access nested flags
activation = FLAGS.network_architecture.activation
hidden_dim = FLAGS.network_architecture.hidden_dim
configurable_module = ConfigurableModule(FLAGS.configurable_module)
# Create logger and log metrics
logger = mlxu.WandBLogger(FLAGS.logger, mlxu.get_user_flags(FLAGS, FLAGS_DEF))
logger.log({'step': 1, 'loss': 10.5})
logger.save_pickle([1, 2, 4, 5], 'checkpoint.pkl')
# Run the main function
if __name__ == "__main__":
mlxu.run(main)
The flags can be passed in via command line arguments:
python examples/cli_logging.py \
--name='example' \
--seed=24 \
--learning_rate=1.0 \
--use_mlxu=True \
--network_architecture.activation='gelu' \
--network_architecture.hidden_dim=126 \
--network_architecture.hidden_layers=2 \
--configurable_module.integer_value=20 \
--configurable_module.float_value=2.0 \
--logger.online=True \
--logger.project='mlxu_example'
Specifically, the logger.online
option controls whether the logger will upload
the data to W&B, and the logger.project
option controls the name of the W&B
project.
MLXU also provides convenient wrapper around JAX's random number generators to make it much easier to use
import jax
import jax.numpy as jnp
import mlxu
import mlxu.jax_utils as jax_utils
@jax.jit
def sum_of_random_uniform(rng_key):
# Capture RNG key to create a stateful rng key generator.
# As long as JaxRNG object is not pass through the function
# boundary, the function is still pure and jittable.
# JaxRNG object also supports the same tuple and dictionary usage like
# the jax_utils.next_rng function.
rng_generator = jax_utils.JaxRNG(rng_key)
output = jnp.zeros((2, 2))
for i in range(4):
# Each call returns a new key, altering the internal state of rng_generator
output += jax.random.uniform(rng_generator(), (2, 2))
return output
def main(argv):
# Setup global rng generator
jax_utils.init_rng(42)
# Get an rng key
rng_key = jax_utils.next_rng()
print(rng_key)
# Get a new rng key, this key should be different from the previous one
rng_key = jax_utils.next_rng()
print(rng_key)
# You can also get a tuple of N rng keys
k1, k2, k3 = jax_utils.next_rng(3)
print(k1, ', ', k2, ', ', k3)
# Dictionary of keys is also supported
rng_key_dict = jax_utils.next_rng(['k1', 'k2'])
print(rng_key_dict)
# Call a jitted function that makes use of stateful JaxRNG object internally
x = sum_of_random_uniform(jax_utils.next_rng())
print(x)
if __name__ == "__main__":
mlxu.run(main)