diff --git a/internal/sshutil/sshutil.go b/internal/sshutil/sshutil.go index b997479..be63f1f 100644 --- a/internal/sshutil/sshutil.go +++ b/internal/sshutil/sshutil.go @@ -5,9 +5,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "fmt" "net" - "os" "time" "golang.org/x/crypto/ssh" @@ -24,74 +22,6 @@ type SSHClient struct { config *ssh.ClientConfig } -// NewClient creates a new SSHClient -func NewClient(host string, user string, privKeyFile string) (*SSHClient, error) { - // Create the ssh config - cfg := &ssh.ClientConfig{ - User: user, - Auth: []ssh.AuthMethod{ - publicKeyFile(privKeyFile), - }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - - // Start the ssh connection - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%s", host, port), cfg) - if err != nil { - return nil, err - } - - // Return the ssh client - return &SSHClient{ - conn: conn, - config: cfg, - }, nil -} - -// RunAsSudo runs a command as sudo on the remote host -func (c *SSHClient) RunAsSudo(command string) error { - // Create a new session - session, err := c.conn.NewSession() - if err != nil { - return err - } - defer session.Close() - - // Run the command - return session.Run(fmt.Sprintf("sudo %s", command)) -} - -// RunAsUser runs a command as the user on the remote host -func (c *SSHClient) RunAsUser(command string) error { - // Create a new session - session, err := c.conn.NewSession() - if err != nil { - return err - } - defer session.Close() - - // Run the command - return session.Run(fmt.Sprintf(command)) -} - -// publicKeyFile reads the private key file and returns the ssh.AuthMethod -func publicKeyFile(file string) ssh.AuthMethod { - // Read the private key file - buffer, err := os.ReadFile(file) - if err != nil { - return nil - } - - // Parse the private key - key, err := ssh.ParsePrivateKey(buffer) - if err != nil { - return nil - } - - // Return the ssh.AuthMethod - return ssh.PublicKeys(key) -} - // GenerateNewSSHKeys generates a new SSH key pair func GenerateNewSSHKeys() ([]byte, []byte, error) { // Generate private key diff --git a/internal/sshutil/sshutil_test.go b/internal/sshutil/sshutil_test.go index 2d09540..85c03fa 100644 --- a/internal/sshutil/sshutil_test.go +++ b/internal/sshutil/sshutil_test.go @@ -1,30 +1,23 @@ package sshutil import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" "fmt" "io" - "io/ioutil" "log" - "os" - "path/filepath" "testing" - "github.com/gliderlabs/ssh" + sshsrv "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" ) -func sessionHandler(s ssh.Session) { +func sessionHandler(s sshsrv.Session) { io.WriteString(s, "Hello world\n") } const serverAddr = "127.0.0.1" func startServer(port string) { - s := &ssh.Server{ + s := &sshsrv.Server{ Addr: fmt.Sprintf("%s:%s", serverAddr, port), Handler: sessionHandler, } @@ -33,77 +26,6 @@ func startServer(port string) { log.Fatal(s.ListenAndServe()) } -// TestNewClient tests the NewClient function -func TestNewClient(t *testing.T) { - // Test creating a new SSHClient - port = "2222" - go startServer(port) - - tempDir := t.TempDir() - privKeyFile := filepath.Join(tempDir, "test_key.pem") - - // Create a temporary private key file for testing - privKeyData := generatePrivateKeyPEM(t) - err := ioutil.WriteFile(privKeyFile, privKeyData, 0600) - assert.Nil(t, err) - defer os.Remove(privKeyFile) - - // Test creating a new SSHClient - client, err := NewClient(serverAddr, "", privKeyFile) - assert.NotNil(t, client) - assert.Nil(t, err) -} - -// TestRunAsSudo tests the RunAsSudo function -func TestRunAsSudo(t *testing.T) { - // Test creating a new SSHClient - port = "2223" - go startServer(port) - - tempDir := t.TempDir() - privKeyFile := filepath.Join(tempDir, "test_key.pem") - - // Create a temporary private key file for testing - privKeyData := generatePrivateKeyPEM(t) - err := ioutil.WriteFile(privKeyFile, privKeyData, 0600) - assert.Nil(t, err) - defer os.Remove(privKeyFile) - - // Create a new SSHClient - client, err := NewClient(serverAddr, "", privKeyFile) - assert.NotNil(t, client) - assert.Nil(t, err) - - // Test running a command as sudo - err = client.RunAsSudo("ls") - assert.Nil(t, err) -} - -// TestRunAsUser tests the RunAsUser function -func TestRunAsUser(t *testing.T) { - // Test creating a new SSHClient - port = "2224" - go startServer(port) - - tempDir := t.TempDir() - privKeyFile := filepath.Join(tempDir, "test_key.pem") - - // Create a temporary private key file for testing - privKeyData := generatePrivateKeyPEM(t) - err := ioutil.WriteFile(privKeyFile, privKeyData, 0600) - assert.Nil(t, err) - defer os.Remove(privKeyFile) - - // Create a new SSHClient - client, err := NewClient(serverAddr, "", privKeyFile) - assert.NotNil(t, client) - assert.Nil(t, err) - - // Test running a command as sudo - err = client.RunAsUser("ls") - assert.Nil(t, err) -} - // TestGenerateNewSSHKeys tests the GenerateNewSSHKeys function func TestGenerateNewSSHKeys(t *testing.T) { // Test generating new SSH keys @@ -112,20 +34,3 @@ func TestGenerateNewSSHKeys(t *testing.T) { assert.NotNil(t, publicKeyBytes) assert.Nil(t, err) } - -// Helper function to generate a temporary private key file -func generatePrivateKeyPEM(t *testing.T) []byte { - privateKey, err := rsa.GenerateKey(rand.Reader, bitSize) - assert.Nil(t, err) - - privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) - privateKeyPEM := pem.EncodeToMemory( - &pem.Block{ - Type: "RSA PRIVATE KEY", - Headers: nil, - Bytes: privateKeyBytes, - }, - ) - - return privateKeyPEM -}