Skip to content

Commit

Permalink
Merge pull request #14 from whiler/12-kmactor-update-cert
Browse files Browse the repository at this point in the history
update cert automatically
  • Loading branch information
whiler authored Mar 10, 2023
2 parents 235ad91 + 6a8d19f commit 418e1f9
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 52 deletions.
160 changes: 160 additions & 0 deletions cmd/cert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package main

import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
)

var errTimeCertificate = errors.New("out of cert time")

var ErrInvalidRepo = errors.New("invalid repo")

func validate(cert tls.Certificate) error {
cur := time.Now()
if x509cert, err := x509.ParseCertificate(cert.Certificate[0]); err != nil {
return err
} else if cur.Before(x509cert.NotBefore) || cur.After(x509cert.NotAfter) {
return errTimeCertificate
} else {
return nil
}
}

func isValid(certpath, keypath string) bool {
if cert, err := tls.LoadX509KeyPair(certpath, keypath); err == nil {
return validate(cert) == nil
} else {
return false
}
}

func getCertNames(certpath, keypath string) ([]string, error) {
if certpath == "" || keypath == "" {
return []string{"localhost"}, nil
} else if cert, err := tls.LoadX509KeyPair(certpath, keypath); err != nil {
return nil, err
} else if x509cert, err := x509.ParseCertificate(cert.Certificate[0]); err != nil {
return nil, err
} else {
set := map[string]bool{}
set[x509cert.Subject.CommonName] = true
for _, name := range x509cert.DNSNames {
set[name] = true
}
for _, ip := range x509cert.IPAddresses {
set[ip.String()] = true
}
names := make([]string, 0, len(set))
for name := range set {
names = append(names, strings.ReplaceAll(name, "*", "local"))
}
return names, nil
}
}

func getRepo(path string) (string, error) {
if file, err := os.Open(path); err != nil {
return "", err
} else {
defer file.Close()
scanner := bufio.NewScanner(file)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
cur := strings.TrimSpace(scanner.Text())
if len(cur) == 0 {
continue
} else if strings.HasPrefix(cur, "#") {
continue
} else if u, err := url.Parse(cur); err != nil {
return "", err
} else if !strings.HasPrefix(u.Scheme, "http") || u.Host == "" {
break
} else {
return u.String(), nil
}
}
return "", ErrInvalidRepo
}
}

func wget(target string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if req, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil); err != nil {
return nil, err
} else {
req.Header.Add("User-Agent", "kmactor/"+ver)
if resp, err := http.DefaultClient.Do(req); err != nil {
return nil, err
} else {
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
}
}

func dumpto(path string, content []byte) error {
if file, err := os.CreateTemp("", filepath.Base(path)); err != nil {
return err
} else {
temp := file.Name()
if wrote, err := file.Write(content); err != nil {
file.Close()
os.Remove(temp)
return err
} else if wrote != len(content) {
file.Close()
os.Remove(temp)
return io.ErrShortWrite
} else {
file.Close()
return os.Rename(temp, path)
}
}
}

func fetchCert(repo string) ([]byte, []byte, error) {
log.Printf("fetching cert from %s", repo)
if u, err := url.Parse(repo); err != nil {
return nil, nil, err
} else if certContent, err := wget(u.JoinPath("cert.pem").String()); err != nil {
return nil, nil, err
} else if keypath, err := wget(u.JoinPath("key.pem").String()); err != nil {
return nil, nil, err
} else {
return certContent, keypath, nil
}
}

func updateCert(certpath, keypath, repopath string) error {
if repopath == "" || certpath == "" || keypath == "" {
return nil
} else if isValid(certpath, keypath) {
return nil
} else if repo, err := getRepo(repopath); err != nil {
return err
} else if certContent, keyContent, err := fetchCert(repo); err != nil {
return err
} else if cert, err := tls.X509KeyPair(certContent, keyContent); err != nil {
return err
} else if err = validate(cert); err != nil {
return err
} else if err = dumpto(certpath, certContent); err != nil {
return err
} else if err = dumpto(keypath, keyContent); err != nil {
return err
} else {
return nil
}
}
73 changes: 22 additions & 51 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package main

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"flag"
"fmt"
Expand All @@ -12,7 +10,6 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

Expand All @@ -23,17 +20,20 @@ import (
)

