diff --git a/connection_manager.go b/connection_manager.go index 83040b36..aa7961ba 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -14,6 +14,8 @@ package ouroboros +import "sync" + // ConnectionManagerErrorFunc is a function that takes a connection ID and an error type ConnectionManagerErrorFunc func(int, error) @@ -53,9 +55,10 @@ func (c ConnectionManagerTag) String() string { } type ConnectionManager struct { - config ConnectionManagerConfig - hosts []ConnectionManagerHost - connections map[int]*ConnectionManagerConnection + config ConnectionManagerConfig + hosts []ConnectionManagerHost + connections map[int]*ConnectionManagerConnection + connectionsMutex sync.Mutex } type ConnectionManagerConfig struct { @@ -107,10 +110,12 @@ func (c *ConnectionManager) AddHostsFromTopology(topology *TopologyConfig) { } func (c *ConnectionManager) AddConnection(connId int, conn *Connection) { + c.connectionsMutex.Lock() c.connections[connId] = &ConnectionManagerConnection{ Id: connId, Conn: conn, } + c.connectionsMutex.Unlock() go func() { err, ok := <-conn.ErrorChan() if !ok { @@ -123,15 +128,20 @@ func (c *ConnectionManager) AddConnection(connId int, conn *Connection) { } func (c *ConnectionManager) RemoveConnection(connId int) { + c.connectionsMutex.Lock() delete(c.connections, connId) + c.connectionsMutex.Unlock() } func (c *ConnectionManager) GetConnectionById(connId int) *ConnectionManagerConnection { + c.connectionsMutex.Lock() + defer c.connectionsMutex.Unlock() return c.connections[connId] } func (c *ConnectionManager) GetConnectionsByTags(tags ...ConnectionManagerTag) []*ConnectionManagerConnection { var ret []*ConnectionManagerConnection + c.connectionsMutex.Lock() for _, conn := range c.connections { skipConn := false for _, tag := range tags { @@ -144,6 +154,7 @@ func (c *ConnectionManager) GetConnectionsByTags(tags ...ConnectionManagerTag) [ ret = append(ret, conn) } } + c.connectionsMutex.Unlock() return ret }