Skip to content

Commit

Permalink
let .pth the default extension in the cli
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Feb 10, 2024
1 parent 9f170e0 commit 9e53152
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def main_parser() -> argparse.ArgumentParser:
"--model",
default="frozen_model",
type=str,
help="Frozen model file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pt",
help="Frozen model file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth.",
)
parser_tst_subgroup = parser_tst.add_mutually_exclusive_group()
parser_tst_subgroup.add_argument(
Expand Down Expand Up @@ -512,7 +512,7 @@ def main_parser() -> argparse.ArgumentParser:
default=["graph.000", "graph.001", "graph.002", "graph.003"],
nargs="+",
type=str,
help="Frozen models file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pt.",
help="Frozen models file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth.",
)
parser_model_devi.add_argument(
"-s",
Expand Down
10 changes: 7 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "test":
dict_args["output"] = str(Path(FLAGS.model).with_suffix(".pt"))
dict_args["output"] = (
str(Path(FLAGS.model).with_suffix(".pth"))
if Path(FLAGS.model).suffix not in (".pt", ".pth")
else FLAGS.model
)
test(**dict_args)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand All @@ -316,8 +320,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
doc_train_input(**dict_args)
elif FLAGS.command == "model-devi":
dict_args["models"] = [
str(Path(mm).with_suffix(".pt"))
if Path(mm).suffix not in (".pb", ".pt")
str(Path(mm).with_suffix(".pth"))
if Path(mm).suffix not in (".pb", ".pt", ".pth")
else mm
for mm in dict_args["models"]
]
Expand Down

0 comments on commit 9e53152

Please sign in to comment.