Skip to content

Commit

Permalink
refactor: add : prefix to ports during config unmarshaling
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Sep 4, 2024
1 parent 70afa43 commit 6aae002
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 24 deletions.
7 changes: 7 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {

*l = strings.Split(addresses, ",")

// Prefix all ports with :
for i, addr := range *l {
if !strings.ContainsRune(addr, ':') {
(*l)[i] = ":" + addr
}
}

return nil
}

Expand Down
4 changes: 2 additions & 2 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ bootstrapDns:
err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56"))
Expect(*l).Should(ContainElements(":55", ":56"))
})
})
})
Expand Down Expand Up @@ -958,7 +958,7 @@ bootstrapDns:
})

func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
Expand Down
16 changes: 12 additions & 4 deletions helpertest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -31,20 +32,27 @@ const (
DS = dns.Type(dns.TypeDS)
)

// GetIntPort returns an port for the current testing
// GetIntPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as int
// the base port and returning it as int.
func GetIntPort(port int) int {
return port + ginkgo.GinkgoParallelProcess()
}

// GetStringPort returns an port for the current testing
// GetStringPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string
// the base port and returning it as string.
func GetStringPort(port int) string {
return fmt.Sprintf("%d", GetIntPort(port))
}

// GetHostPort returns a host:port string for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string.
func GetHostPort(host string, port int) string {
return net.JoinHostPort(host, GetStringPort(port))
}

// TempFile creates temp file with passed data
func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix")
Expand Down
12 changes: 2 additions & 10 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,6 @@ func tlsCipherSuites() []uint16 {
return tlsCipherSuites
}

func getServerAddress(addr string) string {
if !strings.Contains(addr, ":") {
addr = fmt.Sprintf(":%s", addr)
}

return addr
}

type NewServerFunc func(address string) (*dns.Server, error)

func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
Expand Down Expand Up @@ -195,7 +187,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error

addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
for _, address := range addresses {
server, err := newServer(getServerAddress(address))
server, err := newServer(address)
if err != nil {
return err
}
Expand Down Expand Up @@ -236,7 +228,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene
listeners := make([]net.Listener, 0, len(addresses))

for _, address := range addresses {
listener, err := net.Listen("tcp", getServerAddress(address))
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
}
Expand Down
17 changes: 9 additions & 8 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -43,7 +44,7 @@ var (
)

var _ = BeforeSuite(func() {
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
queryURL = baseURL + "dns-query"
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
ctx, cancelFn := context.WithCancel(context.Background())
Expand Down Expand Up @@ -146,10 +147,10 @@ var _ = BeforeSuite(func() {
},

Ports: config.Ports{
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
},
CertFile: certPem.Path,
KeyFile: keyPem.Path,
Expand Down Expand Up @@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})

Expand Down Expand Up @@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})

Expand Down Expand Up @@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() {
})

func requestServer(request *dns.Msg) *dns.Msg {
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
if err != nil {
Log().Fatal("could not connect to server: ", err)
}
Expand Down

0 comments on commit 6aae002

Please sign in to comment.