diff --git a/main.go b/main.go index d2c2b8e..f3a9b58 100644 --- a/main.go +++ b/main.go @@ -41,6 +41,7 @@ type rootCmd struct { noReuseConnection bool bindAddr string sshFlags string + codeServerPath string } func (c *rootCmd) Spec() cli.CommandSpec { @@ -58,6 +59,7 @@ func (c *rootCmd) RegisterFlags(fl *pflag.FlagSet) { fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket") fl.StringVar(&c.bindAddr, "bind", "", "local bind address for SSH tunnel, in [HOST][:PORT] syntax (default: 127.0.0.1)") fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags") + fl.StringVar(&c.codeServerPath, "code-server-path", "", "custom code-server binary to upload") } func (c *rootCmd) Run(fl *pflag.FlagSet) { @@ -84,6 +86,7 @@ func (c *rootCmd) Run(fl *pflag.FlagSet) { bindAddr: c.bindAddr, syncBack: c.syncBack, reuseConnection: !c.noReuseConnection, + codeServerPath: c.codeServerPath, }) if err != nil { diff --git a/sshcode.go b/sshcode.go index e4a623d..2e47332 100644 --- a/sshcode.go +++ b/sshcode.go @@ -36,6 +36,7 @@ type options struct { bindAddr string remotePort string sshFlags string + codeServerPath string } func sshCode(host, dir string, o options) error { @@ -76,23 +77,49 @@ func sshCode(host, dir string, o options) error { } } - flog.Info("ensuring code-server is updated...") - dlScript := downloadScript(codeServerPath) + // Upload local code-server or download code-server from CI server. + if o.codeServerPath != "" { + flog.Info("uploading local code-server binary...") + err = copyCodeServerBinary(o.sshFlags, host, o.codeServerPath, codeServerPath) + if err != nil { + return xerrors.Errorf("failed to upload local code-server binary to remote server: %w", err) + } - // Downloads the latest code-server and allows it to be executed. - sshCmdStr := fmt.Sprintf("ssh %v %v '/usr/bin/env bash -l'", o.sshFlags, host) + sshCmdStr := + fmt.Sprintf("ssh %v %v 'chmod +x %v'", + o.sshFlags, host, codeServerPath, + ) - sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) - sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr - sshCmd.Stdin = strings.NewReader(dlScript) - err = sshCmd.Run() - if err != nil { - return xerrors.Errorf("failed to update code-server: \n---ssh cmd---\n%s\n---download script---\n%s: %w", - sshCmdStr, - dlScript, - err, - ) + sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + err = sshCmd.Run() + if err != nil { + return xerrors.Errorf("failed to make code-server binary executable:\n---ssh cmd---\n%s: %w", + sshCmdStr, + err, + ) + } + } else { + flog.Info("ensuring code-server is updated...") + dlScript := downloadScript(codeServerPath) + + // Downloads the latest code-server and allows it to be executed. + sshCmdStr := fmt.Sprintf("ssh %v %v '/usr/bin/env bash -l'", o.sshFlags, host) + + sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + sshCmd.Stdin = strings.NewReader(dlScript) + err = sshCmd.Run() + if err != nil { + return xerrors.Errorf("failed to update code-server:\n---ssh cmd---\n%s"+ + "\n---download script---\n%s: %w", + sshCmdStr, + dlScript, + err, + ) + } } if !o.skipSync { @@ -117,13 +144,13 @@ func sshCode(host, dir string, o options) error { flog.Info("Tunneling remote port %v to %v", o.remotePort, o.bindAddr) - sshCmdStr = + sshCmdStr := fmt.Sprintf("ssh -tt -q -L %v:localhost:%v %v %v 'cd %v; %v --host 127.0.0.1 --allow-http --no-auth --port=%v'", o.bindAddr, o.remotePort, o.sshFlags, host, dir, codeServerPath, o.remotePort, ) // Starts code-server and forwards the remote port. - sshCmd = exec.Command("sh", "-l", "-c", sshCmdStr) + sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) sshCmd.Stdin = os.Stdin sshCmd.Stdout = os.Stdout sshCmd.Stderr = os.Stderr @@ -399,6 +426,20 @@ func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error return xerrors.Errorf("max number of tries exceeded: %d", maxTries) } +// copyCodeServerBinary copies a code-server binary from local to remote. +func copyCodeServerBinary(sshFlags string, host string, localPath string, remotePath string) error { + if err := ensureFile(localPath); err != nil { + return err + } + + var ( + src = localPath + dest = host + ":" + remotePath + ) + + return rsync(src, dest, sshFlags) +} + func syncUserSettings(sshFlags string, host string, back bool) error { localConfDir, err := configDir() if err != nil { @@ -517,6 +558,18 @@ func ensureDir(path string) error { return nil } +// ensureFile tries to stat the specified path and ensure it's a file. +func ensureFile(path string) error { + info, err := os.Stat(path) + if err != nil { + return err + } + if info.IsDir() { + return xerrors.New("path is a directory") + } + return nil +} + // parseHost parses the host argument. If 'gcp:' is prefixed to the // host then a lookup is done using gcloud to determine the external IP and any // additional SSH arguments that should be used for ssh commands. Otherwise, host