Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allocate real pty #8

Merged
merged 11 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions _examples/ssh-ptystart/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"fmt"
"io"
"log"
"os"
"os/exec"
"runtime"
"time"

"github.com/charmbracelet/ssh"
)

func main() {
ssh.Handle(func(s ssh.Session) {
log.Printf("connected %s %s %q", s.User(), s.RemoteAddr(), s.RawCommand())
defer log.Printf("disconnected %s %s", s.User(), s.RemoteAddr())

pty, _, ok := s.Pty()
if !ok {
io.WriteString(s, "No PTY requested.\n")
s.Exit(1)
return
}

name := "bash"
if runtime.GOOS == "windows" {
name = "powershell.exe"
}
cmd := exec.Command(name)
cmd.Env = append(os.Environ(), "SSH_TTY="+pty.Name(), fmt.Sprintf("TERM=%s", pty.Term))
if err := pty.Start(cmd); err != nil {
fmt.Fprintln(s, err.Error())
s.Exit(1)
return
}

if runtime.GOOS == "windows" {
// ProcessState gets populated by pty.Start waiting on the process
// to exit.
for cmd.ProcessState == nil {
time.Sleep(100 * time.Millisecond)
}

s.Exit(cmd.ProcessState.ExitCode())
} else {
if err := cmd.Wait(); err != nil {
fmt.Fprintln(s, err)
s.Exit(cmd.ProcessState.ExitCode())
}
}
})

log.Println("starting ssh server on port 2222...")
if err := ssh.ListenAndServe(":2222", nil, ssh.AllocatePty()); err != nil && err != ssh.ErrServerClosed {
log.Fatal(err)
}
}
4 changes: 4 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ var (
// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}

// ContextKeySession is a context key for use with Contexts in this package.
// The associated value will be of type Session.
ContextKeySession = &contextKey{"session"}
)

// Context is a package specific context interface. It exposes connection
Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ go 1.17

require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/charmbracelet/x/exp/term v0.0.0-20240117030132-5a84c80527c7
github.com/creack/pty v1.1.21
github.com/u-root/u-root v0.11.0
aymanbagabas marked this conversation as resolved.
Show resolved Hide resolved
golang.org/x/crypto v0.17.0
golang.org/x/sys v0.16.0
)

require golang.org/x/sys v0.15.0 // indirect
require github.com/charmbracelet/x/errors v0.0.0-20240117030013-d31dba354651 // indirect
342 changes: 341 additions & 1 deletion go.sum

Large diffs are not rendered by default.

