From 6a8d19f0398a743192e1df3f75b274c81bbc6619 Mon Sep 17 00:00:00 2001 From: whiler Date: Fri, 10 Mar 2023 17:13:32 +0800 Subject: [PATCH] update cert automatically --- cmd/cert.go | 160 ++++++++++++++++++++++++++++++++++++++++++++++++++++ cmd/main.go | 73 ++++++++---------------- makefile | 2 +- 3 files changed, 183 insertions(+), 52 deletions(-) create mode 100644 cmd/cert.go diff --git a/cmd/cert.go b/cmd/cert.go new file mode 100644 index 0000000..873f26b --- /dev/null +++ b/cmd/cert.go @@ -0,0 +1,160 @@ +package main + +import ( + "bufio" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "log" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" +) + +var errTimeCertificate = errors.New("out of cert time") + +var ErrInvalidRepo = errors.New("invalid repo") + +func validate(cert tls.Certificate) error { + cur := time.Now() + if x509cert, err := x509.ParseCertificate(cert.Certificate[0]); err != nil { + return err + } else if cur.Before(x509cert.NotBefore) || cur.After(x509cert.NotAfter) { + return errTimeCertificate + } else { + return nil + } +} + +func isValid(certpath, keypath string) bool { + if cert, err := tls.LoadX509KeyPair(certpath, keypath); err == nil { + return validate(cert) == nil + } else { + return false + } +} + +func getCertNames(certpath, keypath string) ([]string, error) { + if certpath == "" || keypath == "" { + return []string{"localhost"}, nil + } else if cert, err := tls.LoadX509KeyPair(certpath, keypath); err != nil { + return nil, err + } else if x509cert, err := x509.ParseCertificate(cert.Certificate[0]); err != nil { + return nil, err + } else { + set := map[string]bool{} + set[x509cert.Subject.CommonName] = true + for _, name := range x509cert.DNSNames { + set[name] = true + } + for _, ip := range x509cert.IPAddresses { + set[ip.String()] = true + } + names := make([]string, 0, len(set)) + for name := range set { + names = append(names, strings.ReplaceAll(name, "*", "local")) + } + return names, nil + } +} + +func getRepo(path string) (string, error) { + if file, err := os.Open(path); err != nil { + return "", err + } else { + defer file.Close() + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + cur := strings.TrimSpace(scanner.Text()) + if len(cur) == 0 { + continue + } else if strings.HasPrefix(cur, "#") { + continue + } else if u, err := url.Parse(cur); err != nil { + return "", err + } else if !strings.HasPrefix(u.Scheme, "http") || u.Host == "" { + break + } else { + return u.String(), nil + } + } + return "", ErrInvalidRepo + } +} + +func wget(target string) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + if req, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil); err != nil { + return nil, err + } else { + req.Header.Add("User-Agent", "kmactor/"+ver) + if resp, err := http.DefaultClient.Do(req); err != nil { + return nil, err + } else { + defer resp.Body.Close() + return io.ReadAll(resp.Body) + } + } +} + +func dumpto(path string, content []byte) error { + if file, err := os.CreateTemp("", filepath.Base(path)); err != nil { + return err + } else { + temp := file.Name() + if wrote, err := file.Write(content); err != nil { + file.Close() + os.Remove(temp) + return err + } else if wrote != len(content) { + file.Close() + os.Remove(temp) + return io.ErrShortWrite + } else { + file.Close() + return os.Rename(temp, path) + } + } +} + +func fetchCert(repo string) ([]byte, []byte, error) { + log.Printf("fetching cert from %s", repo) + if u, err := url.Parse(repo); err != nil { + return nil, nil, err + } else if certContent, err := wget(u.JoinPath("cert.pem").String()); err != nil { + return nil, nil, err + } else if keypath, err := wget(u.JoinPath("key.pem").String()); err != nil { + return nil, nil, err + } else { + return certContent, keypath, nil + } +} + +func updateCert(certpath, keypath, repopath string) error { + if repopath == "" || certpath == "" || keypath == "" { + return nil + } else if isValid(certpath, keypath) { + return nil + } else if repo, err := getRepo(repopath); err != nil { + return err + } else if certContent, keyContent, err := fetchCert(repo); err != nil { + return err + } else if cert, err := tls.X509KeyPair(certContent, keyContent); err != nil { + return err + } else if err = validate(cert); err != nil { + return err + } else if err = dumpto(certpath, certContent); err != nil { + return err + } else if err = dumpto(keypath, keyContent); err != nil { + return err + } else { + return nil + } +} diff --git a/cmd/main.go b/cmd/main.go index 7ab5ea0..18fd83a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,8 +2,6 @@ package main import ( "context" - "crypto/tls" - "crypto/x509" "errors" "flag" "fmt" @@ -12,7 +10,6 @@ import ( "net/http" "os" "os/signal" - "strings" "syscall" "time" @@ -23,17 +20,20 @@ import ( ) var ( - normal = "0.1.1" + normal = "0.1.2" preRelease = "dev" buildRevision string - ErrNoCertificate = errors.New("no cert") - ErrTimeCertificate = errors.New("out of cert time") + ver string ) -type dummp struct{} - -func (*dummp) Close() error { return nil } +func use(path string) string { + if info, err := os.Stat(path); err != nil || info.IsDir() { + return "" + } else { + return path + } +} func main() { var ( @@ -42,6 +42,7 @@ func main() { token string cert, key string logto string + repo string version bool ) @@ -49,13 +50,12 @@ func main() { flagSet.IntVar(&port, "port", 9242, "local port") flagSet.StringVar(&token, "token", "", "token") - flagSet.StringVar(&cert, "cert", ensure("cert.pem"), "cert file path") - flagSet.StringVar(&key, "key", ensure("key.pem"), "key file path") + flagSet.StringVar(&cert, "cert", use("cert.pem"), "cert file path") + flagSet.StringVar(&key, "key", use("key.pem"), "key file path") flagSet.StringVar(&logto, "log", "kmactor.log", "log file path") + flagSet.StringVar(&repo, "repo", use("repo.txt"), "auto update cert from repo") flagSet.BoolVar(&version, "version", false, "version") - ver := fmt.Sprintf("%s-%s+%s", normal, preRelease, buildRevision) - if err := flagSet.Parse(os.Args[1:]); err != nil { log.Println(err) } else if version { @@ -68,7 +68,9 @@ func main() { log.Printf("invalid port: %d", port) } else if (cert != "" && key == "") || (cert == "" && key != "") { log.Println("cert and key are required at the same time") - } else if names, err := getCertName(cert, key); err != nil { + } else if err = updateCert(cert, key, repo); err != nil { + log.Println(err) + } else if names, err := getCertNames(cert, key); err != nil { log.Println(err) } else if handler, err := kmactor.Build(ver, token); err != nil { log.Println(err) @@ -123,6 +125,10 @@ func main() { } } +type dummp struct{} + +func (*dummp) Close() error { return nil } + func logging(path string) (io.Closer, error) { if path == "-" { return &dummp{}, nil @@ -134,41 +140,6 @@ func logging(path string) (io.Closer, error) { } } -func ensure(path string) string { - if info, err := os.Stat(path); err != nil || info.IsDir() { - return "" - } else { - return path - } -} - -func getCertName(certpath, keypath string) ([]string, error) { - if certpath == "" || keypath == "" { - return []string{"localhost"}, nil - } else { - cur := time.Now() - if cert, err := tls.LoadX509KeyPair(certpath, keypath); err != nil { - return nil, err - } else if len(cert.Certificate) == 0 { - return nil, ErrNoCertificate - } else if x509cert, err := x509.ParseCertificate(cert.Certificate[0]); err != nil { - return nil, err - } else if cur.Before(x509cert.NotBefore) || cur.After(x509cert.NotAfter) { - return nil, ErrTimeCertificate - } else { - set := map[string]bool{} - set[x509cert.Subject.CommonName] = true - for _, name := range x509cert.DNSNames { - set[name] = true - } - for _, ip := range x509cert.IPAddresses { - set[ip.String()] = true - } - names := make([]string, 0, len(set)) - for name := range set { - names = append(names, strings.ReplaceAll(name, "*", "local")) - } - return names, nil - } - } +func init() { + ver = fmt.Sprintf("%s-%s+%s", normal, preRelease, buildRevision) } diff --git a/makefile b/makefile index 3db6888..b23b6ef 100644 --- a/makefile +++ b/makefile @@ -29,7 +29,7 @@ endif .PHONY: all amd64 arm64 build release tidy updep build: - CGO_ENABLED=1 go build $(GOFLAGS) -ldflags "$(GOLDFLAGS)" -tags="$(TAGS)" -o $(ARTIFACT) cmd/main.go + CGO_ENABLED=1 go build $(GOFLAGS) -ldflags "$(GOLDFLAGS)" -tags="$(TAGS)" -o $(ARTIFACT) cmd/*.go release: GOFLAGS="-trimpath" GOLDFLAGS="$(GOLDFLAGS) -s -w" TAGS="release" $(MAKE) build