var (
normal = "0.1.1"
normal = "0.1.2"
preRelease = "dev"
buildRevision string

ErrNoCertificate = errors.New("no cert")
ErrTimeCertificate = errors.New("out of cert time")
ver string
)

type dummp struct{}

func (*dummp) Close() error { return nil }
func use(path string) string {
if info, err := os.Stat(path); err != nil || info.IsDir() {
return ""
} else {
return path
}
}

func main() {
var (
Expand All @@ -42,20 +42,20 @@ func main() {
token string
cert, key string
logto string
repo string
version bool
)

log.SetFlags(log.Ltime)

flagSet.IntVar(&port, "port", 9242, "local port")
flagSet.StringVar(&token, "token", "", "token")
flagSet.StringVar(&cert, "cert", ensure("cert.pem"), "cert file path")
flagSet.StringVar(&key, "key", ensure("key.pem"), "key file path")
flagSet.StringVar(&cert, "cert", use("cert.pem"), "cert file path")
flagSet.StringVar(&key, "key", use("key.pem"), "key file path")
flagSet.StringVar(&logto, "log", "kmactor.log", "log file path")
flagSet.StringVar(&repo, "repo", use("repo.txt"), "auto update cert from repo")
flagSet.BoolVar(&version, "version", false, "version")

ver := fmt.Sprintf("%s-%s+%s", normal, preRelease, buildRevision)

if err := flagSet.Parse(os.Args[1:]); err != nil {
log.Println(err)
} else if version {
Expand All @@ -68,7 +68,9 @@ func main() {
log.Printf("invalid port: %d", port)
} else if (cert != "" && key == "") || (cert == "" && key != "") {
log.Println("cert and key are required at the same time")
} else if names, err := getCertName(cert, key); err != nil {
} else if err = updateCert(cert, key, repo); err != nil {
log.Println(err)
} else if names, err := getCertNames(cert, key); err != nil {
log.Println(err)
} else if handler, err := kmactor.Build(ver, token); err != nil {
log.Println(err)
Expand Down Expand Up @@ -123,6 +125,10 @@ func main() {
}
}

type dummp struct{}

func (*dummp) Close() error { return nil }

func logging(path string) (io.Closer, error) {
if path == "-" {
return &dummp{}, nil
Expand All @@ -134,41 +140,6 @@ func logging(path string) (io.Closer, error) {
}
}

func ensure(path string) string {
if info, err := os.Stat(path); err != nil || info.IsDir() {
return ""
} else {
return path
}
}

func getCertName(certpath, keypath string) ([]string, error) {
if certpath == "" || keypath == "" {
return []string{"localhost"}, nil
} else {
cur := time.Now()
if cert, err := tls.LoadX509KeyPair(certpath, keypath); err != nil {
return nil, err
} else if len(cert.Certificate) == 0 {
return nil, ErrNoCertificate
} else if x509cert, err := x509.ParseCertificate(cert.Certificate[0]); err != nil {
return nil, err
} else if cur.Before(x509cert.NotBefore) || cur.After(x509cert.NotAfter) {
return nil, ErrTimeCertificate
} else {
set := map[string]bool{}
set[x509cert.Subject.CommonName] = true
for _, name := range x509cert.DNSNames {
set[name] = true
}
for _, ip := range x509cert.IPAddresses {
set[ip.String()] = true
}
names := make([]string, 0, len(set))
for name := range set {
names = append(names, strings.ReplaceAll(name, "*", "local"))
}
return names, nil
}
}
func init() {
ver = fmt.Sprintf("%s-%s+%s", normal, preRelease, buildRevision)
}
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ endif
.PHONY: all amd64 arm64 build release tidy updep

build:
CGO_ENABLED=1 go build $(GOFLAGS) -ldflags "$(GOLDFLAGS)" -tags="$(TAGS)" -o $(ARTIFACT) cmd/main.go
CGO_ENABLED=1 go build $(GOFLAGS) -ldflags "$(GOLDFLAGS)" -tags="$(TAGS)" -o $(ARTIFACT) cmd/*.go

release:
GOFLAGS="-trimpath" GOLDFLAGS="$(GOLDFLAGS) -s -w" TAGS="release" $(MAKE) build
Expand Down

0 comments on commit 418e1f9

Please sign in to comment.