Skip to content

Commit

Permalink
fixes with multiple drivers of same service
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgigi96 committed Jun 28, 2023
1 parent 4e6e141 commit 5ae5c4d
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 77 deletions.
36 changes: 24 additions & 12 deletions cmd/revad/pkg/config/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,33 @@ type DriverConfig struct {
Config map[string]any `key:",squash"`
Address string `key:"address"`
Network string `key:"network"`
Label string `key:"-"`
}

func newSvcConfigFromList(l []map[string]any) (ServicesConfig, error) {
func (s *ServicesConfig) Add(svc string, c *DriverConfig) {
l := len(*s)
if l == 0 {
// the label is simply the service name
c.Label = svc
} else {
c.Label = label(svc, l)
if l == 1 {
(*s)[0].Label = label(svc, 0)
}
}
*s = append(*s, c)
}

func newSvcConfigFromList(name string, l []map[string]any) (ServicesConfig, error) {
cfg := make(ServicesConfig, 0, len(l))
for _, c := range l {
cfg = append(cfg, &DriverConfig{Config: c})
cfg.Add(name, &DriverConfig{Config: c})
}
return cfg, nil
}

func newSvcConfigFromMap(m map[string]any) ServicesConfig {
s, _ := newSvcConfigFromList([]map[string]any{m})
func newSvcConfigFromMap(name string, m map[string]any) ServicesConfig {
s, _ := newSvcConfigFromList(name, []map[string]any{m})
return s
}

Expand All @@ -70,7 +85,7 @@ func parseServices(cfg map[string]any) (map[string]ServicesConfig, error) {
// cfg can be a list or a map
cfgLst, ok := cfg.([]map[string]any)
if ok {
s, err := newSvcConfigFromList(cfgLst)
s, err := newSvcConfigFromList(name, cfgLst)
if err != nil {
return nil, err
}
Expand All @@ -81,7 +96,7 @@ func parseServices(cfg map[string]any) (map[string]ServicesConfig, error) {
if !ok {
return nil, fmt.Errorf("grpc.services.%s must be a list or a map. got %T", name, cfg)
}
services[name] = newSvcConfigFromMap(cfgMap)
services[name] = newSvcConfigFromMap(name, cfgMap)
}

return services, nil
Expand Down Expand Up @@ -138,23 +153,20 @@ func (i iterableImpl) ForEachService(f ServiceFunc) {
return
}
for name, c := range i.i.services() {
for i, cfg := range c {
for _, cfg := range c {
f(&Service{
raw: cfg,
Address: cfg.Address,
Network: cfg.Network,
Label: label(name, i, len(c)),
Label: cfg.Label,
Name: name,
Config: cfg.Config,
})
}
}
}

func label(name string, i, tot int) string {
if tot == 1 {
return name
}
func label(name string, i int) string {
return fmt.Sprintf("%s_%d", name, i)
}

Expand Down
33 changes: 2 additions & 31 deletions cmd/revad/pkg/grace/grace.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"syscall"
"time"

netutil "github.com/cs3org/reva/pkg/utils/net"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
Expand Down Expand Up @@ -343,43 +344,13 @@ func (w *Watcher) GetListeners(servers map[string]Addressable) (map[string]net.L

func get(lns map[string]net.Listener, address, network string) (net.Listener, bool) {
for _, ln := range lns {
if addressEqual(ln.Addr(), network, address) {
if netutil.AddressEqual(ln.Addr(), network, address) {
return ln, true
}
}
return nil, false
}

func addressEqual(a net.Addr, network, address string) bool {
if a.Network() != network {
return false
}

switch network {
case "tcp":
t, err := net.ResolveTCPAddr(network, address)
if err != nil {
return false
}
return tcpAddressEqual(a.(*net.TCPAddr), t)
case "unix":
t, err := net.ResolveUnixAddr(network, address)
if err != nil {
return false
}
return unixAddressEqual(a.(*net.UnixAddr), t)
}
return false
}

func tcpAddressEqual(a1, a2 *net.TCPAddr) bool {
return a1.Port == a2.Port
}

func unixAddressEqual(a1, a2 *net.UnixAddr) bool {
return a1.Name == a2.Name && a1.Net == a2.Net
}

type Addressable interface {
Network() string
Address() string
Expand Down
74 changes: 40 additions & 34 deletions cmd/revad/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/cs3org/reva/pkg/sharedconf"
"github.com/cs3org/reva/pkg/utils/list"
"github.com/cs3org/reva/pkg/utils/maps"
netutil "github.com/cs3org/reva/pkg/utils/net"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -71,9 +72,6 @@ func New(config *config.Config, opt ...Option) (*Reva, error) {
return nil, err
}

grpc, addrGRPC := groupGRPCByAddress(config)
http, addrHTTP := groupHTTPByAddress(config)

if opts.PidFile == "" {
return nil, errors.New("pid file not provided")
}
Expand All @@ -83,7 +81,7 @@ func New(config *config.Config, opt ...Option) (*Reva, error) {
return nil, err
}

listeners, err := watcher.GetListeners(maps.Merge(addrGRPC, addrHTTP))
listeners, err := watcher.GetListeners(servicesAddresses(config))
if err != nil {
return nil, err
}
Expand All @@ -95,6 +93,8 @@ func New(config *config.Config, opt ...Option) (*Reva, error) {
}
initSharedConf(config)

grpc := groupGRPCByAddress(config)
http := groupHTTPByAddress(config)
servers, err := newServers(grpc, http, listeners, log)
if err != nil {
return nil, err
Expand All @@ -116,6 +116,17 @@ func New(config *config.Config, opt ...Option) (*Reva, error) {
}, nil
}

func servicesAddresses(cfg *config.Config) map[string]grace.Addressable {
a := make(map[string]grace.Addressable)
cfg.GRPC.ForEachService(func(s *config.Service) {
a[s.Label] = &addr{address: s.Address, network: s.Network}
})
cfg.HTTP.ForEachService(func(s *config.Service) {
a[s.Label] = &addr{address: s.Address, network: s.Network}
})
return a
}

func newServerless(config *config.Config, log *zerolog.Logger) (*rserverless.Serverless, error) {
sl := make(map[string]rserverless.Service)
logger := log.With().Str("pkg", "serverless").Logger()
Expand Down Expand Up @@ -155,6 +166,8 @@ func setRandomAddresses(c *config.Config, lns map[string]net.Listener, log *zero
log.Fatal().Msg("port not assigned for service " + s.Label)
}
s.SetAddress(ln.Addr().String())
log.Debug().
Msgf("set random address %s:%s to service %s", ln.Addr().Network(), ln.Addr().String(), s.Label)
}
c.GRPC.ForEachService(f)
c.HTTP.ForEachService(f)
Expand All @@ -173,10 +186,9 @@ func (a *addr) Network() string {
return a.network
}

func groupGRPCByAddress(cfg *config.Config) (map[string]*config.GRPC, map[string]grace.Addressable) {
func groupGRPCByAddress(cfg *config.Config) []*config.GRPC {
// TODO: same address cannot be used in different configurations
g := map[string]*config.GRPC{}
a := map[string]grace.Addressable{}
cfg.GRPC.ForEachService(func(s *config.Service) {
if _, ok := g[s.Address]; !ok {
g[s.Address] = &config.GRPC{
Expand All @@ -188,17 +200,19 @@ func groupGRPCByAddress(cfg *config.Config) (map[string]*config.GRPC, map[string
Interceptors: cfg.GRPC.Interceptors,
}
}
a[s.Label] = &addr{address: s.Address, network: s.Network}
g[s.Address].Services[s.Name] = config.ServicesConfig{
{Config: s.Config, Address: s.Address, Network: s.Network},
{Config: s.Config, Address: s.Address, Network: s.Network, Label: s.Label},
}
})
return g, a
l := make([]*config.GRPC, 0, len(g))
for _, c := range g {
l = append(l, c)
}
return l
}

func groupHTTPByAddress(cfg *config.Config) (map[string]*config.HTTP, map[string]grace.Addressable) {
func groupHTTPByAddress(cfg *config.Config) []*config.HTTP {
g := map[string]*config.HTTP{}
a := map[string]grace.Addressable{}
cfg.HTTP.ForEachService(func(s *config.Service) {
if _, ok := g[s.Address]; !ok {
g[s.Address] = &config.HTTP{
Expand All @@ -210,12 +224,15 @@ func groupHTTPByAddress(cfg *config.Config) (map[string]*config.HTTP, map[string
Middlewares: cfg.HTTP.Middlewares,
}
}
a[s.Label] = &addr{address: s.Address, network: s.Network}
g[s.Address].Services[s.Name] = config.ServicesConfig{
{Config: s.Config, Address: s.Address, Network: s.Network},
{Config: s.Config, Address: s.Address, Network: s.Network, Label: s.Label},
}
})
return g, a
l := make([]*config.HTTP, 0, len(g))
for _, c := range g {
l = append(l, c)
}
return l
}

func (r *Reva) Start() error {
Expand Down Expand Up @@ -314,27 +331,16 @@ func adjustCPU(cpu string) (int, error) {
return numCPU, nil
}

func firstKey[K comparable, V any](m map[K]V) (K, bool) {
for k := range m {
return k, true
}
var k K
return k, false
}

func listenerFromServices[V any](lns map[string]net.Listener, svcs map[string]V) net.Listener {
svc, ok := firstKey(svcs)
if !ok {
panic("services map should be not empty")
}
ln, ok := lns[svc]
if !ok {
panic("listener not assigned for service " + svc)
func listenerFromAddress(lns map[string]net.Listener, network, address string) net.Listener {
for _, ln := range lns {
if netutil.AddressEqual(ln.Addr(), network, address) {
return ln
}
}
return ln
panic(fmt.Sprintf("listener not found for address %s:%s", network, address))
}

func newServers(grpc map[string]*config.GRPC, http map[string]*config.HTTP, lns map[string]net.Listener, log *zerolog.Logger) ([]*Server, error) {
func newServers(grpc []*config.GRPC, http []*config.HTTP, lns map[string]net.Listener, log *zerolog.Logger) ([]*Server, error) {
servers := make([]*Server, 0, len(grpc)+len(http))
for _, cfg := range grpc {
services, err := rgrpc.InitServices(cfg.Services)
Expand All @@ -356,7 +362,7 @@ func newServers(grpc map[string]*config.GRPC, http map[string]*config.HTTP, lns
if err != nil {
return nil, err
}
ln := listenerFromServices(lns, services)
ln := listenerFromAddress(lns, cfg.Network, cfg.Address)
server := &Server{
server: s,
listener: ln,
Expand Down Expand Up @@ -385,7 +391,7 @@ func newServers(grpc map[string]*config.GRPC, http map[string]*config.HTTP, lns
if err != nil {
return nil, err
}
ln := listenerFromServices(lns, services)
ln := listenerFromAddress(lns, cfg.Network, cfg.Address)
server := &Server{
server: s,
listener: ln,
Expand Down
33 changes: 33 additions & 0 deletions pkg/utils/net/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package net

import "net"

func AddressEqual(a net.Addr, network, address string) bool {
if a.Network() != network {
return false
}

switch network {
case "tcp":
t, err := net.ResolveTCPAddr(network, address)
if err != nil {
return false
}
return tcpAddressEqual(a.(*net.TCPAddr), t)
case "unix":
t, err := net.ResolveUnixAddr(network, address)
if err != nil {
return false
}
return unixAddressEqual(a.(*net.UnixAddr), t)
}
return false
}

func tcpAddressEqual(a1, a2 *net.TCPAddr) bool {
return a1.Port == a2.Port
}

func unixAddressEqual(a1, a2 *net.UnixAddr) bool {
return a1.Name == a2.Name && a1.Net == a2.Net
}

0 comments on commit 5ae5c4d

Please sign in to comment.