From a297bd0b701c033d0dc253707b49df43ea685805 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 4 Dec 2024 12:04:16 -0500 Subject: [PATCH 1/5] enhance: add mTLS between gptscript and daemon tools Signed-off-by: Grant Linville --- pkg/certs/certs.go | 63 ++++++++++++++++++++++++++++++++++++++ pkg/engine/daemon.go | 58 +++++++++++++++++++++++++++++++++-- pkg/engine/engine.go | 2 ++ pkg/engine/http.go | 43 +++++++++++++++++++++++++- pkg/gptscript/gptscript.go | 14 ++++++--- pkg/runner/runner.go | 7 ++++- pkg/tests/tester/runner.go | 6 +++- 7 files changed, 183 insertions(+), 10 deletions(-) create mode 100644 pkg/certs/certs.go diff --git a/pkg/certs/certs.go b/pkg/certs/certs.go new file mode 100644 index 00000000..f001532b --- /dev/null +++ b/pkg/certs/certs.go @@ -0,0 +1,63 @@ +package certs + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "time" +) + +// CertAndKey contains an x509 certificate (PEM format) and ECDSA private key (also PEM format) +type CertAndKey struct { + Cert []byte + Key []byte +} + +func GenerateGPTScriptCert() (CertAndKey, error) { + return GenerateSelfSignedCert("gptscript server") +} + +func GenerateSelfSignedCert(name string) (CertAndKey, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return CertAndKey{}, fmt.Errorf("failed to generate ECDSA key: %v", err) + } + + marshalledPrivateKey, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return CertAndKey{}, fmt.Errorf("failed to marshal ECDSA key: %v", err) + } + + marshalledPrivateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: marshalledPrivateKey}) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: name, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), // a year from now + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + IsCA: false, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + cert, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + return CertAndKey{}, fmt.Errorf("failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}) + + return CertAndKey{Cert: certPEM, Key: marshalledPrivateKeyPEM}, nil +} diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index b7877da3..cd068b5f 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -2,6 +2,9 @@ package engine import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" "fmt" "io" "math/rand" @@ -11,11 +14,13 @@ import ( "sync" "time" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" ) var ports Ports +var certificates Certs type Ports struct { daemonPorts map[string]int64 @@ -29,6 +34,11 @@ type Ports struct { daemonWG sync.WaitGroup } +type Certs struct { + daemonCerts map[string]certs.CertAndKey + daemonLock sync.Mutex +} + func IsDaemonRunning(url string) bool { ports.daemonLock.Lock() defer ports.daemonLock.Unlock() @@ -128,7 +138,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { tool.Instructions = types.CommandPrefix + instructions port, ok := ports.daemonPorts[tool.ID] - url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path) + url := fmt.Sprintf("https://127.0.0.1:%d%s", port, path) if ok && ports.daemonsRunning[url] != nil { return url, nil } @@ -144,11 +154,31 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { ctx := ports.daemonCtx port = nextPort() - url = fmt.Sprintf("http://127.0.0.1:%d%s", port, path) + url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path) + + // Generate a certificate for the daemon, unless one already exists. + certificates.daemonLock.Lock() + defer certificates.daemonLock.Unlock() + cert, exists := certificates.daemonCerts[tool.ID] + if !exists { + var err error + cert, err = certs.GenerateSelfSignedCert(tool.ID) + if err != nil { + return "", fmt.Errorf("failed to generate certificate for daemon: %v", err) + } + + if certificates.daemonCerts == nil { + certificates.daemonCerts = map[string]certs.CertAndKey{} + } + certificates.daemonCerts[tool.ID] = cert + } cmd, stop, err := e.newCommand(ctx, []string{ fmt.Sprintf("PORT=%d", port), + fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)), + fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)), fmt.Sprintf("GPTSCRIPT_PORT=%d", port), + fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)), }, tool, "{}", @@ -210,8 +240,30 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { ports.daemonWG.Done() }() + // Build HTTP client for checking the health of the daemon + clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key) + if err != nil { + return "", fmt.Errorf("failed to create client certificate: %v", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert.Cert) { + return "", fmt.Errorf("failed to append daemon certificate for [%s]", tool.ID) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, + }, + } + + // Check the health of the daemon for i := 0; i < 120; i++ { - resp, err := http.Get(url) + resp, err := httpClient.Get(url) if err == nil && resp.StatusCode == http.StatusOK { go func() { _, _ = io.ReadAll(resp.Body) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index a195a8b4..88cb07ae 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -7,6 +7,7 @@ import ( "strings" "sync" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" @@ -22,6 +23,7 @@ type RuntimeManager interface { } type Engine struct { + GPTScriptCert certs.CertAndKey Model Model RuntimeManager RuntimeManager Env []string diff --git a/pkg/engine/http.go b/pkg/engine/http.go index 109db559..d06c7169 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -2,6 +2,8 @@ package engine import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" @@ -40,6 +42,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too return nil, err } + var tlsConfigForDaemonRequest *tls.Config if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) { referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix) referencedToolRefs, ok := tool.ToolMapping[referencedToolName] @@ -60,6 +63,33 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too } parsed.Host = toolURLParsed.Host toolURL = parsed.String() + + // Find the certificate corresponding to this daemon tool + certificates.daemonLock.Lock() + daemonCert, exists := certificates.daemonCerts[referencedTool.ID] + certificates.daemonLock.Unlock() + + if !exists { + return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID) + } + + // Create a pool for the certificate to treat as a CA + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(daemonCert.Cert) { + return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID) + } + + clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate: %v", err) + } + + // Create TLS config for use in the HTTP client later + tlsConfigForDaemonRequest = &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + } } if tool.Blocking { @@ -112,7 +142,18 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too req.Header.Set("Content-Type", "text/plain") } - resp, err := http.DefaultClient.Do(req) + var httpClient *http.Client + if tlsConfigForDaemonRequest != nil { + httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfigForDaemonRequest, + }, + } + } else { + httpClient = http.DefaultClient + } + + resp, err := httpClient.Do(req) if err != nil { return nil, err } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index dfb1771a..cac519a8 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -12,6 +12,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/config" context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" @@ -107,7 +108,12 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir) } - simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env) + gptscriptCert, err := certs.GenerateGPTScriptCert() + if err != nil { + return nil, err + } + + simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert) if err != nil { return nil, err } @@ -140,7 +146,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet}) } - runner, err := runner.New(registry, credStore, opts.Runner) + runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner) if err != nil { return nil, err } @@ -285,8 +291,8 @@ type simpleRunner struct { env []string } -func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) { - runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{ +func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) { + runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{ RuntimeManager: rm, MonitorFactory: simpleMonitorFactory{}, }) diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index fc5737ef..931ab99b 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gptscript-ai/gptscript/pkg/builtin" + "github.com/gptscript-ai/gptscript/pkg/certs" context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" @@ -95,9 +96,10 @@ type Runner struct { credOverrides []string credStore credentials.CredentialStore sequential bool + gptscriptCert certs.CertAndKey } -func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) { +func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) { opt := complete(opts...) runner := &Runner{ @@ -109,6 +111,7 @@ func New(client engine.Model, credStore credentials.CredentialStore, opts ...Opt credStore: credStore, sequential: opt.Sequential, auth: opt.Authorizer, + gptscriptCert: gptscriptCert, } if opt.StartPort != 0 { @@ -411,6 +414,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, + GPTScriptCert: r.gptscriptCert, } callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause) @@ -593,6 +597,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, + GPTScriptCert: r.gptscriptCert, } var contentInput string diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index 44ec4e3c..22095270 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/adrg/xdg" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" @@ -198,7 +199,10 @@ func NewRunner(t *testing.T) *Runner { rm := runtimes.Default(cacheDir, "") - run, err := runner.New(c, credentials.NoopStore{}, runner.Options{ + gptscriptCert, err := certs.GenerateGPTScriptCert() + require.NoError(t, err) + + run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{ Sequential: true, RuntimeManager: rm, }) From cd73f59335b02900774e01e258a224d331c83f16 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 5 Dec 2024 15:37:52 -0500 Subject: [PATCH 2/5] close the body Signed-off-by: Grant Linville --- pkg/engine/daemon.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index cd068b5f..a6d2975a 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -270,6 +270,8 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { _ = resp.Body.Close() }() return url, nil + } else { + _ = resp.Body.Close() } select { case <-killedCtx.Done(): From 94ffd427ccbfbce9b71ba63437ba8ba2796c9ef0 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 5 Dec 2024 15:57:20 -0500 Subject: [PATCH 3/5] fix lint issue Signed-off-by: Grant Linville --- pkg/engine/daemon.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index a6d2975a..31f96018 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -270,9 +270,8 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { _ = resp.Body.Close() }() return url, nil - } else { - _ = resp.Body.Close() } + _ = resp.Body.Close() select { case <-killedCtx.Done(): return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx)) From 2b7cb50bdb5d076e3a6211ea0961193f569a8613 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 16 Dec 2024 16:27:29 -0500 Subject: [PATCH 4/5] improvements Signed-off-by: Grant Linville --- pkg/engine/daemon.go | 23 ++++++++++++++++------- pkg/engine/engine.go | 2 -- pkg/engine/http.go | 9 +++++---- pkg/gptscript/gptscript.go | 14 ++++---------- pkg/runner/runner.go | 7 +------ pkg/tests/tester/runner.go | 6 +----- 6 files changed, 27 insertions(+), 34 deletions(-) diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 31f96018..899b27fc 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -36,7 +36,8 @@ type Ports struct { type Certs struct { daemonCerts map[string]certs.CertAndKey - daemonLock sync.Mutex + clientCert certs.CertAndKey + lock sync.Mutex } func IsDaemonRunning(url string) bool { @@ -157,8 +158,8 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path) // Generate a certificate for the daemon, unless one already exists. - certificates.daemonLock.Lock() - defer certificates.daemonLock.Unlock() + certificates.lock.Lock() + defer certificates.lock.Unlock() cert, exists := certificates.daemonCerts[tool.ID] if !exists { var err error @@ -173,12 +174,21 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { certificates.daemonCerts[tool.ID] = cert } + // Set the client certificate if there isn't one already. + if len(certificates.clientCert.Cert) == 0 { + gptscriptCert, err := certs.GenerateGPTScriptCert() + if err != nil { + return "", fmt.Errorf("failed to generate GPTScript certificate: %v", err) + } + certificates.clientCert = gptscriptCert + } + cmd, stop, err := e.newCommand(ctx, []string{ fmt.Sprintf("PORT=%d", port), fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)), fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)), fmt.Sprintf("GPTSCRIPT_PORT=%d", port), - fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)), + fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(certificates.clientCert.Cert)), }, tool, "{}", @@ -241,7 +251,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { }() // Build HTTP client for checking the health of the daemon - clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key) + tlsClientCert, err := tls.X509KeyPair(certificates.clientCert.Cert, certificates.clientCert.Key) if err != nil { return "", fmt.Errorf("failed to create client certificate: %v", err) } @@ -254,7 +264,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { httpClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ - Certificates: []tls.Certificate{clientCert}, + Certificates: []tls.Certificate{tlsClientCert}, RootCAs: pool, InsecureSkipVerify: false, }, @@ -271,7 +281,6 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { }() return url, nil } - _ = resp.Body.Close() select { case <-killedCtx.Done(): return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx)) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 88cb07ae..a195a8b4 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" @@ -23,7 +22,6 @@ type RuntimeManager interface { } type Engine struct { - GPTScriptCert certs.CertAndKey Model Model RuntimeManager RuntimeManager Env []string diff --git a/pkg/engine/http.go b/pkg/engine/http.go index d06c7169..d8bf0ef2 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -65,9 +65,10 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too toolURL = parsed.String() // Find the certificate corresponding to this daemon tool - certificates.daemonLock.Lock() + certificates.lock.Lock() daemonCert, exists := certificates.daemonCerts[referencedTool.ID] - certificates.daemonLock.Unlock() + clientCert := certificates.clientCert + certificates.lock.Unlock() if !exists { return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID) @@ -79,14 +80,14 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID) } - clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key) + tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) if err != nil { return nil, fmt.Errorf("failed to create client certificate: %v", err) } // Create TLS config for use in the HTTP client later tlsConfigForDaemonRequest = &tls.Config{ - Certificates: []tls.Certificate{clientCert}, + Certificates: []tls.Certificate{tlsClientCert}, RootCAs: pool, InsecureSkipVerify: false, } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index cac519a8..dfb1771a 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -12,7 +12,6 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" - "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/config" context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" @@ -108,12 +107,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir) } - gptscriptCert, err := certs.GenerateGPTScriptCert() - if err != nil { - return nil, err - } - - simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert) + simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env) if err != nil { return nil, err } @@ -146,7 +140,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet}) } - runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner) + runner, err := runner.New(registry, credStore, opts.Runner) if err != nil { return nil, err } @@ -291,8 +285,8 @@ type simpleRunner struct { env []string } -func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) { - runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{ +func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) { + runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{ RuntimeManager: rm, MonitorFactory: simpleMonitorFactory{}, }) diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 931ab99b..fc5737ef 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gptscript-ai/gptscript/pkg/builtin" - "github.com/gptscript-ai/gptscript/pkg/certs" context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" @@ -96,10 +95,9 @@ type Runner struct { credOverrides []string credStore credentials.CredentialStore sequential bool - gptscriptCert certs.CertAndKey } -func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) { +func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) { opt := complete(opts...) runner := &Runner{ @@ -111,7 +109,6 @@ func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCe credStore: credStore, sequential: opt.Sequential, auth: opt.Authorizer, - gptscriptCert: gptscriptCert, } if opt.StartPort != 0 { @@ -414,7 +411,6 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, - GPTScriptCert: r.gptscriptCert, } callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause) @@ -597,7 +593,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, - GPTScriptCert: r.gptscriptCert, } var contentInput string diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index 22095270..44ec4e3c 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/adrg/xdg" - "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" @@ -199,10 +198,7 @@ func NewRunner(t *testing.T) *Runner { rm := runtimes.Default(cacheDir, "") - gptscriptCert, err := certs.GenerateGPTScriptCert() - require.NoError(t, err) - - run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{ + run, err := runner.New(c, credentials.NoopStore{}, runner.Options{ Sequential: true, RuntimeManager: rm, }) From aa5ef5750bb6d820fb72fd49dbdb23c084ff4c3e Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Tue, 17 Dec 2024 14:47:29 -0500 Subject: [PATCH 5/5] fixes Signed-off-by: Grant Linville --- pkg/engine/daemon.go | 23 +++++++++++++++++++ pkg/engine/http.go | 54 +++++++++++++++++++++++++++++++++----------- pkg/openai/client.go | 51 +++++++++++++++++++++++++++++++++++------ pkg/remote/remote.go | 18 ++++++++++++--- 4 files changed, 123 insertions(+), 23 deletions(-) diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 899b27fc..59c93361 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -40,6 +40,29 @@ type Certs struct { lock sync.Mutex } +func GetClientCert() (certs.CertAndKey, error) { + certificates.lock.Lock() + defer certificates.lock.Unlock() + if len(certificates.clientCert.Cert) == 0 { + cert, err := certs.GenerateGPTScriptCert() + if err != nil { + return certs.CertAndKey{}, fmt.Errorf("failed to generate GPTScript certificate: %v", err) + } + certificates.clientCert = cert + } + return certificates.clientCert, nil +} + +func GetDaemonCert(toolID string) ([]byte, error) { + certificates.lock.Lock() + defer certificates.lock.Unlock() + cert, exists := certificates.daemonCerts[toolID] + if !exists { + return nil, fmt.Errorf("daemon certificate for [%s] not found", toolID) + } + return cert.Cert, nil +} + func IsDaemonRunning(url string) bool { ports.daemonLock.Lock() defer ports.daemonLock.Unlock() diff --git a/pkg/engine/http.go b/pkg/engine/http.go index d8bf0ef2..222d6977 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -14,6 +14,7 @@ import ( "slices" "strings" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -74,22 +75,22 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID) } - // Create a pool for the certificate to treat as a CA - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(daemonCert.Cert) { - return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID) - } - - tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) + tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert) if err != nil { - return nil, fmt.Errorf("failed to create client certificate: %v", err) + return nil, err } + } else if isLocalhostHTTPS(toolURL) { + // This sometimes happens when talking to a model provider + certificates.lock.Lock() + daemonCert, exists := certificates.daemonCerts[tool.ID] + clientCert := certificates.clientCert + certificates.lock.Unlock() - // Create TLS config for use in the HTTP client later - tlsConfigForDaemonRequest = &tls.Config{ - Certificates: []tls.Certificate{tlsClientCert}, - RootCAs: pool, - InsecureSkipVerify: false, + if exists { + tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert) + if err != nil { + return nil, err + } } } @@ -185,3 +186,30 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too Result: &s, }, nil } + +func isLocalhostHTTPS(u string) bool { + parsed, err := url.Parse(u) + if err != nil { + return false + } + + return parsed.Scheme == "https" && (parsed.Hostname() == "localhost" || parsed.Hostname() == "127.0.0.1") +} + +func getTLSConfig(clientCert certs.CertAndKey, daemonCert []byte) (*tls.Config, error) { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(daemonCert) { + return nil, fmt.Errorf("failed to append daemon certificate") + } + + tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate: %v", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsClientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, nil +} diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 1894bdda..0c660b4c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -2,9 +2,12 @@ package openai import ( "context" + "crypto/tls" + "crypto/x509" "errors" "io" "log/slog" + "net/http" "os" "slices" "sort" @@ -13,6 +16,7 @@ import ( openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/hash" @@ -51,13 +55,15 @@ type Client struct { } type Options struct { - BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"` - APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"` - OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` - DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"` - ConfigFile string `usage:"Path to GPTScript config file" name:"config"` - SetSeed bool `usage:"-"` - CacheKey string `usage:"-"` + BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"` + APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"` + OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` + DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"` + ConfigFile string `usage:"Path to GPTScript config file" name:"config"` + SetSeed bool `usage:"-"` + CacheKey string `usage:"-"` + ClientCert certs.CertAndKey `usage:"-"` + ServerCert []byte `usage:"-"` Cache *cache.Client } @@ -70,6 +76,14 @@ func Complete(opts ...Options) (result Options) { result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed) result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey) + + if len(opt.ClientCert.Cert) > 0 { + result.ClientCert = opt.ClientCert + } + + if len(opt.ServerCert) > 0 { + result.ServerCert = opt.ServerCert + } } return result @@ -116,6 +130,29 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL) cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID) + // Set up for mTLS, if configured. + if opt.ServerCert != nil && len(opt.ClientCert.Cert) > 0 { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(opt.ServerCert) { + return nil, errors.New("failed to append server cert to pool") + } + + clientCert, err := tls.X509KeyPair(opt.ClientCert.Cert, opt.ClientCert.Key) + if err != nil { + return nil, err + } + + cfg.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, + }, + } + } + cacheKeyBase := opt.CacheKey if cacheKeyBase == "" { cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL) diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 5542372b..e5132d78 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -166,10 +166,22 @@ func (c *Client) load(ctx context.Context, toolName string, env ...string) (*ope return nil, err } + clientCert, err := engine.GetClientCert() + if err != nil { + return nil, err + } + + serverCert, err := engine.GetDaemonCert(prg.EntryToolID) + if err != nil { + return nil, err + } + oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{ - BaseURL: strings.TrimSuffix(url, "/") + "/v1", - Cache: c.cache, - CacheKey: prg.EntryToolID, + BaseURL: strings.TrimSuffix(url, "/") + "/v1", + Cache: c.cache, + CacheKey: prg.EntryToolID, + ClientCert: clientCert, + ServerCert: serverCert, }) if err != nil { return nil, err