From 75b30bb1f36ce986371a0f3d5cb4270bbe076d75 Mon Sep 17 00:00:00 2001 From: j178 <10510431+j178@users.noreply.github.com> Date: Tue, 2 Apr 2024 18:02:06 +0800 Subject: [PATCH] Fix recursive chained script Based on #982, need to rebase after that get merged. --- rye/src/cli/run.rs | 11 ++++++++--- rye/tests/test_run.rs | 11 +++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/rye/src/cli/run.rs b/rye/src/cli/run.rs index 394c025827..eab545f21e 100644 --- a/rye/src/cli/run.rs +++ b/rye/src/cli/run.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::env::{self, join_paths, split_paths}; use std::ffi::OsString; use std::path::PathBuf; @@ -62,7 +62,7 @@ pub fn execute(cmd: Args) -> Result<(), Error> { None => unreachable!(), }; - invoke_script(&pyproject, args, true)?; + invoke_script(&pyproject, args, true, &mut HashSet::new())?; unreachable!(); } @@ -70,6 +70,7 @@ fn invoke_script( pyproject: &PyProject, mut args: Vec, exec: bool, + seen_chain: &mut HashSet, ) -> Result { let venv_bin = pyproject.venv_bin_path(); let mut env_overrides = None; @@ -126,9 +127,13 @@ fn invoke_script( if args.len() != 1 { bail!("extra arguments to chained commands are not allowed"); } + if seen_chain.contains(&args[0]) { + bail!("found recursive chain script"); + } + seen_chain.insert(args[0].clone()); for args in commands { let status = - invoke_script(pyproject, args.into_iter().map(Into::into).collect(), false)?; + invoke_script(pyproject, args.into_iter().map(Into::into).collect(), false, seen_chain)?; if !status.success() { if !exec { return Ok(status); diff --git a/rye/tests/test_run.rs b/rye/tests/test_run.rs index 7fd5fabc38..1362a53a6d 100644 --- a/rye/tests/test_run.rs +++ b/rye/tests/test_run.rs @@ -209,7 +209,7 @@ fn test_script_chain() { // A nested `chain` script scripts["script_6"]["chain"] = value(Array::from_iter(["script_1", "script_4", "script_5"])); - // NEED FIX: A recursive `chain` script + // A recursive `chain` script scripts["script_7"]["chain"] = value(Array::from_iter(["script_7"])); doc["tool"]["rye"]["scripts"] = scripts; @@ -249,7 +249,14 @@ fn test_script_chain() { ----- stderr ----- error: script failed with exit code: 1 "###); - // rye_cmd_snapshot!(space.rye_cmd().arg("run").arg("script_7"), @r###""###); + rye_cmd_snapshot!(space.rye_cmd().arg("run").arg("script_7"), @r###" + success: false + exit_code: 1 + ----- stdout ----- + + ----- stderr ----- + error: found recursive chain script + "###); } #[test]