30 changes: 29 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func HostKeyPEM(bytes []byte) Option {
// denying PTY requests.
func NoPty() Option {
return func(srv *Server) error {
srv.PtyCallback = func(ctx Context, pty Pty) bool {
srv.PtyCallback = func(Context, Pty) bool {
return false
}
return nil
Expand All @@ -82,3 +82,31 @@ func WrapConn(fn ConnCallback) Option {
return nil
}
}

var contextKeyEmulatePty = &contextKey{"emulate-pty"}

func emulatePtyHandler(ctx Context, _ Session, _ Pty) (func() error, error) {
ctx.SetValue(contextKeyEmulatePty, true)
return func() error { return nil }, nil
}

// EmulatePty returns a functional option that fakes a PTY. It uses PtyWriter
// underneath.
func EmulatePty() Option {
return func(s *Server) error {
s.PtyHandler = emulatePtyHandler
return nil
}
}

// AllocatePty returns a functional option that allocates a PTY. Implementers
// who wish to use an actual PTY should use this along with the platform
// specific PTY implementation defined in pty_*.go.
func AllocatePty() Option {
return func(s *Server) error {
s.PtyHandler = func(_ Context, s Session, pty Pty) (func() error, error) {
return s.(*session).ptyAllocate(pty.Term, pty.Window, pty.Modes)
}
return nil
}
}
14 changes: 14 additions & 0 deletions pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package ssh

import (
"bytes"
"errors"
"io"
"os/exec"
)

// ErrUnsupported is returned when the platform does not support PTY.
var ErrUnsupported = errors.New("pty unsupported")

// NewPtyWriter creates a writer that handles when the session has a active
// PTY, replacing the \n with \r\n.
func NewPtyWriter(w io.Writer) io.Writer {
Expand Down Expand Up @@ -55,3 +60,12 @@ func (rw readWriterDelegate) Read(p []byte) (n int, err error) {
func (rw readWriterDelegate) Write(p []byte) (n int, err error) {
return rw.w.Write(p)
}

// Start starts a *exec.Cmd attached to the Session. If a PTY is allocated,
// it will use that for I/O.
// On Windows, the process execution lifecycle is not managed by Go and has to
// be managed manually. This means that c.Wait() won't work.
// See https://github.com/charmbracelet/x/blob/main/exp/term/windows/conpty/conpty_windows.go
func (p *Pty) Start(c *exec.Cmd) error {
return p.start(c)
}
44 changes: 44 additions & 0 deletions pty_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris && !windows
// +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris,!windows

package ssh

import (
"os/exec"

"golang.org/x/crypto/ssh"
)

type impl struct{}

func (i *impl) IsZero() bool {
return true
}

func (i *impl) Name() string {
return ""
}

func (i *impl) Read(p []byte) (n int, err error) {
return 0, ErrUnsupported
}

func (i *impl) Write(p []byte) (n int, err error) {
return 0, ErrUnsupported
}

func (i *impl) Resize(w int, h int) error {
return ErrUnsupported
}

func (i *impl) Close() error {
return nil
}

func (*impl) start(*exec.Cmd) error {
return ErrUnsupported
}

func newPty(Context, string, Window, ssh.TerminalModes) (impl, error) {
return impl{}, ErrUnsupported
}
199 changes: 199 additions & 0 deletions pty_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build darwin dragonfly freebsd linux netbsd openbsd solaris

package ssh

import (
"fmt"
"os"
"os/exec"
"syscall"

"github.com/creack/pty"
"github.com/u-root/u-root/pkg/termios"
"golang.org/x/crypto/ssh"
"golang.org/x/sys/unix"
)

type impl struct {
// Master is the master PTY file descriptor.
Master *os.File

// Slave is the slave PTY file descriptor.
Slave *os.File
}

func (i *impl) IsZero() bool {
return i.Master == nil && i.Slave == nil
}

// Name returns the name of the slave PTY.
func (i *impl) Name() string {
return i.Slave.Name()
}

// Read implements ptyInterface.
func (i *impl) Read(p []byte) (n int, err error) {
return i.Master.Read(p)
}

// Write implements ptyInterface.
func (i *impl) Write(p []byte) (n int, err error) {
return i.Master.Write(p)
}

func (i *impl) Close() error {
if err := i.Master.Close(); err != nil {
return err
}
return i.Slave.Close()
}

func (i *impl) Resize(w int, h int) (rErr error) {
conn, err := i.Master.SyscallConn()
if err != nil {
return err
}

return conn.Control(func(fd uintptr) {
rErr = termios.SetWinSize(fd, &termios.Winsize{
Winsize: unix.Winsize{
Row: uint16(h),
Col: uint16(w),
},
})
})
}

func (i *impl) start(c *exec.Cmd) error {
c.Stdin, c.Stdout, c.Stderr = i.Slave, i.Slave, i.Slave
if c.SysProcAttr == nil {
c.SysProcAttr = &syscall.SysProcAttr{}
}
c.SysProcAttr.Setctty = true
c.SysProcAttr.Setsid = true
return c.Start()
}

func newPty(_ Context, _ string, win Window, modes ssh.TerminalModes) (_ impl, rErr error) {
ptm, pts, err := pty.Open()
if err != nil {
return impl{}, err
}

conn, err := ptm.SyscallConn()
if err != nil {
return impl{}, err
}

if err := conn.Control(func(fd uintptr) {
rErr = applyTerminalModesToFd(fd, win.Width, win.Height, modes)
}); err != nil {
return impl{}, err
}

return impl{Master: ptm, Slave: pts}, rErr
}

func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes) error {
// Get the current TTY configuration.
tios, err := termios.GTTY(int(fd))
if err != nil {
return fmt.Errorf("GTTY: %w", err)
}

// Apply the modes from the SSH request.
tios.Row = height
tios.Col = width

for c, v := range modes {
if c == ssh.TTY_OP_ISPEED {
tios.Ispeed = int(v)
continue
}
if c == ssh.TTY_OP_OSPEED {
tios.Ospeed = int(v)
continue
}
k, ok := terminalModeFlagNames[c]
if !ok {
continue
}
if _, ok := tios.CC[k]; ok {
tios.CC[k] = uint8(v)
continue
}
if _, ok := tios.Opts[k]; ok {
tios.Opts[k] = v > 0
continue
}
}

// Save the new TTY configuration.
if _, err := tios.STTY(int(fd)); err != nil {
return fmt.Errorf("STTY: %w", err)
}

return nil
}

// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic
// names used by the termios package.
var terminalModeFlagNames = map[uint8]string{
ssh.VINTR: "intr",
ssh.VQUIT: "quit",
ssh.VERASE: "erase",
ssh.VKILL: "kill",
ssh.VEOF: "eof",
ssh.VEOL: "eol",
ssh.VEOL2: "eol2",
ssh.VSTART: "start",
ssh.VSTOP: "stop",
ssh.VSUSP: "susp",
ssh.VDSUSP: "dsusp",
ssh.VREPRINT: "rprnt",
ssh.VWERASE: "werase",
ssh.VLNEXT: "lnext",
ssh.VFLUSH: "flush",
ssh.VSWTCH: "swtch",
ssh.VSTATUS: "status",
ssh.VDISCARD: "discard",
ssh.IGNPAR: "ignpar",
ssh.PARMRK: "parmrk",
ssh.INPCK: "inpck",
ssh.ISTRIP: "istrip",
ssh.INLCR: "inlcr",
ssh.IGNCR: "igncr",
ssh.ICRNL: "icrnl",
ssh.IUCLC: "iuclc",
ssh.IXON: "ixon",
ssh.IXANY: "ixany",
ssh.IXOFF: "ixoff",
ssh.IMAXBEL: "imaxbel",
ssh.IUTF8: "iutf8",
ssh.ISIG: "isig",
ssh.ICANON: "icanon",
ssh.XCASE: "xcase",
ssh.ECHO: "echo",
ssh.ECHOE: "echoe",
ssh.ECHOK: "echok",
ssh.ECHONL: "echonl",
ssh.NOFLSH: "noflsh",
ssh.TOSTOP: "tostop",
ssh.IEXTEN: "iexten",
ssh.ECHOCTL: "echoctl",
ssh.ECHOKE: "echoke",
ssh.PENDIN: "pendin",
ssh.OPOST: "opost",
ssh.OLCUC: "olcuc",
ssh.ONLCR: "onlcr",
ssh.OCRNL: "ocrnl",
ssh.ONOCR: "onocr",
ssh.ONLRET: "onlret",
ssh.CS7: "cs7",
ssh.CS8: "cs8",
ssh.PARENB: "parenb",
ssh.PARODD: "parodd",
ssh.TTY_OP_ISPEED: "tty_op_ispeed",
ssh.TTY_OP_OSPEED: "tty_op_ospeed",
}
Loading