Skip to content

Commit

Permalink
Improved session management in loops
Browse files Browse the repository at this point in the history
  • Loading branch information
rg2011 committed Oct 10, 2019
1 parent 9f98292 commit 3625cef
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 49 deletions.
42 changes: 21 additions & 21 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,29 @@ func (p *Pool) Push(md, username, pass string, commands []Task, script Script, u
defer p.wg.Done()
defer controller.Close()
defer close(stream)
if err := controller.Dial(); err != nil {
stream <- Result{Data: nil, Err: err}
return
}
for repeat := true; repeat; {
data, done, err := func() ([]interface{}, bool, error) {
// Concurrency limit
p.sem <- struct{}{}
defer func() { <-p.sem }()
// Iterate on the switches, delivering tasks to the queue
return p.run(controller, commands, script)
}()
// Do not wait on the stream with the semaphore locked!
for {
// Dial does session caching, will refresh credentials if needed
var data []interface{}
var done bool
err := controller.Dial()
if err == nil {
data, done, err = func() ([]interface{}, bool, error) {
// Do this in a closure to use defer() and make sure
// we release the lock after running the task, whatever the error
p.sem <- struct{}{}
defer func() { <-p.sem }()
// Iterate on the switches, delivering tasks to the queue
return p.run(controller, commands, script)
}()
}
stream <- Result{Data: data, Err: err}
if done || p.loop <= 0 {
repeat = false
} else {
select {
case <-time.After(p.loop):
repeat = true
case <-p.cancel:
repeat = false
}
return
}
select {
case <-time.After(p.loop): // do nothing
case <-p.cancel:
return
}
}
}()
Expand Down
65 changes: 37 additions & 28 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ func (c *Controller) IP() string {
}

// Login opens a new session to the controller
func (c *Controller) login() error {
func (c *Controller) login() (token string, expires time.Time, err error) {
apiURL, data := fmt.Sprintf("%s/api/login", c.url), url.Values{}
parsedURL, err := url.Parse(apiURL)
if err != nil {
return errors.Wrapf(err, "Parsing login URL '%s' failed", apiURL)
return token, expires, errors.Wrapf(err, "Parsing login URL '%s' failed", apiURL)
}
data.Set("username", c.username)
data.Set("password", c.password)
req, err := http.NewRequest(http.MethodPost, apiURL, strings.NewReader(data.Encode()))
if err != nil {
return errors.Wrapf(err, "Building request for '%s' failed", apiURL)
return token, expires, errors.Wrapf(err, "Building request for '%s' failed", apiURL)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Accept", "application/json")
Expand All @@ -83,31 +83,31 @@ func (c *Controller) login() error {
defer resp.Body.Close()
}
if err != nil {
return errors.Wrapf(err, "Login request to MD '%s' failed", c.md)
return token, expires, errors.Wrapf(err, "Login request to MD '%s' failed", c.md)
}
if resp.StatusCode != 200 {
return errors.Errorf("MD '%s': Login incorrect (username '%s')", c.md, c.username)
return token, expires, errors.Errorf("MD '%s': Login incorrect (username '%s')", c.md, c.username)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Errorf("MD '%s': Could not read login response", c.md)
return token, expires, errors.Errorf("MD '%s': Could not read login response", c.md)
}
lr := loginResponse{}
if err := json.Unmarshal(body, &lr); err != nil {
return errors.Errorf("MD '%s': Expected login response, got '%s'", c.md, string(body))
return token, expires, errors.Errorf("MD '%s': Expected login response, got '%s'", c.md, string(body))
}
c.lastToken = lr.GlobalResult.UIDARUBA
token = lr.GlobalResult.UIDARUBA
for _, cookie := range c.client.Jar.Cookies(parsedURL) {
if cookie.Name == "SESSION" {
c.expires = cookie.Expires
return nil
expires = cookie.Expires
return token, expires, nil
}
}
return errors.Errorf("MD '%s': No SESSION cookie received", c.md)
return token, expires, errors.Errorf("MD '%s': No SESSION cookie received", c.md)
}

// sshLogin opens a new session to the controller
func (c *Controller) sshLogin() error {
func (c *Controller) sshLogin() (*ssh.Client, error) {
config := &ssh.ClientConfig{
User: c.username,
Auth: []ssh.AuthMethod{
Expand All @@ -118,14 +118,18 @@ func (c *Controller) sshLogin() error {
}
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:22", c.md), config)
if err != nil {
return errors.Wrapf(err, "Failed to dial SSH to '%s'", c.md)
return nil, errors.Wrapf(err, "Failed to dial SSH to '%s'", c.md)
}
c.sshClient = client
c.lastSSH = time.Now()
return nil
return client, nil
}

func (c *Controller) logout() error {
defer func() {
// Make sure we clean the struct no matter what
c.lastUsed = time.Time{}
c.lastToken = ""
c.expires = time.Time{}
}()
apiURL := fmt.Sprintf("%s/api/logout?UIDARUBA=%s", c.url, c.lastToken)
resp, err := c.client.Get(apiURL)
if resp != nil && resp.Body != nil {
Expand All @@ -148,12 +152,14 @@ func (c *Controller) logout() error {
}

func (c *Controller) sshLogout() error {
if c.sshClient == nil {
return nil
defer func() {
c.sshClient = nil
c.lastSSH = time.Time{}
}()
if c.sshClient != nil {
return errors.WithStack(c.sshClient.Close())
}
err := c.sshClient.Close()
c.sshClient = nil
return errors.WithStack(err)
return nil
}

// Dial an SSH API session, before running
Expand All @@ -164,28 +170,33 @@ func (c *Controller) Dial() error {
if c.lastToken != "" {
c.logout()
}
if err := c.login(); err != nil {
token, expires, err := c.login()
if err != nil {
return err
}
c.lastToken = token
c.expires = expires
}
c.lastUsed = now
// SSH is only dialed on demand
if c.useSSH {
return c.sshDial()
return c.sshDial(now)
}
return nil
}

func (c *Controller) sshDial() error {
now := time.Now()
func (c *Controller) sshDial(now time.Time) error {
if (c.sshClient == nil) || c.lastSSH.IsZero() || (now.Sub(c.lastSSH).Minutes() > 5) {
if c.sshClient != nil {
c.sshLogout()
}
if err := c.sshLogin(); err != nil {
client, err := c.sshLogin()
if err != nil {
return err
}
c.sshClient = client
}
c.lastSSH = now
return nil
}

Expand All @@ -196,13 +207,11 @@ func (c *Controller) Close() error {
if err1 := c.logout(); err1 != nil {
err = err1
}
c.lastToken = ""
}
if c.sshClient != nil {
if err2 := c.sshLogout(); err2 != nil && err == nil {
err = err2
}
c.sshClient = nil
}
if err != nil {
// A common pattern will be just defer session.Close()
Expand Down

0 comments on commit 3625cef

Please sign in to comment.