This repo contains the source code for experiments for our PDMCF paper:
Solving Large Multicommodity Network Flow Problems on GPUs
Fangzhao Zhang, Stephen Boyd
Paper: https://web.stanford.edu/~boyd/papers/pdmcf.html
We provide pdmcf.py
for our torch implementation, pdmcf_jax.py
for our jax implementation, and warm_start.py
for reproducing our warm start results.
Clone the repo and run the following command
conda create -n pdmcf python=3.12
conda activate pdmcf
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirement.txt
conda install pytorch-scatter -c pyg
Run PDMCF method
python pdmcf.py --n 100 --q 10
where --n
specifies number of nodes, --q
specifies number of neighbors. One can also add --mosek_check
to check with MOSEK result, note this requires to purchase MOSEK license. --float64
can be added to switch from float32 to float64, which gives more precise numerical results. --eps
changes user-specified stopping criterion, which is set to 1e-2 by default.
Install JAX
pip install -U "jax[cuda12]"
Run PDMCF method
python pdmcf_jax.py --n 100 --q 10
similarly, --mosek_check
can be added to check with MOSEK result, --float64
can be added to switch to higher accuracy, and --eps
changes user-specified stopping criterion.
Run PDMCF (with warm start) method
python warm_start.py --n 1000 --q 10 --nu 0.1
--nu
specifies weight perturbation ratio (default to 0.1), --float64
can be added to switch to higher accuracy.
We also provide a script for user-specified utility functions. See pdmcf_custom.py
and custom_utils.py
. Specifically, users need to provide the following functions in custom_utils.py
(we provide weighted log and weighted square root as examples in comment lines).
- prox_util: how to compute proximal operator to conjugate of negative utilities, i.e.,
$prox_{(-u_{ij})^\ast}$ - eval_f: how to evaluate negative utility functions, i.e.,
$-u_{ij}$ - nabla: how to compute derivative of utility functions, i.e.,
$u'_{ij}$ - mosek_solve: code for solving with MOSEK, used as benchmark comparison
--n
, --q
, --mosek_check
, --float64
, --eps
are valid arguments.