Skip to content

Commit

Permalink
feat: add sftp library (#393)
Browse files Browse the repository at this point in the history
Co-authored-by: Dilip Kola <kdilipkola@gmail.com>
Co-authored-by: Akash Chetty <achetty.iitr@gmail.com>
  • Loading branch information
3 people authored Apr 5, 2024
1 parent c788d93 commit f0b67e9
Show file tree
Hide file tree
Showing 7 changed files with 607 additions and 1 deletion.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ require (
github.com/mitchellh/mapstructure v1.5.0
github.com/ory/dockertest/v3 v3.10.0
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5
github.com/pkg/sftp v1.13.6
github.com/prometheus/client_golang v1.19.0
github.com/prometheus/client_model v0.6.0
github.com/prometheus/common v0.51.1
Expand Down Expand Up @@ -113,7 +114,6 @@ require (
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pierrec/lz4/v4 v4.1.17 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkg/sftp v1.13.6 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/prometheus/procfs v0.12.0 // indirect
Expand Down
113 changes: 113 additions & 0 deletions sftp/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//go:generate mockgen -destination=mock_sftp/mock_sftp_client.go -package mock_sftp github.com/rudderlabs/rudder-go-kit/sftp Client
package sftp

import (
"errors"
"fmt"
"io"
"time"

"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)

// SSHConfig represents the configuration for SSH connection
type SSHConfig struct {
HostName string
Port int
User string
AuthMethod string
PrivateKey string
Password string // Password for password-based authentication
DialTimeout time.Duration
}

// sshClientConfig constructs an SSH client configuration based on the provided SSHConfig.
func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) {
if config == nil {
return nil, errors.New("config should not be nil")
}

if config.HostName == "" {
return nil, errors.New("hostname should not be empty")
}

if config.Port == 0 {
return nil, errors.New("port should not be empty")
}

if config.User == "" {
return nil, errors.New("user should not be empty")
}

var authMethods ssh.AuthMethod

switch config.AuthMethod {
case PasswordAuth:
authMethods = ssh.Password(config.Password)
case KeyAuth:
privateKey, err := ssh.ParsePrivateKey([]byte(config.PrivateKey))
if err != nil {
return nil, fmt.Errorf("cannot parse private key: %w", err)
}
authMethods = ssh.PublicKeys(privateKey)
default:
return nil, errors.New("unsupported authentication method")
}

sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: []ssh.AuthMethod{authMethods},
Timeout: config.DialTimeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

return sshConfig, nil
}

// NewSSHClient establishes an SSH connection and returns an SSH client
func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
sshConfig, err := sshClientConfig(config)
if err != nil {
return nil, fmt.Errorf("cannot configure SSH client: %w", err)
}

sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", config.HostName, config.Port), sshConfig)
if err != nil {
return nil, fmt.Errorf("cannot dial SSH host %q:%d: %w", config.HostName, config.Port, err)
}
return sshClient, nil
}

type clientImpl struct {
client *sftp.Client
}

type Client interface {
Create(path string) (io.WriteCloser, error)
Open(path string) (io.ReadCloser, error)
Remove(path string) error
}

// newSFTPClient creates an SFTP client with existing SSH client
func newSFTPClient(client *ssh.Client) (Client, error) {
sftpClient, err := sftp.NewClient(client)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
}
return &clientImpl{
client: sftpClient,
}, nil
}

func (c *clientImpl) Create(path string) (io.WriteCloser, error) {
return c.client.Create(path)
}

func (c *clientImpl) Open(path string) (io.ReadCloser, error) {
return c.client.Open(path)
}

func (c *clientImpl) Remove(path string) error {
return c.client.Remove(path)
}
79 changes: 79 additions & 0 deletions sftp/mock_sftp/mock_sftp_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 101 additions & 0 deletions sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package sftp

import (
"fmt"
"io"
"os"
"path/filepath"

"golang.org/x/crypto/ssh"
)

const (
// PasswordAuth indicates password-based authentication
PasswordAuth = "passwordAuth"
// KeyAuth indicates key-based authentication
KeyAuth = "keyAuth"
)

// FileManager is an interface for managing files on a remote server
type FileManager interface {
Upload(localFilePath, remoteDir string) error
Download(remoteFilePath, localDir string) error
Delete(remoteFilePath string) error
}

// fileManagerImpl is a real implementation of FileManager
type fileManagerImpl struct {
client Client
}

func NewFileManager(sshClient *ssh.Client) (FileManager, error) {
sftpClient, err := newSFTPClient(sshClient)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
}
return &fileManagerImpl{client: sftpClient}, nil
}

// Upload uploads a file to the remote server
func (fm *fileManagerImpl) Upload(localFilePath, remoteDir string) error {
localFile, err := os.Open(localFilePath)
if err != nil {
return fmt.Errorf("cannot open local file: %w", err)
}
defer func() {
_ = localFile.Close()
}()

remoteFileName := filepath.Join(remoteDir, filepath.Base(localFilePath))
remoteFile, err := fm.client.Create(remoteFileName)
if err != nil {
return fmt.Errorf("cannot create remote file: %w", err)
}
defer func() {
_ = remoteFile.Close()
}()

_, err = io.Copy(remoteFile, localFile)
if err != nil {
return fmt.Errorf("error copying file: %w", err)
}

return nil
}

// Download downloads a file from the remote server
func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {
remoteFile, err := fm.client.Open(remoteFilePath)
if err != nil {
return fmt.Errorf("cannot open remote file: %w", err)
}
defer func() {
_ = remoteFile.Close()
}()

localFileName := filepath.Join(localDir, filepath.Base(remoteFilePath))
localFile, err := os.Create(localFileName)
if err != nil {
return fmt.Errorf("cannot create local file: %w", err)
}
defer func() {
_ = localFile.Close()
}()

_, err = io.Copy(localFile, remoteFile)
if err != nil {
return fmt.Errorf("cannot copy remote file content to local file: %w", err)
}

return nil
}

// Delete deletes a file on the remote server
func (fm *fileManagerImpl) Delete(remoteFilePath string) error {
err := fm.client.Remove(remoteFilePath)
if err != nil {
return fmt.Errorf("cannot delete file: %w", err)
}

return nil
}
Loading

0 comments on commit f0b67e9

Please sign in to comment.