Skip to content

Commit

Permalink
Merge pull request #63 from hmorimitsu/rapidflow
Browse files Browse the repository at this point in the history
Rapidflow
  • Loading branch information
hmorimitsu authored Mar 18, 2024
2 parents 1f77f90 + 88a6dff commit f08faef
Show file tree
Hide file tree
Showing 12 changed files with 404 additions and 68 deletions.
83 changes: 55 additions & 28 deletions ptlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
from argparse import Namespace
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional

import requests
import torch
Expand Down Expand Up @@ -250,32 +250,7 @@ def get_model(
pretrained_ckpt = args.pretrained_ckpt

if pretrained_ckpt is not None:
if Path(pretrained_ckpt).exists():
ckpt_path = pretrained_ckpt
elif hasattr(model_ref, "pretrained_checkpoints"):
ckpt_path = model_ref.pretrained_checkpoints.get(pretrained_ckpt)
if ckpt_path is None:
raise ValueError(
f"Invalid checkpoint name {pretrained_ckpt}. "
f'Choose one from {{{",".join(model.pretrained_checkpoints.keys())}}}'
)
else:
raise ValueError(
f"Cannot find checkpoint {pretrained_ckpt} for model {model_name}"
)

device = "cuda" if torch.cuda.is_available() else "cpu"

if Path(ckpt_path).exists():
ckpt = torch.load(ckpt_path, map_location=torch.device(device))
else:
model_dir = Path(hub.get_dir()) / "ptlflow" / "checkpoints"
ckpt = hub.load_state_dict_from_url(
ckpt_path,
model_dir=model_dir,
map_location=torch.device(device),
check_hash=True,
)
ckpt = load_checkpoint(pretrained_ckpt, model_ref, model_name)

state_dict = ckpt["state_dict"]
if "hyper_parameters" in ckpt:
Expand Down Expand Up @@ -330,10 +305,62 @@ def get_trainable_model_names() -> List[str]:
This function return the names of the model that have a loss function defined.
Returns
=======
-------
List[str]
The list of the model names that can be trained.
"""
return [
mname for mname in models_dict.keys() if get_model(mname).loss_fn is not None
]


def load_checkpoint(
pretrained_ckpt: str, model_ref: BaseModel, model_name: str
) -> Dict[str, Any]:
"""Try to load the checkpoint specified in pretrained_ckpt.
Parameters
----------
pretrained_ckpt : str
Path to a local file or name of a pretrained checkpoint.
model_ref : BaseModel
A reference to the model class. See the function get_model_reference() for more details.
model_name : str
A string representing the name of the model, just for debugging purposes.
Returns
-------
Dict[str, Any]
A dictionary of the loaded checkpoint. The output of torch.load().
See Also
--------
get_model_reference : To get a reference to the class of a model.
"""
if Path(pretrained_ckpt).exists():
ckpt_path = pretrained_ckpt
elif hasattr(model_ref, "pretrained_checkpoints"):
ckpt_path = model_ref.pretrained_checkpoints.get(pretrained_ckpt)
if ckpt_path is None:
raise ValueError(
f"Invalid checkpoint name {pretrained_ckpt}. "
f'Choose one from {{{",".join(model_ref.pretrained_checkpoints.keys())}}}'
)
else:
raise ValueError(
f"Cannot find checkpoint {pretrained_ckpt} for model {model_name}"
)

device = "cuda" if torch.cuda.is_available() else "cpu"

if Path(ckpt_path).exists():
ckpt = torch.load(ckpt_path, map_location=torch.device(device))
else:
model_dir = Path(hub.get_dir()) / "ptlflow" / "checkpoints"
ckpt = hub.load_state_dict_from_url(
ckpt_path,
model_dir=model_dir,
map_location=torch.device(device),
check_hash=True,
)
return ckpt
25 changes: 25 additions & 0 deletions ptlflow/models/rapidflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,31 @@ python test.py rapidflow --iters 12 --pretrained_ckpt sintel --test_dataset sint
python test.py rapidflow --iters 12 --pretrained_ckpt kitti --test_dataset kitti-2015 --input_pad_one_side
```

## Converting model to ONNX

The script [convert_to_onnx.py](convert_to_onnx.py) provides a simple example of how to convert RAPIDFlow models to ONNX format.
For example, to convert the 12 iterations version with the checkpoint trained on the Sintel dataset, you can run:
```bash
python convert_to_onnx.py rapidflow_it12 --checkpoint sintel
```

We also provide the script [onnx_infer.py](onnx_infer.py) to quickly test the converted ONNX model.
To test the model converted above, just run:
```bash
python onnx_infer.py rapidflow_it12.onnx
```

You can also provide your own images to test by providing an additional argument:
```bash
python onnx_infer.py rapidflow_it12.onnx --image_paths /path/to/first/image /path/to/second/image
```

### ONNX example limitations

Directly converting the model to ONNX as shown in this example will work, but it is not optimal.
To obtain the best convertion, it would be necessary to rewrite some parts of the code to remove conditions and operations that may change according to the input size.
Also, ONNX convertion only supports `--corr_mode allpairs`, which is not suitable for large images.

## Code license

The source code is released under the [Apache 2.0 LICENSE](LICENSE).
Expand Down
117 changes: 117 additions & 0 deletions ptlflow/models/rapidflow/convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Validate optical flow estimation performance on standard datasets."""

# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import sys
from argparse import ArgumentParser
from pathlib import Path

import torch
import torch.onnx

this_dir = Path(__file__).parent.resolve()
sys.path.insert(0, str(this_dir.parent.parent.parent))

from ptlflow import get_model, load_checkpoint
from ptlflow.models.rapidflow.rapidflow import RAPIDFlow


def _init_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument(
"model",
type=str,
choices=(
"rapidflow",
"rapidflow_it1",
"rapidflow_it2",
"rapidflow_it3",
"rapidflow_it6",
"rapidflow_it12",
),
help="Name of the model to use.",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="Path to the checkpoint to be loaded. It can also be one of the following names: \{chairs, things, sintel, kitti\}, in which case the respective pretrained checkpoint will be downloaded.",
)
parser.add_argument(
"--output_path",
type=str,
default=".",
help="Path to the directory where the converted onnx model will be saved.",
)
parser.add_argument(
"--input_size",
type=int,
nargs=2,
default=(384, 1280),
help="Size of the input image.",
)
return parser


def fuse_checkpoint_next1d_layers(state_dict):
fused_sd = {}
hv_pairs = {}
for name, param in state_dict.items():
if name.endswith("weight_h") or name.endswith("weight_v"):
name_prefix = name[: -(len("weight_h") + 1)]
orientation = name[-1]
if name_prefix not in hv_pairs:
hv_pairs[name_prefix] = {}
hv_pairs[name_prefix][orientation] = param
else:
fused_sd[name] = param

for name_prefix, param_pairs in hv_pairs.items():
weight = torch.einsum("cijk,cimj->cimk", param_pairs["h"], param_pairs["v"])
fused_sd[f"{name_prefix}.weight"] = weight
return fused_sd


def load_model(args):
model = get_model(args.model, args=args)
ckpt = load_checkpoint(args.checkpoint, RAPIDFlow, "rapidflow")
state_dict = fuse_checkpoint_next1d_layers(ckpt["state_dict"])
model.load_state_dict(state_dict, strict=True)
return model


if __name__ == "__main__":
parser = _init_parser()
parser = RAPIDFlow.add_model_specific_args(parser)
args = parser.parse_args()
args.corr_mode = "allpairs"
args.fuse_next1d_weights = True
args.simple_io = True

model = load_model(args)
sample_inputs = torch.randn(1, 2, 3, args.input_size[0], args.input_size[1])
if torch.cuda.is_available():
model = model.cuda()
sample_inputs = sample_inputs.cuda()

output_dir = Path(args.output_path)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / f"{args.model}.onnx")
torch.onnx.export(
model, sample_inputs, output_path, verbose=False, opset_version=16
)
print(f"ONNX model saved to: {output_path}")
8 changes: 5 additions & 3 deletions ptlflow/models/rapidflow/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# Code adapted from RAFT: https://github.com/princeton-vl/RAFT/blob/master/core/corr.py
# =============================================================================

import math

import torch
import torch.nn.functional as F
from .utils import bilinear_sampler
Expand Down Expand Up @@ -79,7 +81,7 @@ def corr(fmap1, fmap2):

corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim))
return corr / math.sqrt(dim)


class AlternateCorrBlock:
Expand Down Expand Up @@ -116,7 +118,7 @@ def __call__(self, coords):

corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim))
return corr / math.sqrt(dim)


def get_corr_block(
Expand All @@ -127,7 +129,7 @@ def get_corr_block(
alternate_corr: bool = False,
):
if alternate_corr:
if alt_cuda_corr is None:
if alt_cuda_corr is None or fmap1.device == torch.device("cpu"):
corr_fn = IterativeCorrBlock
else:
corr_fn = AlternateCorrBlock
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions ptlflow/models/rapidflow/image_samples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
These image samples correspond to the first image pair of the training set of the KITTI 2015 dataset.

The complete dataset is available at: [https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow).
Loading

0 comments on commit f08faef

Please sign in to comment.