Skip to content

Commit

Permalink
[#831] cmd: fix: --config parameter handling (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
yohamta authored Feb 13, 2025
1 parent 11f4649 commit bbff142
Show file tree
Hide file tree
Showing 26 changed files with 343 additions and 205 deletions.
88 changes: 88 additions & 0 deletions cmd/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package main

import (
"fmt"

"github.com/spf13/cobra"
"github.com/spf13/viper"
)

// Default values for the server.
const (
defaultHost = "localhost"
defaultPort = "8080"
)

type commandLineFlag struct {
name, shorthand, defaultValue, usage string
required bool
}

var (
configFlag = commandLineFlag{
name: "config",
shorthand: "c",
usage: "config file (default is $HOME/.config/dagu/config.yaml)",
}
dagsFlag = commandLineFlag{
name: "dags",
shorthand: "d",
usage: "location of DAG files (default is $HOME/.config/dagu/dags)",
}
hostFlag = commandLineFlag{
name: "host",
shorthand: "s",
defaultValue: defaultHost,
usage: "server host",
}
portFlag = commandLineFlag{
name: "port",
shorthand: "p",
defaultValue: defaultPort,
usage: "server port",
}
paramsFlag = commandLineFlag{
name: "params",
shorthand: "p",
usage: "parameters to pass to the DAG",
}
requestIDFlag = commandLineFlag{
name: "req",
shorthand: "r",
usage: "request ID",
}
)

func withRequired(flag commandLineFlag) commandLineFlag {
flag.required = true
flag.usage = fmt.Sprintf("%s (required)", flag.usage)
return flag
}

func withUsage(flag commandLineFlag, usage string) commandLineFlag {
flag.usage = usage
return flag
}

func initCommonFlags(cmd *cobra.Command, addFlags []commandLineFlag) {
addFlags = append(addFlags, configFlag)
for _, flag := range addFlags {
cmd.Flags().StringP(flag.name, flag.shorthand, flag.defaultValue, flag.usage)
if flag.required {
if err := cmd.MarkFlagRequired(flag.name); err != nil {
fmt.Printf("failed to mark flag %s as required: %v\n", flag.name, err)
}
}
}
}

func bindCommonFlags(cmd *cobra.Command, addFlags []string) error {
flags := []string{"config"}
flags = append(flags, addFlags...)
for _, flag := range flags {
if err := viper.BindPFlag(flag, cmd.Flags().Lookup(flag)); err != nil {
return fmt.Errorf("failed to bind flag %s: %w", flag, err)
}
}
return nil
}
28 changes: 15 additions & 13 deletions cmd/dry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ import (
"github.com/spf13/cobra"
)

const (
dryPrefix = "dry_"
)

func dryCmd() *cobra.Command {
return &cobra.Command{
cmd := &cobra.Command{
Use: "dry [flags] /path/to/spec.yaml",
Short: "Dry-runs specified DAG",
Long: `dagu dry /path/to/spec.yaml -- params1 params2`,
Args: cobra.MinimumNArgs(1),
RunE: wrapRunE(runDry),
PreRunE: func(cmd *cobra.Command, _ []string) error {
return bindCommonFlags(cmd, nil)
},
RunE: wrapRunE(runDry),
}

initCommonFlags(cmd, []commandLineFlag{paramsFlag})

return cmd
}

func runDry(cmd *cobra.Command, args []string) error {
Expand All @@ -29,8 +32,6 @@ func runDry(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to create setup: %w", err)
}

cmd.Flags().StringP("params", "p", "", "parameters")

ctx := setup.loggerContext(cmd.Context(), false)

loadOpts := []digraph.LoadOption{
Expand Down Expand Up @@ -60,7 +61,8 @@ func runDry(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to generate request ID: %w", err)
}

logFile, err := setup.openLogFile(ctx, dryPrefix, dag, requestID)
const logPrefix = "dry_"
logFile, err := setup.openLogFile(ctx, logPrefix, dag, requestID)
if err != nil {
return fmt.Errorf("failed to initialize log file for DAG %s: %w", dag.Name, err)
}
Expand All @@ -78,7 +80,7 @@ func runDry(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize client: %w", err)
}

agt := agent.New(
agentInstance := agent.New(
requestID,
dag,
filepath.Dir(logFile.Name()),
Expand All @@ -89,13 +91,13 @@ func runDry(cmd *cobra.Command, args []string) error {
agent.Options{Dry: true},
)

listenSignals(ctx, agt)
listenSignals(ctx, agentInstance)

if err := agt.Run(ctx); err != nil {
if err := agentInstance.Run(ctx); err != nil {
return fmt.Errorf("failed to execute DAG %s (requestID: %s): %w", dag.Name, requestID, err)
}

agt.PrintSummary(ctx)
agentInstance.PrintSummary(ctx)

return nil
}
5 changes: 5 additions & 0 deletions cmd/dry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ func TestDryCommand(t *testing.T) {
args: []string{"dry", th.DAG(t, "cmd/dry.yaml").Location},
expectedOut: []string{"Dry-run finished"},
},
{
name: "DryRunDAGWithParams",
args: []string{"dry", th.DAG(t, "cmd/dry_with_params.yaml").Location, "--params", "p3 p4"},
expectedOut: []string{`[p3 p4]`},
},
{
name: "DryRunDAGWithParamsAfterDash",
args: []string{"dry", th.DAG(t, "cmd/dry_with_params.yaml").Location, "--", "p5", "p6"},
Expand Down
16 changes: 0 additions & 16 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@ import (

"github.com/dagu-org/dagu/internal/build"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

var (
// version is set at build time
version = "0.0.0"

cfgFile string

rootCmd = &cobra.Command{
Use: build.Slug,
Short: "YAML-based DAG scheduling tool.",
Expand All @@ -31,18 +27,6 @@ func init() {
build.Version = version

registerCommands()

rootCmd.PersistentFlags().
StringVar(
&cfgFile, "config", "",
"config file (default is $HOME/.config/dagu/config.yaml)",
)

cobra.OnInitialize(func() {
if cfgFile != "" {
viper.SetConfigFile(cfgFile)
}
})
}

func registerCommands() {
Expand Down
26 changes: 12 additions & 14 deletions cmd/restart.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ import (
"github.com/spf13/cobra"
)

const (
restartPrefix = "restart_"
stopPollInterval = 100 * time.Millisecond
)

func restartCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "restart /path/to/spec.yaml",
Short: "Stop the running DAG and restart it",
Long: `dagu restart /path/to/spec.yaml`,
Args: cobra.ExactArgs(1),
RunE: wrapRunE(runRestart),
PreRunE: func(cmd *cobra.Command, _ []string) error {
return bindCommonFlags(cmd, nil)
},
RunE: wrapRunE(runRestart),
}
initCommonFlags(cmd, nil)
cmd.Flags().BoolP("quiet", "q", false, "suppress output")
return cmd
}
Expand All @@ -47,15 +46,12 @@ func runRestart(cmd *cobra.Command, args []string) error {
ctx := setup.loggerContext(cmd.Context(), quiet)

specFilePath := args[0]

// Load initial DAG configuration
dag, err := digraph.Load(ctx, specFilePath, digraph.WithBaseConfig(setup.cfg.Paths.BaseConfig))
if err != nil {
logger.Error(ctx, "Failed to load DAG", "path", specFilePath, "err", err)
return fmt.Errorf("failed to load DAG from %s: %w", specFilePath, err)
}

// Handle the restart process
if err := handleRestartProcess(ctx, setup, dag, quiet, specFilePath); err != nil {
logger.Error(ctx, "Failed to restart process", "path", specFilePath, "err", err)
return fmt.Errorf("restart process failed for DAG %s: %w", dag.Name, err)
Expand Down Expand Up @@ -111,7 +107,8 @@ func executeDAG(ctx context.Context, cli client.Client, setup *setup,
return fmt.Errorf("failed to generate request ID: %w", err)
}

logFile, err := setup.openLogFile(ctx, restartPrefix, dag, requestID)
const logPrefix = "restart_"
logFile, err := setup.openLogFile(ctx, logPrefix, dag, requestID)
if err != nil {
return fmt.Errorf("failed to initialize log file: %w", err)
}
Expand All @@ -127,7 +124,7 @@ func executeDAG(ctx context.Context, cli client.Client, setup *setup,
return fmt.Errorf("failed to initialize DAG store: %w", err)
}

agt := agent.New(
agentInstance := agent.New(
requestID,
dag,
filepath.Dir(logFile.Name()),
Expand All @@ -137,12 +134,12 @@ func executeDAG(ctx context.Context, cli client.Client, setup *setup,
setup.historyStore(),
agent.Options{Dry: false})

listenSignals(ctx, agt)
if err := agt.Run(ctx); err != nil {
listenSignals(ctx, agentInstance)
if err := agentInstance.Run(ctx); err != nil {
if quiet {
os.Exit(1)
} else {
agt.PrintSummary(ctx)
agentInstance.PrintSummary(ctx)
return fmt.Errorf("DAG execution failed: %w", err)
}
}
Expand All @@ -166,6 +163,7 @@ func stopDAGIfRunning(ctx context.Context, cli client.Client, dag *digraph.DAG)
}

func stopRunningDAG(ctx context.Context, cli client.Client, dag *digraph.DAG) error {
const stopPollInterval = 100 * time.Millisecond
for {
status, err := cli.GetCurrentStatus(ctx, dag)
if err != nil {
Expand Down
32 changes: 18 additions & 14 deletions cmd/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,28 @@ import (
"github.com/spf13/cobra"
)

const (
retryPrefix = "retry_"
)

func retryCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "retry --req=<request-id> /path/to/spec.yaml",
Short: "Retry the DAG execution",
Long: `dagu retry --req=<request-id> /path/to/spec.yaml`,
Args: cobra.ExactArgs(1),
RunE: wrapRunE(runRetry),
PreRunE: func(cmd *cobra.Command, _ []string) error {
return bindCommonFlags(cmd, nil)
},
RunE: wrapRunE(runRetry),
}

cmd.Flags().StringP("req", "r", "", "request-id")
_ = cmd.MarkFlagRequired("req")
cmd.Flags().BoolP("quiet", "q", false, "suppress output")
initRetryFlags(cmd)

return cmd
}

func initRetryFlags(cmd *cobra.Command) {
initCommonFlags(cmd, []commandLineFlag{withRequired(requestIDFlag)})
cmd.Flags().BoolP("quiet", "q", false, "suppress output")
}

func runRetry(cmd *cobra.Command, args []string) error {
setup, err := createSetup()
if err != nil {
Expand Down Expand Up @@ -99,7 +102,8 @@ func executeRetry(ctx context.Context, dag *digraph.DAG, setup *setup, originalS
return fmt.Errorf("failed to generate new request ID: %w", err)
}

logFile, err := setup.openLogFile(ctx, retryPrefix, dag, newRequestID)
const logPrefix = "retry_"
logFile, err := setup.openLogFile(ctx, logPrefix, dag, newRequestID)
if err != nil {
return fmt.Errorf("failed to initialize log file for DAG %s: %w", dag.Name, err)
}
Expand All @@ -121,7 +125,7 @@ func executeRetry(ctx context.Context, dag *digraph.DAG, setup *setup, originalS
return fmt.Errorf("failed to initialize client: %w", err)
}

agt := agent.New(
agentInstance := agent.New(
newRequestID,
dag,
filepath.Dir(logFile.Name()),
Expand All @@ -132,19 +136,19 @@ func executeRetry(ctx context.Context, dag *digraph.DAG, setup *setup, originalS
agent.Options{RetryTarget: &originalStatus.Status},
)

listenSignals(ctx, agt)
listenSignals(ctx, agentInstance)

if err := agt.Run(ctx); err != nil {
if err := agentInstance.Run(ctx); err != nil {
if quiet {
os.Exit(1)
} else {
agt.PrintSummary(ctx)
agentInstance.PrintSummary(ctx)
return fmt.Errorf("failed to execute DAG %s (requestID: %s): %w", dag.Name, newRequestID, err)
}
}

if !quiet {
agt.PrintSummary(ctx)
agentInstance.PrintSummary(ctx)
}

return nil
Expand Down
Loading

0 comments on commit bbff142

Please sign in to comment.