diff --git a/cmds/dutctl/rpc.go b/cmds/dutctl/rpc.go index e6cc30a..81dbb48 100644 --- a/cmds/dutctl/rpc.go +++ b/cmds/dutctl/rpc.go @@ -9,6 +9,7 @@ import ( "io/fs" "log" "os" + "strings" "sync" "connectrpc.com/connect" @@ -116,6 +117,10 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { path := msg.FileRequest.GetPath() log.Printf("File request for: %q\n", path) + if !isPartOfArgs(cmdArgs, path) { + log.Fatalf("Invalid file request: Requested file %q was not named in the command's arguments", path) + } + content, err := os.ReadFile(path) if err != nil { log.Fatal(err) @@ -141,6 +146,10 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { log.Printf("Received file: %q\n", path) + if !isPartOfArgs(cmdArgs, path) { + log.Fatalf("Invalid file transmission: Sent file %q was not named in the command's arguments", path) + } + if len(content) == 0 { log.Println("Received empty file content") } @@ -192,3 +201,13 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { return nil } + +func isPartOfArgs(args []string, token string) bool { + for _, arg := range args { + if strings.Contains(arg, token) { + return true + } + } + + return false +} diff --git a/pkg/module/dummy/dummy_file_transfer.go b/pkg/module/dummy/dummy_file_transfer.go index 5a8ae79..0b6faec 100644 --- a/pkg/module/dummy/dummy_file_transfer.go +++ b/pkg/module/dummy/dummy_file_transfer.go @@ -53,12 +53,13 @@ func (d *FT) Run(_ context.Context, s module.Session, args ...string) error { str := fmt.Sprintf("Called with %d arguments", len(args)) s.Print(str) - if len(args) != 1 { - return fmt.Errorf("expected 1 argument, got %d", len(args)) + const expectedArgsCnt = 2 + if len(args) != expectedArgsCnt { + return fmt.Errorf("expected 2 arguments, got %d", len(args)) } inFile := args[0] - str = fmt.Sprintf("Requesting file %q passed in arg[0]", inFile) + str = fmt.Sprintf("Requesting file %q passed in arg[0] as input", inFile) s.Print(str) fileReader, err := s.RequestFile(inFile) @@ -83,14 +84,16 @@ func (d *FT) Run(_ context.Context, s module.Session, args ...string) error { return fmt.Errorf("failed to process file: %v", err) } - log.Print("dummy.FT module: Sending back processed file") + outFile := args[1] + log.Printf("dummy.FT module: Sending back processed file %q", outFile) - err = s.SendFile("processed.txt", bytes.NewBuffer(result)) + err = s.SendFile(outFile, bytes.NewBuffer(result)) if err != nil { return fmt.Errorf("failed to send file: %v", err) } - s.Print("File operated successfully, check processed.txt") + str = fmt.Sprintf("File operated successfully, delivered %q as passed in arg[1] as output", outFile) + s.Print(str) return nil }