Skip to content

Commit

Permalink
Improve config allowed_network_types.
Browse files Browse the repository at this point in the history
- Rename values to "ipv4", "ipv6" and "any".
- Validate them when parsing.
  • Loading branch information
keuin committed Sep 11, 2022
1 parent 25fc31b commit 2b1e0db
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
26 changes: 22 additions & 4 deletions bilibili/netprobe.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,35 @@ package bilibili

import (
"context"
"fmt"
"net"
)

type IpNetType string

var (
IPv6Net IpNetType = "tcp6"
IPv4Net IpNetType = "tcp4"
IP64 IpNetType = "tcp"
IPv6Net IpNetType = "ipv6"
IPv4Net IpNetType = "ipv4"
IP64 IpNetType = "any"
)

// GetDialNetString returns the string accepted by net.Dialer::DialContext
func (t IpNetType) GetDialNetString() string {
switch t {
case IPv4Net:
return "tcp4"
case IPv6Net:
return "tcp6"
case IP64:
return "tcp"
}
return ""
}

func (t IpNetType) String() string {
return fmt.Sprintf("%s(%s)", string(t), t.GetDialNetString())
}

type netContext = func(context.Context, string, string) (net.Conn, error)

type netProbe struct {
Expand All @@ -36,6 +54,6 @@ func (p *netProbe) NextNetworkType(dialer net.Dialer) (netContext, IpNetType) {
network := p.list[p.i]
p.i++
return func(ctx context.Context, _, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, string(network), addr)
return dialer.DialContext(ctx, network.GetDialNetString(), addr)
}, network
}
14 changes: 13 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ import (
"context"
"fmt"
"github.com/akamensky/argparse"
"github.com/keuin/slbr/bilibili"
"github.com/keuin/slbr/common"
"github.com/keuin/slbr/logging"
"github.com/keuin/slbr/recording"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"log"
"os"
"os/signal"
"reflect"
"sync"
"syscall"
)
Expand Down Expand Up @@ -104,7 +107,16 @@ func getTasks() (tasks []recording.TaskConfig) {
return
}
var gc GlobalConfig
err = viper.Unmarshal(&gc)
netType := reflect.TypeOf(bilibili.IP64)
err = viper.Unmarshal(&gc, func(conf *mapstructure.DecoderConfig) {
conf.DecodeHook = func(from reflect.Value, to reflect.Value) (interface{}, error) {
if to.Type() == netType &&
bilibili.IpNetType(from.String()).GetDialNetString() == "" {
return nil, fmt.Errorf("invalid IpNetType: %v", from.String())
}
return from.Interface(), nil
}
})
if err != nil {
err = fmt.Errorf("cannot parse config file \"%v\": %w", configFile, err)
return
Expand Down

0 comments on commit 2b1e0db

Please sign in to comment.