diff --git a/services/snmptrap/service.go b/services/snmptrap/service.go index 06d9ea85fd..15fa1c942c 100644 --- a/services/snmptrap/service.go +++ b/services/snmptrap/service.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "strconv" + "sync" "sync/atomic" text "text/template" @@ -14,9 +15,10 @@ import ( ) type Service struct { - configValue atomic.Value - snmpClientValue atomic.Value - logger *log.Logger + configValue atomic.Value + clientMu sync.Mutex + client *snmpgo.SNMP + logger *log.Logger } func NewService(c Config, l *log.Logger) *Service { @@ -24,38 +26,39 @@ func NewService(c Config, l *log.Logger) *Service { logger: l, } s.configValue.Store(c) - s.snmpClientValue.Store((*snmpgo.SNMP)(nil)) return s } func (s *Service) Open() error { c := s.config() if c.Enabled { - snmp, err := s.newSNMPClient(c) + err := s.loadNewSNMPClient(c) if err != nil { return err } - s.snmpClientValue.Store(snmp) } return nil } func (s *Service) Close() error { - snmp := s.snmpClient() - if snmp != nil { - snmp.Close() - } + s.closeClient() return nil } +func (s *Service) closeClient() { + s.clientMu.Lock() + defer s.clientMu.Unlock() + if s.client != nil { + s.client.Close() + } + s.client = nil +} + func (s *Service) config() Config { return s.configValue.Load().(Config) } -func (s *Service) snmpClient() *snmpgo.SNMP { - return s.snmpClientValue.Load().(*snmpgo.SNMP) -} -func (s *Service) newSNMPClient(c Config) (*snmpgo.SNMP, error) { +func (s *Service) loadNewSNMPClient(c Config) error { snmp, err := snmpgo.NewSNMP(snmpgo.SNMPArguments{ Version: snmpgo.V2c, Address: c.Addr, @@ -63,9 +66,12 @@ func (s *Service) newSNMPClient(c Config) (*snmpgo.SNMP, error) { Community: c.Community, }) if err != nil { - return nil, errors.Wrap(err, "invalid SNMP configuration") + return errors.Wrap(err, "invalid SNMP configuration") } - return snmp, nil + s.clientMu.Lock() + s.client = snmp + s.clientMu.Unlock() + return nil } func (s *Service) Update(newConfig []interface{}) error { @@ -78,17 +84,12 @@ func (s *Service) Update(newConfig []interface{}) error { old := s.config() if old != c { if c.Enabled { - snmp, err := s.newSNMPClient(c) + err := s.loadNewSNMPClient(c) if err != nil { return err } - s.snmpClientValue.Store(snmp) } else { - snmp := s.snmpClient() - if snmp != nil { - snmp.Close() - } - s.snmpClientValue.Store((*snmpgo.SNMP)(nil)) + s.closeClient() } s.configValue.Store(c) } @@ -185,12 +186,9 @@ func (s *Service) Trap(trapOid string, dataList []Data, level alert.Level) error } } - snmp := s.snmpClient() - // Open snmp client, idempotent call - if err := snmp.Open(); err != nil { - return errors.Wrap(err, "failed to SNMP open connection") - } - if err = snmp.V2Trap(varBinds); err != nil { + s.clientMu.Lock() + defer s.clientMu.Unlock() + if err = s.client.V2Trap(varBinds); err != nil { return errors.Wrap(err, "failed to send SNMP trap") } return nil