Skip to content

Commit

Permalink
aghos: use filewalker
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Aug 13, 2021
1 parent e4f2964 commit 8459208
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 115 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.16

require (
github.com/AdguardTeam/dnsproxy v0.39.2
github.com/AdguardTeam/golibs v0.9.0
github.com/AdguardTeam/golibs v0.9.1
github.com/AdguardTeam/urlfilter v0.14.6
github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKU
github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.9.0 h1:QwmHqeZOVs9XpkmPb2iYpZ35OBArjgTesE8gLtEFRFg=
github.com/AdguardTeam/golibs v0.9.0/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY=
github.com/AdguardTeam/golibs v0.9.1 h1:mHSN4LfaY1uGmHPsl97paAND/VeSnM5r9XQ7pSYx93o=
github.com/AdguardTeam/golibs v0.9.1/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo=
github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=
Expand Down
29 changes: 5 additions & 24 deletions internal/aghnet/net_freebsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,26 @@ import (
"fmt"
"io"
"net"
"os"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
)

func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}

// maxCheckedFileSize is the maximum acceptable length of the /etc/rc.conf file.
const maxCheckedFileSize = 1024 * 1024

func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const filename = "/etc/rc.conf"

var f *os.File
f, err = os.Open(filename)
if err != nil {
return false, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()

var r io.Reader
r, err = aghio.LimitReader(f, maxCheckedFileSize)
if err != nil {
return false, err
}

return rcConfStaticConfig(r, ifaceName)
return aghos.FileWalker(ifaceNamed(ifaceName).rcConfStaticConfig).Walk(filename)
}

// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
// have a static IP.
func rcConfStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
func (n ifaceNamed) rcConfStaticConfig(r io.Reader) (_ []string, ok bool, err error) {
s := bufio.NewScanner(r)
for ifaceLinePref := fmt.Sprintf("ifconfig_%s", ifaceName); s.Scan(); {
for ifaceLinePref := fmt.Sprintf("ifconfig_%s", n); s.Scan(); {
line := strings.TrimSpace(s.Text())
if !strings.HasPrefix(line, ifaceLinePref) {
continue
Expand All @@ -66,11 +47,11 @@ func rcConfStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
if len(fields) >= 2 &&
strings.ToLower(fields[0]) == "inet" &&
net.ParseIP(fields[1]) != nil {
return true, s.Err()
return nil, true, s.Err()
}
}

return false, s.Err()
return nil, false, s.Err()
}

func ifaceSetStaticIP(string) (err error) {
Expand Down
4 changes: 2 additions & 2 deletions internal/aghnet/net_freebsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestRcConfStaticConfig(t *testing.T) {
const ifaceName = `em0`
const iface ifaceNamed = `em0`
const nl = "\n"

testCases := []struct {
Expand Down Expand Up @@ -51,7 +51,7 @@ func TestRcConfStaticConfig(t *testing.T) {
for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) {
has, err := rcConfStaticConfig(r, ifaceName)
_, has, err := iface.rcConfStaticConfig(r)
require.NoError(t, err)

assert.Equal(t, tc.wantHas, has)
Expand Down
6 changes: 1 addition & 5 deletions internal/aghnet/net_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ import (
"golang.org/x/sys/unix"
)

// ifaceNamed contains interface's name, and passes itself to the underlying
// methods.
type ifaceNamed string

// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// have a static IP.
func (n ifaceNamed) dhcpcdStaticConfig(r io.Reader) (subsources []string, has bool, err error) {
Expand Down Expand Up @@ -100,7 +96,7 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
FileWalker: iface.ifacesStaticConfig,
filename: "/etc/network/interfaces",
}} {
has, err = pair.Check(pair.filename)
has, err = pair.Walk(pair.filename)
if err != nil {
return false, err
}
Expand Down
37 changes: 5 additions & 32 deletions internal/aghnet/net_openbsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,34 @@ import (
"fmt"
"io"
"net"
"os"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
)

func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}

// maxCheckedFileSize is the maximum acceptable length of the /etc/hostname.*
// files.
const maxCheckedFileSize = 1024 * 1024

func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const filenameFmt = "/etc/hostname.%s"

filename := fmt.Sprintf(filenameFmt, ifaceName)
var f *os.File
if f, err = os.Open(filename); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
}

return false, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()

var r io.Reader
r, err = aghio.LimitReader(f, maxCheckedFileSize)
if err != nil {
return false, err
}
filename := fmt.Sprintf("/etc/hostname.%s", ifaceName)

return hostnameIfStaticConfig(r)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename)
}

// hostnameIfStaticConfig checks if the interface is configured by
// /etc/hostname.* to have a static IP.
//
// TODO(e.burkov): The platform-dependent functions to check the static IP
// address configured are rather similar. Think about unifying common parts.
func hostnameIfStaticConfig(r io.Reader) (has bool, err error) {
func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) >= 2 && fields[0] == "inet" && net.ParseIP(fields[1]) != nil {
return true, s.Err()
return nil, true, s.Err()
}
}

return false, s.Err()
return nil, false, s.Err()
}

