diff --git a/seqhasher_test.go b/seqhasher_test.go index 01bea50..33f58df 100644 --- a/seqhasher_test.go +++ b/seqhasher_test.go @@ -433,20 +433,93 @@ func TestAll(t *testing.T) { } func TestMainFunction(t *testing.T) { - // Set up arguments - oldArgs := os.Args - os.Args = []string{"cmd", "-version"} + tests := []struct { + name string + args []string + expectedOutput string + expectedError bool + checkUsage bool + }{ + { + name: "Version flag", + args: []string{"cmd", "-version"}, + expectedOutput: fmt.Sprintf("SeqHasher %s\n", version), + expectedError: false, + }, + { + name: "No input file", + args: []string{"cmd"}, + checkUsage: true, + }, + { + name: "Non-existent input file", + args: []string{"cmd", "nonexistent_file.fasta"}, + expectedOutput: "Error opening input: open nonexistent_file.fasta: no such file or directory\n", + expectedError: true, + }, + { + name: "Valid input file", + args: []string{"cmd", testFastaPath}, + expectedOutput: "", // The actual output will be the processed sequences + expectedError: false, + }, + } - // Call run() instead of main() - output := run() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset flags before each test + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) - // Restore arguments - os.Args = oldArgs + // Set up arguments + oldArgs := os.Args + os.Args = tt.args - // Check if version is printed - expectedOutput := fmt.Sprintf("SeqHasher %s\n", version) - if output != expectedOutput { - t.Errorf("Expected output %q, got %q", expectedOutput, output) + // Capture stdout and stderr + oldStdout := os.Stdout + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stdout = w + os.Stderr = w + + // Call run() instead of main() + output := run() + + // Restore stdout and stderr + w.Close() + os.Stdout = oldStdout + os.Stderr = oldStderr + + // Read captured output + var buf bytes.Buffer + io.Copy(&buf, r) + capturedOutput := buf.String() + + // Restore arguments + os.Args = oldArgs + + // Check output + if tt.expectedError { + if output != tt.expectedOutput { + t.Errorf("Expected error output %q, got %q", tt.expectedOutput, output) + } + } else if tt.checkUsage { + if !strings.Contains(capturedOutput, "Usage:") { + t.Errorf("Expected usage information, but it was not printed") + } + } else { + if tt.expectedOutput != "" && output != tt.expectedOutput { + t.Errorf("Expected output %q, got %q", tt.expectedOutput, output) + } + if tt.expectedOutput == "" && output == "" { + t.Errorf("Expected non-empty output, got empty string") + } + } + + // For the "Valid input file" case, check if the output contains processed sequences + if tt.name == "Valid input file" && !strings.Contains(output, ";seq1\n") { + t.Errorf("Expected processed sequences in output, but they were not found") + } + }) } }