diff --git a/internal/execs/execs.go b/internal/execs/execs.go index c808365..44e1253 100644 --- a/internal/execs/execs.go +++ b/internal/execs/execs.go @@ -15,85 +15,77 @@ var ( resolvectl = Resolvectl ) -// RunCmd runs the cmd with args and sets stdin to s. -var RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { +// RunCmd runs the cmd with args and sets stdin to s, returns stdout and stderr. +var RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) (stdout, stderr []byte, err error) { c := exec.CommandContext(ctx, cmd, arg...) if s != "" { c.Stdin = bytes.NewBufferString(s) } - return c.Run() -} - -// RunCmdOutput runs the cmd with args and sets stdin to s, returns output. -var RunCmdOutput = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, error) { - c := exec.CommandContext(ctx, cmd, arg...) - if s != "" { - c.Stdin = bytes.NewBufferString(s) - } - return c.Output() + var outbuf, errbuf bytes.Buffer + c.Stdout = &outbuf + c.Stderr = &errbuf + err = c.Run() + stdout = outbuf.Bytes() + stderr = errbuf.Bytes() + return } // RunIP runs the "ip" command with args. -func RunIP(ctx context.Context, arg ...string) error { +func RunIP(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { return RunCmd(ctx, ip, "", arg...) } // RunIPLink runs the "ip link" command with args. -func RunIPLink(ctx context.Context, arg ...string) error { +func RunIPLink(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { a := append([]string{"link"}, arg...) return RunIP(ctx, a...) } // RunIPAddress runs the "ip address" command with args. -func RunIPAddress(ctx context.Context, arg ...string) error { +func RunIPAddress(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { a := append([]string{"address"}, arg...) return RunIP(ctx, a...) } // RunIP4Route runs the "ip -4 route" command with args. -func RunIP4Route(ctx context.Context, arg ...string) error { +func RunIP4Route(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { a := append([]string{"-4", "route"}, arg...) return RunIP(ctx, a...) } // RunIP6Route runs the "ip -6 route" command with args. -func RunIP6Route(ctx context.Context, arg ...string) error { +func RunIP6Route(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { a := append([]string{"-6", "route"}, arg...) return RunIP(ctx, a...) } // RunIP4Rule runs the "ip -4 rule" command with args. -func RunIP4Rule(ctx context.Context, arg ...string) error { +func RunIP4Rule(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { a := append([]string{"-4", "rule"}, arg...) return RunIP(ctx, a...) } // RunIP6Rule runs the "ip -6 rule" command with args. -func RunIP6Rule(ctx context.Context, arg ...string) error { +func RunIP6Rule(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { a := append([]string{"-6", "rule"}, arg...) return RunIP(ctx, a...) } // RunSysctl runs the "sysctl" command with args. -func RunSysctl(ctx context.Context, arg ...string) error { +func RunSysctl(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { return RunCmd(ctx, sysctl, "", arg...) } // RunNft runs the "nft -f -" command and sets stdin to s. -func RunNft(ctx context.Context, s string) error { +func RunNft(ctx context.Context, s string) (stdout, stderr []byte, err error) { return RunCmd(ctx, nft, s, "-f", "-") } // RunResolvectl runs the "resolvectl" command with args. -func RunResolvectl(ctx context.Context, arg ...string) error { +func RunResolvectl(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { return RunCmd(ctx, resolvectl, "", arg...) } -// RunResolvectlOutput runs the "resolvectl" command with args, returns output. -func RunResolvectlOutput(ctx context.Context, arg ...string) ([]byte, error) { - return RunCmdOutput(ctx, resolvectl, "", arg...) -} - // SetExecutables configures all executables from config. func SetExecutables(config *Config) { ip = config.IP diff --git a/internal/execs/execs_test.go b/internal/execs/execs_test.go index 09f3b56..9736c07 100644 --- a/internal/execs/execs_test.go +++ b/internal/execs/execs_test.go @@ -14,39 +14,30 @@ func TestRunCmd(t *testing.T) { // test not existing dir := t.TempDir() - if err := RunCmd(ctx, filepath.Join(dir, "does/not/exist"), ""); err == nil { + if _, _, err := RunCmd(ctx, filepath.Join(dir, "does/not/exist"), ""); err == nil { t.Errorf("running not existing command should fail: %v", err) } // test existing - if err := RunCmd(ctx, "echo", "", "this", "is", "a", "test"); err != nil { + if _, _, err := RunCmd(ctx, "echo", "", "this", "is", "a", "test"); err != nil { t.Errorf("running echo failed: %v", err) } // test with stdin - if err := RunCmd(ctx, "echo", "this is a test"); err != nil { + if _, _, err := RunCmd(ctx, "echo", "this is a test"); err != nil { t.Errorf("running echo failed: %v", err) } -} -// TestRunCmdOutput tests RunCmdOutput. -func TestRunCmdOutput(t *testing.T) { - ctx := context.Background() - dir := t.TempDir() - - // test with error - if _, err := RunCmdOutput(ctx, filepath.Join(dir, "does/not/exist"), ""); err == nil { - t.Errorf("running should fail: %v", err) + // test stdout + stdout, stderr, err := RunCmd(ctx, "cat", "this is a test") + if err != nil || string(stdout) != "this is a test" { + t.Errorf("running echo failed: %s, %s, %v", stdout, stderr, err) } - // test without error - if b, err := RunCmdOutput(ctx, "ls", "", "-d", dir); err != nil || len(b) == 0 { - t.Errorf("running should not fail: %v, %v", b, err) - } - - // test with stdin - if b, err := RunCmdOutput(ctx, "echo", "this is a test"); err != nil || len(b) == 0 { - t.Errorf("running should not fail: %v, %v", b, err) + // test stderr and error + stdout, stderr, err = RunCmd(ctx, "cat", "", "does/not/exist") + if err == nil || string(stderr) != "cat: does/not/exist: No such file or directory\n" { + t.Errorf("running echo failed: %s, %s, %v", stdout, stderr, err) } } @@ -56,13 +47,13 @@ func TestRunIP(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIP(context.Background(), "address", "show") + _, _, _ = RunIP(context.Background(), "address", "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -74,13 +65,13 @@ func TestRunIPLink(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIPLink(context.Background(), "show") + _, _, _ = RunIPLink(context.Background(), "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -92,13 +83,13 @@ func TestRunIPAddress(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIPAddress(context.Background(), "show") + _, _, _ = RunIPAddress(context.Background(), "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -110,13 +101,13 @@ func TestRunIP4Route(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIP4Route(context.Background(), "show") + _, _, _ = RunIP4Route(context.Background(), "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -128,13 +119,13 @@ func TestRunIP6Route(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIP6Route(context.Background(), "show") + _, _, _ = RunIP6Route(context.Background(), "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -146,13 +137,13 @@ func TestRunIP4Rule(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIP4Rule(context.Background(), "show") + _, _, _ = RunIP4Rule(context.Background(), "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -164,13 +155,13 @@ func TestRunIP6Rule(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunIP6Rule(context.Background(), "show") + _, _, _ = RunIP6Rule(context.Background(), "show") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -182,13 +173,13 @@ func TestRunSysctl(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunSysctl(context.Background(), "-q", "net.ipv4.conf.all.src_valid_mark=1") + _, _, _ = RunSysctl(context.Background(), "-q", "net.ipv4.conf.all.src_valid_mark=1") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -200,13 +191,13 @@ func TestRunNft(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")+" "+s) - return nil + return nil, nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunNft(context.Background(), "list tables") + _, _, _ = RunNft(context.Background(), "list tables") if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } @@ -218,32 +209,14 @@ func TestRunResolvectl(t *testing.T) { got := []string{} oldRunCmd := RunCmd - RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return []byte("OK"), nil, nil } defer func() { RunCmd = oldRunCmd }() - _ = RunResolvectl(context.Background(), "dns") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunResolvectlOutput tests RunResolvectlOutput. -func TestRunResolvectlOutput(t *testing.T) { - want := []string{"resolvectl dns"} - got := []string{} - - oldRunCmdOutput := RunCmdOutput - RunCmdOutput = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return []byte("OK"), nil - } - defer func() { RunCmdOutput = oldRunCmdOutput }() - - if b, err := RunResolvectlOutput(context.Background(), "dns"); err != nil || string(b) != "OK" { - t.Errorf("invalid return values %v, %v", b, err) + if b, _, err := RunResolvectl(context.Background(), "dns"); err != nil || string(b) != "OK" { + t.Errorf("invalid return values %s, %v", b, err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) diff --git a/internal/splitrt/excludes_test.go b/internal/splitrt/excludes_test.go index 2278bd6..b290037 100644 --- a/internal/splitrt/excludes_test.go +++ b/internal/splitrt/excludes_test.go @@ -34,9 +34,9 @@ func TestExcludesAddStatic(t *testing.T) { // set testing runNft function got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } // test adding excludes @@ -68,9 +68,9 @@ func TestExcludesAddDynamic(t *testing.T) { // set testing runNft function got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } // test adding excludes @@ -103,9 +103,9 @@ func TestExcludesRemove(t *testing.T) { // set testing runNft function got := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -159,9 +159,9 @@ func TestExcludesRemove(t *testing.T) { // test with nft error got = []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return errors.New("test error") + return nil, nil, errors.New("test error") } for _, exclude := range excludes { e.AddStatic(ctx, exclude) @@ -182,9 +182,9 @@ func TestExcludesCleanup(t *testing.T) { // set testing runNft function got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } // test without excludes diff --git a/internal/splitrt/filter.go b/internal/splitrt/filter.go index fb8b25b..afcdb21 100644 --- a/internal/splitrt/filter.go +++ b/internal/splitrt/filter.go @@ -103,15 +103,21 @@ table inet oc-daemon-routing { ` r := strings.NewReplacer("$FWMARK", fwMark) rules := r.Replace(routeRules) - if err := execs.RunNft(ctx, rules); err != nil { - log.WithError(err).Error("SplitRouting error setting routing rules") + if stdout, stderr, err := execs.RunNft(ctx, rules); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting routing rules") } } // unsetRoutingRules removes the nftables rules for routing. func unsetRoutingRules(ctx context.Context) { - if err := execs.RunNft(ctx, "delete table inet oc-daemon-routing"); err != nil { - log.WithError(err).Error("SplitRouting error unsetting routing rules") + if stdout, stderr, err := execs.RunNft(ctx, "delete table inet oc-daemon-routing"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error unsetting routing rules") } } @@ -129,8 +135,11 @@ func addLocalAddresses(ctx context.Context, device, family string, addresses []* nftconf += "fib saddr type != local counter drop\n" } - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("SplitRouting error adding local addresses") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error adding local addresses") } } @@ -160,8 +169,11 @@ func rejectIPVersion(ctx context.Context, device, version string) { nftconf += "counter jump rejectipversion\n" } - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("SplitRouting error setting ip version reject rules") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ip version reject rules") } } @@ -186,8 +198,11 @@ func addExclude(ctx context.Context, address *net.IPNet) { nftconf := fmt.Sprintf("add element inet oc-daemon-routing %s { %s }", set, address) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("SplitRouting error adding exclude") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error adding exclude") } } @@ -210,15 +225,18 @@ func setExcludes(ctx context.Context, addresses []*net.IPNet) { } // run command - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("SplitRouting error setting excludes") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting excludes") } } // cleanupRoutingRules cleans up the nftables rules for routing after a // failed shutdown. func cleanupRoutingRules(ctx context.Context) { - if err := execs.RunNft(ctx, "delete table inet oc-daemon-routing"); err == nil { + if _, _, err := execs.RunNft(ctx, "delete table inet oc-daemon-routing"); err == nil { log.Debug("SplitRouting cleaned up nft") } } diff --git a/internal/splitrt/route.go b/internal/splitrt/route.go index 936b1be..e32342f 100644 --- a/internal/splitrt/route.go +++ b/internal/splitrt/route.go @@ -10,96 +10,129 @@ import ( // addDefaultRouteIPv4 adds default routing for IPv4. func addDefaultRouteIPv4(ctx context.Context, device, rtTable, rulePrio1, fwMark, rulePrio2 string) { // set default route - if err := execs.RunIP4Route(ctx, "add", "0.0.0.0/0", "dev", device, + if stdout, stderr, err := execs.RunIP4Route(ctx, "add", "0.0.0.0/0", "dev", device, "table", rtTable); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv4 default route") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv4 default route") } // set routing rules - if err := execs.RunIP4Rule(ctx, "add", "iif", device, "table", "main", + if stdout, stderr, err := execs.RunIP4Rule(ctx, "add", "iif", device, "table", "main", "pref", rulePrio1); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv4 routing rule 1") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv4 routing rule 1") } - if err := execs.RunIP4Rule(ctx, "add", "not", "fwmark", fwMark, + if stdout, stderr, err := execs.RunIP4Rule(ctx, "add", "not", "fwmark", fwMark, "table", rtTable, "pref", rulePrio2); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv4 routing rule 2") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv4 routing rule 2") } // set src_valid_mark with sysctl - if err := execs.RunSysctl(ctx, "-q", + if stdout, stderr, err := execs.RunSysctl(ctx, "-q", "net.ipv4.conf.all.src_valid_mark=1"); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv4 sysctl") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv4 sysctl") } } // addDefaultRouteIPv6 adds default routing for IPv6. func addDefaultRouteIPv6(ctx context.Context, device, rtTable, rulePrio1, fwMark, rulePrio2 string) { // set default route - if err := execs.RunIP6Route(ctx, "add", "::/0", "dev", device, "table", + if stdout, stderr, err := execs.RunIP6Route(ctx, "add", "::/0", "dev", device, "table", rtTable); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv6 default route") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv6 default route") } // set routing rules - if err := execs.RunIP6Rule(ctx, "add", "iif", device, "table", "main", + if stdout, stderr, err := execs.RunIP6Rule(ctx, "add", "iif", device, "table", "main", "pref", rulePrio1); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv6 routing rule 1") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv6 routing rule 1") } - if err := execs.RunIP6Rule(ctx, "add", "not", "fwmark", fwMark, + if stdout, stderr, err := execs.RunIP6Rule(ctx, "add", "not", "fwmark", fwMark, "table", rtTable, "pref", rulePrio2); err != nil { - log.WithError(err).Error("SplitRouting error setting ipv6 routing rule 2") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error setting ipv6 routing rule 2") } } // deleteDefaultRouteIPv4 removes default routing for IPv4. func deleteDefaultRouteIPv4(ctx context.Context, device, rtTable string) { // delete routing rules - if err := execs.RunIP4Rule(ctx, "delete", "table", rtTable); err != nil { - log.WithError(err).Error("SplitRouting error deleting ipv4 routing rule 2") + if stdout, stderr, err := execs.RunIP4Rule(ctx, "delete", "table", rtTable); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error deleting ipv4 routing rule 2") } - if err := execs.RunIP4Rule(ctx, "delete", "iif", device, "table", + if stdout, stderr, err := execs.RunIP4Rule(ctx, "delete", "iif", device, "table", "main"); err != nil { - log.WithError(err).Error("SplitRouting error deleting ipv4 routing rule 1") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error deleting ipv4 routing rule 1") } } // deleteDefaultRouteIPv6 removes default routing for IPv6. func deleteDefaultRouteIPv6(ctx context.Context, device, rtTable string) { // delete routing rules - if err := execs.RunIP6Rule(ctx, "delete", "table", rtTable); err != nil { - log.WithError(err).Error("SplitRouting error deleting ipv6 routing rule 2") + if stdout, stderr, err := execs.RunIP6Rule(ctx, "delete", "table", rtTable); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error deleting ipv6 routing rule 2") } - if err := execs.RunIP6Rule(ctx, "delete", "iif", device, "table", + if stdout, stderr, err := execs.RunIP6Rule(ctx, "delete", "iif", device, "table", "main"); err != nil { - log.WithError(err).Error("SplitRouting error deleting ipv6 routing rule 1") + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("SplitRouting error deleting ipv6 routing rule 1") } } // cleanupRouting cleans up the routing configuration after a failed shutdown. func cleanupRouting(ctx context.Context, rtTable, rulePrio1, rulePrio2 string) { // delete ipv4 routing rules - if err := execs.RunIP4Rule(ctx, "delete", "pref", rulePrio1); err == nil { + if _, _, err := execs.RunIP4Rule(ctx, "delete", "pref", rulePrio1); err == nil { log.Debug("SplitRouting cleaned up ipv4 routing rule 1") } - if err := execs.RunIP4Rule(ctx, "delete", "pref", rulePrio2); err == nil { + if _, _, err := execs.RunIP4Rule(ctx, "delete", "pref", rulePrio2); err == nil { log.Debug("SplitRouting cleaned up ipv4 routing rule 2") } // delete ipv6 routing rules - if err := execs.RunIP6Rule(ctx, "delete", "pref", rulePrio1); err == nil { + if _, _, err := execs.RunIP6Rule(ctx, "delete", "pref", rulePrio1); err == nil { log.Debug("SplitRouting cleaned up ipv6 routing rule 1") } - if err := execs.RunIP6Rule(ctx, "delete", "pref", rulePrio2); err == nil { + if _, _, err := execs.RunIP6Rule(ctx, "delete", "pref", rulePrio2); err == nil { log.Debug("SplitRouting cleaned up ipv6 routing rule 2") } // flush ipv4 routing table - if err := execs.RunIP4Route(ctx, "flush", "table", rtTable); err == nil { + if _, _, err := execs.RunIP4Route(ctx, "flush", "table", rtTable); err == nil { log.Debug("SplitRouting cleaned up ipv4 routing table") } // flush ipv6 routing table - if err := execs.RunIP6Route(ctx, "flush", "table", rtTable); err == nil { + if _, _, err := execs.RunIP6Route(ctx, "flush", "table", rtTable); err == nil { log.Debug("SplitRouting cleaned up ipv6 routing table") } } diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index 73b78e9..a696102 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -25,9 +25,9 @@ func TestSplitRoutingHandleDeviceUpdate(t *testing.T) { got := []string{"nothing else"} oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -79,9 +79,9 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { got := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -159,9 +159,9 @@ func TestSplitRoutingHandleDNSReport(t *testing.T) { got := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -188,8 +188,8 @@ func TestSplitRoutingHandleDNSReport(t *testing.T) { func TestSplitRoutingStartStop(t *testing.T) { // set dummy low level functions for testing oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { - return nil + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -263,8 +263,8 @@ func TestSplitRoutingStartStop(t *testing.T) { s.Stop() // test with nft errors - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { - return errors.New("test error") + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return nil, nil, errors.New("test error") } s = NewSplitRouting(NewConfig(), vpnconfig.New()) if err := s.Start(); err != nil { @@ -312,13 +312,13 @@ func TestCleanup(t *testing.T) { got := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { if s == "" { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } got = append(got, cmd+" "+strings.Join(arg, " ")+" "+s) - return nil + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() diff --git a/internal/trafpol/allowdevs_test.go b/internal/trafpol/allowdevs_test.go index 423154d..599e6c9 100644 --- a/internal/trafpol/allowdevs_test.go +++ b/internal/trafpol/allowdevs_test.go @@ -14,9 +14,9 @@ func TestAllowDevsAdd(t *testing.T) { ctx := context.Background() got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } // test adding @@ -42,9 +42,9 @@ func TestAllowDevsRemove(t *testing.T) { ctx := context.Background() got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } // test removing device diff --git a/internal/trafpol/filter.go b/internal/trafpol/filter.go index 6ace89d..991e29c 100644 --- a/internal/trafpol/filter.go +++ b/internal/trafpol/filter.go @@ -155,31 +155,43 @@ table inet oc-daemon-filter { ` r := strings.NewReplacer("$FWMARK", fwMark) rules := r.Replace(filterRules) - if err := execs.RunNft(ctx, rules); err != nil { - log.WithError(err).Error("TrafPol error setting routing rules") + if stdout, stderr, err := execs.RunNft(ctx, rules); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error setting routing rules") } } // unsetFilterRules unsets the filter rules. func unsetFilterRules(ctx context.Context) { - if err := execs.RunNft(ctx, "delete table inet oc-daemon-filter"); err != nil { - log.WithError(err).Error("TrafPol error unsetting routing rules") + if stdout, stderr, err := execs.RunNft(ctx, "delete table inet oc-daemon-filter"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error unsetting routing rules") } } // addAllowedDevice adds device to the allowed devices. func addAllowedDevice(ctx context.Context, device string) { nftconf := fmt.Sprintf("add element inet oc-daemon-filter allowdevs { %s }", device) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("TrafPol error adding allowed device") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error adding allowed device") } } // removeAllowedDevice removes device from the allowed devices. func removeAllowedDevice(ctx context.Context, device string) { nftconf := fmt.Sprintf("delete element inet oc-daemon-filter allowdevs { %s }", device) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("TrafPol error removing allowed device") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error removing allowed device") } } @@ -190,11 +202,17 @@ func setAllowedIPs(ctx context.Context, ips []*net.IPNet) { // runs into "file exists" errors even though we remove duplicates from // ips before calling this function and we flush the existing entries - if err := execs.RunNft(ctx, "flush set inet oc-daemon-filter allowhosts4"); err != nil { - log.WithError(err).Error("TrafPol error flushing allowed ipv4s") + if stdout, stderr, err := execs.RunNft(ctx, "flush set inet oc-daemon-filter allowhosts4"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error flushing allowed ipv4s") } - if err := execs.RunNft(ctx, "flush set inet oc-daemon-filter allowhosts6"); err != nil { - log.WithError(err).Error("TrafPol error flushing allowed ipv6s") + if stdout, stderr, err := execs.RunNft(ctx, "flush set inet oc-daemon-filter allowhosts6"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error flushing allowed ipv6s") } fmt4 := "add element inet oc-daemon-filter allowhosts4 { %s }" @@ -203,14 +221,20 @@ func setAllowedIPs(ctx context.Context, ips []*net.IPNet) { if ip.IP.To4() != nil { // ipv4 address nftconf := fmt.Sprintf(fmt4, ip) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("TrafPol error adding allowed ipv4") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error adding allowed ipv4") } } else { // ipv6 address nftconf := fmt.Sprintf(fmt6, ip) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("TrafPol error adding allowed ipv6") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error adding allowed ipv6") } } } @@ -229,8 +253,11 @@ func portsToString(ports []uint16) string { func addPortalPorts(ctx context.Context, ports []uint16) { p := portsToString(ports) nftconf := fmt.Sprintf("add element inet oc-daemon-filter allowports { %s }", p) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("TrafPol error adding portal ports") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error adding portal ports") } } @@ -238,14 +265,17 @@ func addPortalPorts(ctx context.Context, ports []uint16) { func removePortalPorts(ctx context.Context, ports []uint16) { p := portsToString(ports) nftconf := fmt.Sprintf("delete element inet oc-daemon-filter allowports { %s }", p) - if err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).Error("TrafPol error removing portal ports") + if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("TrafPol error removing portal ports") } } // cleanupFilterRules cleans up the filter rules after a failed shutdown. func cleanupFilterRules(ctx context.Context) { - if err := execs.RunNft(ctx, "delete table inet oc-daemon-filter"); err == nil { + if _, _, err := execs.RunNft(ctx, "delete table inet oc-daemon-filter"); err == nil { log.Debug("TrafPol cleaned up nft") } } diff --git a/internal/trafpol/filter_test.go b/internal/trafpol/filter_test.go index 1cd909d..10c2e7a 100644 --- a/internal/trafpol/filter_test.go +++ b/internal/trafpol/filter_test.go @@ -12,8 +12,8 @@ import ( // TestFilterFunctionsErrors tests filter functions, errors. func TestFilterFunctionsErrors(_ *testing.T) { oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { - return errors.New("test error") + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return nil, nil, errors.New("test error") } defer func() { execs.RunCmd = oldRunCmd }() diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index 5bcb5a5..3cfb5dd 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -51,11 +51,11 @@ func TestTrafPolHandleCPDReport(t *testing.T) { var nftMutex sync.Mutex nftCmds := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { nftMutex.Lock() defer nftMutex.Unlock() nftCmds = append(nftCmds, s) - return nil + return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -157,9 +157,9 @@ func TestCleanup(t *testing.T) { "delete table inet oc-daemon-filter", } got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, s) - return nil + return nil, nil, nil } Cleanup(context.Background()) if !reflect.DeepEqual(got, want) { diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go index 06fed47..7cf0cf4 100644 --- a/internal/vpnsetup/vpnsetup.go +++ b/internal/vpnsetup/vpnsetup.go @@ -68,18 +68,23 @@ func (v *VPNSetup) sendEvent(event *Event) { func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { // set mtu on device mtu := strconv.Itoa(c.Device.MTU) - if err := execs.RunIPLink(ctx, "set", c.Device.Name, "mtu", mtu); err != nil { + if stdout, stderr, err := execs.RunIPLink(ctx, "set", c.Device.Name, "mtu", mtu); err != nil { log.WithError(err).WithFields(log.Fields{ "device": c.Device.Name, "mtu": mtu, + "stdout": string(stdout), + "stderr": string(stderr), }).Error("Daemon could not set mtu on device") return } // set device up - if err := execs.RunIPLink(ctx, "set", c.Device.Name, "up"); err != nil { - log.WithError(err).WithField("device", c.Device.Name). - Error("Daemon could not set device up") + if stdout, stderr, err := execs.RunIPLink(ctx, "set", c.Device.Name, "up"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "device": c.Device.Name, + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("Daemon could not set device up") return } @@ -91,10 +96,12 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { } dev := c.Device.Name addr := ipnet.String() - if err := execs.RunIPAddress(ctx, "add", addr, "dev", dev); err != nil { + if stdout, stderr, err := execs.RunIPAddress(ctx, "add", addr, "dev", dev); err != nil { log.WithError(err).WithFields(log.Fields{ "device": dev, "ip": addr, + "stdout": string(stdout), + "stderr": string(stderr), }).Error("Daemon could not set ip on device") return } @@ -111,9 +118,12 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { // teardownVPNDevice tears down the configured vpn device. func teardownVPNDevice(ctx context.Context, c *vpnconfig.Config) { // set device down - if err := execs.RunIPLink(ctx, "set", c.Device.Name, "down"); err != nil { - log.WithError(err).WithField("device", c.Device.Name). - Error("Daemon could not set device down") + if stdout, stderr, err := execs.RunIPLink(ctx, "set", c.Device.Name, "down"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "device": c.Device.Name, + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("Daemon could not set device down") return } @@ -142,10 +152,12 @@ func (v *VPNSetup) teardownRouting() { // setupDNSServer sets the DNS server. func (v *VPNSetup) setupDNSServer(ctx context.Context, config *vpnconfig.Config) { device := config.Device.Name - if err := execs.RunResolvectl(ctx, "dns", device, v.dnsProxyConf.Address); err != nil { + if stdout, stderr, err := execs.RunResolvectl(ctx, "dns", device, v.dnsProxyConf.Address); err != nil { log.WithError(err).WithFields(log.Fields{ "device": device, "server": v.dnsProxyConf.Address, + "stdout": string(stdout), + "stderr": string(stderr), }).Error("VPNSetup error setting dns server") } } @@ -153,10 +165,12 @@ func (v *VPNSetup) setupDNSServer(ctx context.Context, config *vpnconfig.Config) // setupDNSDomains sets the DNS domains. func (v *VPNSetup) setupDNSDomains(ctx context.Context, config *vpnconfig.Config) { device := config.Device.Name - if err := execs.RunResolvectl(ctx, "domain", device, config.DNS.DefaultDomain, "~."); err != nil { + if stdout, stderr, err := execs.RunResolvectl(ctx, "domain", device, config.DNS.DefaultDomain, "~."); err != nil { log.WithError(err).WithFields(log.Fields{ "device": device, "domain": config.DNS.DefaultDomain, + "stdout": string(stdout), + "stderr": string(stderr), }).Error("VPNSetup error setting dns domains") } } @@ -164,9 +178,12 @@ func (v *VPNSetup) setupDNSDomains(ctx context.Context, config *vpnconfig.Config // setupDNSDefaultRoute sets the DNS default route. func (v *VPNSetup) setupDNSDefaultRoute(ctx context.Context, config *vpnconfig.Config) { device := config.Device.Name - if err := execs.RunResolvectl(ctx, "default-route", device, "yes"); err != nil { - log.WithError(err).WithField("device", device). - Error("VPNSetup error setting dns default route") + if stdout, stderr, err := execs.RunResolvectl(ctx, "default-route", device, "yes"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "device": device, + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error setting dns default route") } } @@ -196,13 +213,19 @@ func (v *VPNSetup) setupDNS(ctx context.Context, config *vpnconfig.Config) { v.setupDNSDefaultRoute(ctx, config) // flush dns caches - if err := execs.RunResolvectl(ctx, "flush-caches"); err != nil { - log.WithError(err).Error("VPNSetup error flushing dns caches during setup") + if stdout, stderr, err := execs.RunResolvectl(ctx, "flush-caches"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error flushing dns caches during setup") } // reset learnt server features - if err := execs.RunResolvectl(ctx, "reset-server-features"); err != nil { - log.WithError(err).Error("VPNSetup error resetting server features during setup") + if stdout, stderr, err := execs.RunResolvectl(ctx, "reset-server-features"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error resetting server features during setup") } } @@ -220,19 +243,28 @@ func (v *VPNSetup) teardownDNS(ctx context.Context, vpnconf *vpnconfig.Config) { // update dns configuration of host // undo device dns configuration - if err := execs.RunResolvectl(ctx, "revert", vpnconf.Device.Name); err != nil { - log.WithError(err).WithField("device", vpnconf.Device.Name). - Error("VPNSetup error reverting dns configuration") + if stdout, stderr, err := execs.RunResolvectl(ctx, "revert", vpnconf.Device.Name); err != nil { + log.WithError(err).WithFields(log.Fields{ + "device": vpnconf.Device.Name, + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error reverting dns configuration") } // flush dns caches - if err := execs.RunResolvectl(ctx, "flush-caches"); err != nil { - log.WithError(err).Error("VPNSetup error flushing dns caches during teardown") + if stdout, stderr, err := execs.RunResolvectl(ctx, "flush-caches"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error flushing dns caches during teardown") } // reset learnt server features - if err := execs.RunResolvectl(ctx, "reset-server-features"); err != nil { - log.WithError(err).Error("VPNSetup error resetting server features during teardown") + if stdout, stderr, err := execs.RunResolvectl(ctx, "reset-server-features"); err != nil { + log.WithError(err).WithFields(log.Fields{ + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error resetting server features during teardown") } } @@ -290,9 +322,13 @@ func (v *VPNSetup) ensureDNS(ctx context.Context, config *vpnconfig.Config) bool // get dns settings device := config.Device.Name - stdout, err := execs.RunResolvectlOutput(ctx, "status", device, "--no-pager") + stdout, stderr, err := execs.RunResolvectl(ctx, "status", device, "--no-pager") if err != nil { - log.WithError(err).WithField("device", device).Error("VPNSetup error getting DNS settings") + log.WithError(err).WithFields(log.Fields{ + "device": device, + "stdout": string(stdout), + "stderr": string(stderr), + }).Error("VPNSetup error getting DNS settings") return false } @@ -521,11 +557,11 @@ func NewVPNSetup( // Cleanup cleans up the configuration after a failed shutdown. func Cleanup(ctx context.Context, vpnDevice string, splitrtConfig *splitrt.Config) { // dns, device, split routing - if err := execs.RunResolvectl(ctx, "revert", vpnDevice); err == nil { + if _, _, err := execs.RunResolvectl(ctx, "revert", vpnDevice); err == nil { log.WithField("device", vpnDevice). Warn("VPNSetup cleaned up dns config") } - if err := execs.RunIPLink(ctx, "delete", vpnDevice); err == nil { + if _, _, err := execs.RunIPLink(ctx, "delete", vpnDevice); err == nil { log.WithField("device", vpnDevice). Warn("VPNSetup cleaned up vpn device") } diff --git a/internal/vpnsetup/vpnsetup_test.go b/internal/vpnsetup/vpnsetup_test.go index 55ca5cf..15013ab 100644 --- a/internal/vpnsetup/vpnsetup_test.go +++ b/internal/vpnsetup/vpnsetup_test.go @@ -40,9 +40,9 @@ func TestSetupVPNDevice(t *testing.T) { "address add 2001::1/64 dev tun0", } got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) - return nil + return nil, nil, nil } // test @@ -57,15 +57,15 @@ func TestSetupVPNDevice(t *testing.T) { // depending on when execs.RunCmd failed. numRuns := 0 failAt := 0 - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { // fail after failAt runs if numRuns == failAt { - return errors.New("test error") + return nil, nil, errors.New("test error") } numRuns++ got = append(got, strings.Join(arg, " ")) - return nil + return nil, nil, nil } for _, f := range []int{0, 1, 2} { got = []string{} @@ -93,9 +93,9 @@ func TestTeardownVPNDevice(t *testing.T) { "link set tun0 down", } got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) - return nil + return nil, nil, nil } // test @@ -105,8 +105,8 @@ func TestTeardownVPNDevice(t *testing.T) { } // test with execs error - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { - return errors.New("test error") + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return nil, nil, errors.New("test error") } teardownVPNDevice(context.Background(), c) } @@ -122,9 +122,9 @@ func TestVPNSetupSetupDNS(t *testing.T) { c.DNS.DefaultDomain = "mycompany.com" got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) - return nil + return nil, nil, nil } v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) v.setupDNS(context.Background(), c) @@ -141,9 +141,9 @@ func TestVPNSetupSetupDNS(t *testing.T) { } // test with execs errors - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) - return errors.New("test error") + return nil, nil, errors.New("test error") } got = []string{} @@ -163,9 +163,9 @@ func TestVPNSetupTeardownDNS(t *testing.T) { c.Device.Name = "tun0" got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) - return nil + return nil, nil, nil } v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) @@ -181,9 +181,9 @@ func TestVPNSetupTeardownDNS(t *testing.T) { } // test with execs errors - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) - return errors.New("test error") + return nil, nil, errors.New("test error") } got = []string{} @@ -272,8 +272,6 @@ func TestVPNSetupEnsureDNS(t *testing.T) { // clean up after tests oldRunCmd := execs.RunCmd defer func() { execs.RunCmd = oldRunCmd }() - oldRunCmdOutput := execs.RunCmdOutput - defer func() { execs.RunCmdOutput = oldRunCmdOutput }() v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) ctx := context.Background() @@ -283,14 +281,9 @@ func TestVPNSetupEnsureDNS(t *testing.T) { v.dnsProxyConf.Address = "127.0.0.1:4253" vpnconf.DNS.DefaultDomain = "test.example.com" - // override RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { - return nil - } - // test resolvectl error - execs.RunCmdOutput = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, error) { - return nil, errors.New("test error") + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return nil, nil, errors.New("test error") } if ok := v.ensureDNS(ctx, vpnconf); ok { @@ -306,8 +299,8 @@ func TestVPNSetupEnsureDNS(t *testing.T) { []byte("header\nProtocols: +DefaultRoute\nDNS Servers: other\nDNS Domain: test.example.com ~.\n"), []byte("header\nProtocols: +DefaultRoute\nDNS Servers: 127.0.0.1:4253\nDNS Domain: other\n"), } { - execs.RunCmdOutput = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, error) { - return invalid, nil + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return invalid, nil, nil } if ok := v.ensureDNS(ctx, vpnconf); ok { @@ -321,8 +314,8 @@ func TestVPNSetupEnsureDNS(t *testing.T) { []byte("header\n Protocols: +DefaultRoute \nother\n " + "DNS Servers: 127.0.0.1:4253 \n DNS Domain: test.example.com ~.\nother\n"), } { - execs.RunCmdOutput = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, error) { - return valid, nil + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return valid, nil, nil } if ok := v.ensureDNS(ctx, vpnconf); !ok { @@ -342,17 +335,11 @@ func TestVPNSetupStartStop(_ *testing.T) { func TestVPNSetupSetupTeardown(_ *testing.T) { // override functions oldCmd := execs.RunCmd - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { - return nil + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + return nil, nil, nil } defer func() { execs.RunCmd = oldCmd }() - oldCmdOutput := execs.RunCmdOutput - execs.RunCmdOutput = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, error) { - return nil, nil - } - defer func() { execs.RunCmdOutput = oldCmdOutput }() - oldRegisterAddrUpdates := addrmon.RegisterAddrUpdates addrmon.RegisterAddrUpdates = func(*addrmon.AddrMon) (chan netlink.AddrUpdate, error) { return nil, nil @@ -423,13 +410,13 @@ func TestNewVPNSetup(t *testing.T) { // TestCleanup tests Cleanup. func TestCleanup(t *testing.T) { got := []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) error { + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { if s == "" { got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil + return nil, nil, nil } got = append(got, cmd+" "+strings.Join(arg, " ")+" "+s) - return nil + return nil, nil, nil } Cleanup(context.Background(), "tun0", splitrt.NewConfig()) want := []string{