From 8fe7cb099dccfce3f9329d7207ef48f488f07e83 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Tue, 15 Jun 2021 15:04:16 +0300 Subject: [PATCH] all: imp code, docs & tests --- internal/dnsforward/dnsforward.go | 8 - internal/dnsforward/http.go | 2 +- internal/home/authratelimiter.go | 2 +- internal/home/config.go | 1 - internal/home/duration_test.go | 285 ++++++++++++++---------------- 5 files changed, 137 insertions(+), 161 deletions(-) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index fa6782d3299..d1d30d42bb5 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -223,14 +223,6 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) { c.UpstreamDNS = aghstrings.CloneSlice(sc.UpstreamDNS) } -// UpstreamTimeout returns the copy of actual RDNS configuration. -func (s *Server) UpstreamTimeout() (timeout time.Duration) { - s.serverLock.RLock() - defer s.serverLock.RUnlock() - - return s.conf.UpstreamTimeout -} - // RDNSSettings returns the copy of actual RDNS configuration. func (s *Server) RDNSSettings() (localPTRResolvers []string, resolveClients, resolvePTR bool) { s.serverLock.RLock() diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index d0c2ae7fded..06baa9a9dc6 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -585,7 +585,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { result := map[string]string{} bootstraps := req.BootstrapDNS - timeout := s.UpstreamTimeout() + timeout := s.conf.UpstreamTimeout for _, host := range req.Upstreams { err = checkDNS(host, bootstraps, timeout, checkDNSUpstreamExc) if err != nil { diff --git a/internal/home/authratelimiter.go b/internal/home/authratelimiter.go index c0b3da4054d..acdee35cbab 100644 --- a/internal/home/authratelimiter.go +++ b/internal/home/authratelimiter.go @@ -72,7 +72,7 @@ func (ab *authRateLimiter) check(usrID string) (left time.Duration) { // incLocked increments the number of unsuccessful attempts for attempter with // ip and updates it's blocking moment if needed. For internal use only. func (ab *authRateLimiter) incLocked(usrID string, now time.Time) { - var until time.Time = now.Add(failedAuthTTL) + until := now.Add(failedAuthTTL) var attNum uint = 1 a, ok := ab.failedAuths[usrID] diff --git a/internal/home/config.go b/internal/home/config.go index 2553308dda9..fedb43ab357 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -348,7 +348,6 @@ func (c *configuration) write() error { dns.LocalPTRResolvers, dns.ResolveClients, dns.UsePrivateRDNS = s.RDNSSettings() - dns.UpstreamTimeout = Duration{s.UpstreamTimeout()} } if Context.dhcpServer != nil { diff --git a/internal/home/duration_test.go b/internal/home/duration_test.go index 08e9bc832e5..4c728aadebe 100644 --- a/internal/home/duration_test.go +++ b/internal/home/duration_test.go @@ -1,186 +1,171 @@ package home import ( - "encoding" "encoding/json" - "encoding/xml" - "fmt" - "io" "strings" "testing" "time" - "github.com/AdguardTeam/golibs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v2" + yaml "gopkg.in/yaml.v2" ) +// durationMarshalTester is a helper struct to simplify testing different +// Duration marshalling and unmarshalling cases. +type durationMarshalTester struct { + PtrMap map[string]*Duration `json:"ptr_map"` + PtrSlice []*Duration `json:"ptr_slice"` + PtrValue *Duration `json:"ptr_value"` + PtrArray [1]*Duration `json:"ptr_array"` + Map map[string]Duration `json:"map"` + Slice []Duration `json:"slice"` + Value Duration `json:"value"` + Array [1]Duration `json:"array"` +} + +const nl = "\n" const ( - // ErrNotTextMarshaler is returned when passed interface does not - // implement the encoding.TextMarshaler interface. - ErrNotTextMarshaler errors.Error = "not a text marshaler" - // ErrNotTextUnmarshaler is returned when passed interface does not - // implement the encoding.TextUnmarshaler interface. - ErrNotTextUnmarshaler errors.Error = "not a text unmarshaler" + jsonStr = `{` + + `"ptr_map":{"dur":"1ms"},` + + `"ptr_slice":["1ms"],` + + `"ptr_value":"1ms",` + + `"ptr_array":["1ms"],` + + `"map":{"dur":"1ms"},` + + `"slice":["1ms"],` + + `"value":"1ms",` + + `"array":["1ms"]` + + `}` + yamlStr = `ptrmap:` + nl + + ` dur: 1ms` + nl + + `ptrslice:` + nl + + `- 1ms` + nl + + `ptrvalue: 1ms` + nl + + `ptrarray:` + nl + + `- 1ms` + nl + + `map:` + nl + + ` dur: 1ms` + nl + + `slice:` + nl + + `- 1ms` + nl + + `value: 1ms` + nl + + `array:` + nl + + `- 1ms` ) -// directText implements Encode and Decode methods like other encoding-related -// packages do. Simplifies testing of encoding.TextMarshaler and -// encoding.TextUnmarshaler interfaces implementations. -// -// TODO(e.burkov): Put into aghtest when there will be other -// encoding.TextMarshaler or encoding.TextUnmarshaler implementations. -type directText struct { - // w is an io.Writer that directText will write encoded data. - w io.Writer - // r is an io.Reader that directText will read data to decode from. - r io.Reader -} +// checkFields verifies m's fields. It expects the m to be unmarshalled from +// one of the constant strings above. +func (m *durationMarshalTester) checkFields(t *testing.T, d Duration) { + require.NotNil(t, m.PtrMap) -// Encode expects the v to be an encoding.TextMarshaler and writes the data from -// it using internal writer. -func (e *directText) Encode(v interface{}) (err error) { - val, ok := v.(encoding.TextMarshaler) - if !ok { - return ErrNotTextMarshaler - } + fromPtrMap, ok := m.PtrMap["dur"] + require.True(t, ok) + require.NotNil(t, fromPtrMap) - var data []byte - data, err = val.MarshalText() - if err != nil { - return err - } + require.Len(t, m.PtrSlice, 1) + fromPtrSlice := m.PtrSlice[0] + require.NotNil(t, fromPtrSlice) - _, err = e.w.Write(data) - if err != nil { - return err - } + fromPtrArray := m.PtrArray[0] + require.NotNil(t, fromPtrArray) - return nil -} + require.NotNil(t, m.PtrValue) -// Decode expects the v to be an encoding.TextUnmarshaler. It reads the data -// internal reader passing it into v. -func (e *directText) Decode(v interface{}) (err error) { - val, ok := v.(encoding.TextUnmarshaler) - if !ok { - return ErrNotTextUnmarshaler - } + var fromMap Duration + fromMap, ok = m.Map["dur"] + require.True(t, ok) - var data []byte - data, err = io.ReadAll(e.r) - if err != nil { - return err - } + require.Len(t, m.Slice, 1) - err = val.UnmarshalText(data) - if err != nil { - return err - } - - return nil + assert.Equal(t, d, *fromPtrMap) + assert.Equal(t, d, *fromPtrSlice) + assert.Equal(t, d, *m.PtrValue) + assert.Equal(t, d, *fromPtrArray) + assert.Equal(t, d, fromMap) + assert.Equal(t, d, m.Slice[0]) + assert.Equal(t, d, m.Value) + assert.Equal(t, d, m.Array[0]) } -// val is the default value throughout tests. -const val = 1 * time.Millisecond - -// valStr is a text representation of val. -var valStr = val.String() +// val is the default time.Duration value to be used throughout the tests of +// Duration. +const val = time.Millisecond func TestDuration_MarshalText(t *testing.T) { d := Duration{val} + dPtr := &d + + m := durationMarshalTester{ + PtrMap: map[string]*Duration{"dur": dPtr}, + PtrSlice: []*Duration{dPtr}, + PtrValue: dPtr, + PtrArray: [1]*Duration{dPtr}, + Map: map[string]Duration{"dur": d}, + Slice: []Duration{d}, + Value: d, + Array: [1]Duration{d}, + } + b := &strings.Builder{} + t.Run("json", func(t *testing.T) { + t.Cleanup(b.Reset) + err := json.NewEncoder(b).Encode(m) + require.NoError(t, err) - testCases := []struct { - enc interface { - Encode(v interface{}) (err error) - } - name string - fmtStr string - }{{ - enc: yaml.NewEncoder(b), - name: "yaml", - fmtStr: "%s\n", - }, { - enc: json.NewEncoder(b), - name: "json", - fmtStr: "%q\n", - }, { - enc: xml.NewEncoder(b), - name: "xml", - fmtStr: "%s", - }, { - enc: &directText{ - w: b, - }, - name: "direct", - fmtStr: "%s", - }} - - for _, tc := range testCases { - b.Reset() - t.Run(tc.name, func(t *testing.T) { - err := tc.enc.Encode(d) - require.NoError(t, err) - - assert.Equal(t, fmt.Sprintf(tc.fmtStr, val), b.String()) - }) - } + assert.JSONEq(t, jsonStr, b.String()) + }) + + t.Run("yaml", func(t *testing.T) { + t.Cleanup(b.Reset) + err := yaml.NewEncoder(b).Encode(m) + require.NoError(t, err) + + assert.YAMLEq(t, yamlStr, b.String(), b.String()) + }) + + t.Run("direct", func(t *testing.T) { + data, err := d.MarshalText() + require.NoError(t, err) + + assert.EqualValues(t, []byte(val.String()), data) + }) } func TestDuration_UnmarshalText(t *testing.T) { - d := Duration{} - - testCases := []struct { - dec interface { - Decode(v interface{}) (err error) - } - name string - }{{ - dec: yaml.NewDecoder( - strings.NewReader(valStr), - ), - name: "yaml", - }, { - dec: json.NewDecoder( - strings.NewReader(`"` + valStr + `"`), - ), - name: "json", - }, { - dec: xml.NewDecoder( - strings.NewReader("" + valStr + ""), - ), - name: "xml", - }, { - dec: &directText{ - r: strings.NewReader(valStr), - }, - name: "direct", - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := tc.dec.Decode(&d) - require.NoError(t, err) - - assert.Equal(t, val, d.Duration) - }) - } + d := Duration{val} + var m *durationMarshalTester + + t.Run("json", func(t *testing.T) { + m = &durationMarshalTester{} + + r := strings.NewReader(jsonStr) + err := json.NewDecoder(r).Decode(m) + require.NoError(t, err) + + m.checkFields(t, d) + }) + + t.Run("yaml", func(t *testing.T) { + m = &durationMarshalTester{} + + r := strings.NewReader(yamlStr) + err := yaml.NewDecoder(r).Decode(m) + require.NoError(t, err) + + m.checkFields(t, d) + }) + + t.Run("direct", func(t *testing.T) { + dd := &Duration{} + + err := dd.UnmarshalText([]byte(d.String())) + require.NoError(t, err) + + assert.Equal(t, d, *dd) + }) t.Run("bad_data", func(t *testing.T) { - const wrongDur = "abc" - - dec := &directText{ - r: strings.NewReader(wrongDur), - } - err := dec.Decode(&d) - require.Error(t, err) - - assert.Equal( - t, - fmt.Sprintf("unmarshalling duration: time: invalid duration %q", wrongDur), - err.Error(), - ) + const wrongData = `abc` + + assert.Error(t, (&Duration{}).UnmarshalText([]byte(wrongData))) }) }