func ifaceSetStaticIP(string) (err error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/aghnet/net_openbsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestHostnameIfStaticConfig(t *testing.T) {
for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) {
has, err := hostnameIfStaticConfig(r)
_, has, err := hostnameIfStaticConfig(r)
require.NoError(t, err)

assert.Equal(t, tc.wantHas, has)
Expand Down
8 changes: 8 additions & 0 deletions internal/aghnet/net_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//go:build openbsd || freebsd || linux
// +build openbsd freebsd linux

package aghnet

// ifaceNamed contains interface's name, and passes itself to the underlying
// methods.
type ifaceNamed string
File renamed without changes.
26 changes: 13 additions & 13 deletions internal/aghos/filewalker.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,28 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e
return sub, nil
}

// Check starts walking the files defined by initPattern which should be valid
// Walk starts walking the files defined by initPattern which should be valid
// for filepath.Glob method.
func (c FileWalker) Check(initPattern string) (ok bool, err error) {
func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
// The slice of sources is keeps the order in which the files are walked
// since sourcesSet.Values() returns strings in undefined order.
sourcesSet := stringutil.NewSet()
var sources []string
sources, err = handlePatterns(sourcesSet, initPattern)
// since srcSet.Values() returns strings in undefined order.
srcSet := stringutil.NewSet()
var src []string
src, err = handlePatterns(srcSet, initPattern)
if err != nil {
return false, err
}

var i int
defer func() {
if i < len(sources) {
err = errors.Annotate(err, "checking %q: %w", sources[i])
if i < len(src) {
err = errors.Annotate(err, "checking %q: %w", src[i])
}
}()

for ; i < len(sources); i++ {
for ; i < len(src); i++ {
var patterns []string
patterns, ok, err = checkFile(c, sources[i])
patterns, ok, err = checkFile(c, src[i])
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
Expand All @@ -104,13 +104,13 @@ func (c FileWalker) Check(initPattern string) (ok bool, err error) {
return true, nil
}

var subsources []string
subsources, err = handlePatterns(sourcesSet, patterns...)
var subsrc []string
subsrc, err = handlePatterns(srcSet, patterns...)
if err != nil {
return false, err
}

sources = append(sources, subsources...)
src = append(src, subsrc...)
}

return false, nil
Expand Down
12 changes: 6 additions & 6 deletions internal/aghos/filewalker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import (
"github.com/stretchr/testify/require"
)

func TestFileWalker_Check(t *testing.T) {
func TestFileWalker_Walk(t *testing.T) {
testdataPref := filepath.Join(".", "testdata")

const attribute = "000"
const attribute = `000`

c := FileWalker(func(r io.Reader) (patterns []string, has bool, err error) {
s := bufio.NewScanner(r)
Expand Down Expand Up @@ -57,23 +57,23 @@ func TestFileWalker_Check(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ok, err := c.Check(tc.initPattern)
ok, err := c.Walk(tc.initPattern)
require.NoError(t, err)

assert.Equal(t, tc.want, ok)
})
}

t.Run("pattern_malformed", func(t *testing.T) {
ok, err := c.Check(`\`)
ok, err := c.Walk(`\`)
require.Error(t, err)

assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern)
})

t.Run("bad_filename", func(t *testing.T) {
ok, err := c.Check(filepath.Join(testdataPref, "bad_filename.txt"))
ok, err := c.Walk(filepath.Join(testdataPref, "bad_filename.txt"))
require.Error(t, err)

assert.False(t, ok)
Expand All @@ -85,7 +85,7 @@ func TestFileWalker_Check(t *testing.T) {

ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, false, rerr
}).Check(filepath.Join(testdataPref, "*"))
}).Walk(filepath.Join(testdataPref, "*"))
require.Error(t, err)
require.False(t, ok)

Expand Down
44 changes: 13 additions & 31 deletions internal/aghos/os_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
package aghos

import (
"bytes"
"io"
"os"
"path/filepath"
"strings"
"syscall"

"github.com/AdguardTeam/golibs/stringutil"
)

func setRlimit(val uint64) (err error) {
Expand All @@ -34,37 +34,19 @@ func isActuallyOpenWrt() (ok bool) {
}

func isOpenWrt() (ok bool) {
const etcDir = "/etc"

dirEnts, err := os.ReadDir(etcDir)
if err != nil {
return false
}

// fNameSubstr is a part of a name of the desired file.
const fNameSubstr = "release"
osNameData := []byte("OpenWrt")

for _, dirEnt := range dirEnts {
if dirEnt.IsDir() {
continue
}
var err error
ok, err = FileWalker(func(r io.Reader) (_ []string, ok bool, err error) {
const osNameData = "openwrt"

fn := dirEnt.Name()
if !strings.Contains(fn, fNameSubstr) {
continue
}

var body []byte
body, err = os.ReadFile(filepath.Join(etcDir, fn))
// This use of ReadAll is safe since it's size handled before.
var data []byte
data, err = io.ReadAll(r)
if err != nil {
continue
return nil, false, err
}

if bytes.Contains(body, osNameData) {
return true
}
}
return nil, stringutil.ContainsFold(string(data), osNameData), nil
}).Walk("/etc/*release*")

return false
return err == nil && ok
}

0 comments on commit 8459208

Please sign in to comment.