Skip to content

Commit

Permalink
exp save: add recursive arg
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum authored and BradyJ27 committed Apr 22, 2024
1 parent 288304e commit 51cf227
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
17 changes: 9 additions & 8 deletions dvc/commands/experiments/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def run(self):
ref = self.repo.experiments.save(
targets=self.args.targets,
name=self.args.name,
recursive=self.args.recursive,
force=self.args.force,
include_untracked=self.args.include_untracked,
message=self.args.message,
Expand Down Expand Up @@ -45,15 +46,15 @@ def add_parser(experiments_subparsers, parent_parser):
save_parser.add_argument(
"targets",
nargs="*",
help="""\
Stages to save. 'dvc.yaml' by default.
The targets can be path to a dvc.yaml file or `.dvc` file,
or a stage name from dvc.yaml file from
current working directory. To save a stage from dvc.yaml
from other directories, the target must be a path followed by colon `:`
and then the stage name name.
""",
help=("Limit DVC caching to these .dvc files and stage names."),
).complete = completion.DVCFILES_AND_STAGE
save_parser.add_argument(
"-R",
"--recursive",
action="store_true",
default=False,
help="Cache subdirectories of the specified directory.",
)
save_parser.add_argument(
"-f",
"--force",
Expand Down
9 changes: 7 additions & 2 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def save(
cls,
info: "ExecutorInfo",
targets: Optional[Iterable[str]] = None,
recursive: bool = False,
force: bool = False,
include_untracked: Optional[List[str]] = None,
message: Optional[str] = None,
Expand Down Expand Up @@ -297,9 +298,13 @@ def save(
stages = []
if targets:
for target in targets:
stages.append(dvc.commit(target, force=True, relink=False))
stages.append(
dvc.commit(
target, recursive=recursive, force=True, relink=False
)
)
else:
stages = dvc.commit([], force=True, relink=False)
stages = dvc.commit([], recursive=recursive, force=True, relink=False)
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
Expand Down
2 changes: 2 additions & 0 deletions dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def save(
repo: "Repo",
targets: Optional[Iterable[str]] = None,
name: Optional[str] = None,
recursive: bool = False,
force: bool = False,
include_untracked: Optional[List[str]] = None,
message: Optional[str] = None,
Expand All @@ -34,6 +35,7 @@ def save(
save_result = executor.save(
executor.info,
targets=targets,
recursive=recursive,
force=force,
include_untracked=include_untracked,
message=message,
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ def test_experiments_rename_invalid(dvc, scm, mocker, capsys, caplog):


def test_experiments_save(dvc, scm, mocker):
cli_args = parse_args(["exp", "save", "target", "--name", "exp-name", "--force"])
cli_args = parse_args(
["exp", "save", "target", "--name", "exp-name", "--recursive", "--force"]
)
assert cli_args.func == CmdExperimentsSave

cmd = cli_args.func(cli_args)
Expand All @@ -487,6 +489,7 @@ def test_experiments_save(dvc, scm, mocker):
cmd.repo,
targets=["target"],
name="exp-name",
recursive=True,
force=True,
include_untracked=[],
message=None,
Expand All @@ -507,6 +510,7 @@ def test_experiments_save_message(dvc, scm, mocker, flag):
cmd.repo,
targets=[],
name=None,
recursive=False,
force=False,
include_untracked=[],
message="custom commit message",
Expand Down

0 comments on commit 51cf227

Please sign in to comment.