diff --git a/cmd/nps/nps.go b/cmd/nps/nps.go index baa930bb..c3b4d334 100644 --- a/cmd/nps/nps.go +++ b/cmd/nps/nps.go @@ -1,14 +1,6 @@ package main import ( - "ehang.io/nps/lib/crypt" - "ehang.io/nps/lib/file" - "ehang.io/nps/lib/install" - "ehang.io/nps/lib/version" - "ehang.io/nps/server" - "ehang.io/nps/server/connection" - "ehang.io/nps/server/tool" - "ehang.io/nps/web/routers" "flag" "log" "os" @@ -18,7 +10,16 @@ import ( "strings" "sync" + "ehang.io/nps/lib/file" + "ehang.io/nps/lib/install" + "ehang.io/nps/lib/version" + "ehang.io/nps/server" + "ehang.io/nps/server/connection" + "ehang.io/nps/server/tool" + "ehang.io/nps/web/routers" + "ehang.io/nps/lib/common" + "ehang.io/nps/lib/crypt" "ehang.io/nps/lib/daemon" "github.com/astaxie/beego" "github.com/astaxie/beego/logs" @@ -200,7 +201,8 @@ func run() { } logs.Info("the version of server is %s ,allow client core version to be %s", version.VERSION, version.GetVersion()) connection.InitConnectionService() - crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key")) + //crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key")) + crypt.InitTls() tool.InitAllowPort() tool.StartSystemInfo() go server.StartNewServer(bridgePort, task, beego.AppConfig.String("bridge_type")) diff --git a/lib/crypt/tls.go b/lib/crypt/tls.go index 35a0a748..c301be88 100644 --- a/lib/crypt/tls.go +++ b/lib/crypt/tls.go @@ -1,22 +1,37 @@ package crypt import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "log" + "math/big" "net" "os" + "time" "github.com/astaxie/beego/logs" ) -var pemPath, keyPath string +var ( + cert tls.Certificate +) -func InitTls(pem, key string) { - pemPath = pem - keyPath = key +func InitTls() { + c, k, err := generateKeyPair("NPS Corp,.Inc") + if err == nil { + cert, err = tls.X509KeyPair(c, k) + } + if err != nil { + log.Fatalln("Error initializing crypto certs", err) + } } func NewTlsServerConn(conn net.Conn) net.Conn { - cert, err := tls.LoadX509KeyPair(pemPath, keyPath) + var err error if err != nil { logs.Error(err) os.Exit(0) @@ -32,3 +47,41 @@ func NewTlsClientConn(conn net.Conn) net.Conn { } return tls.Client(conn, conf) } + +func generateKeyPair(CommonName string) (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Company Name LTD."}, + CommonName: CommonName, + Country: []string{"US"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} diff --git a/server/proxy/http.go b/server/proxy/http.go index af4ad805..73bfa10e 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -1,7 +1,7 @@ package proxy import ( - "bufio" + "context" "crypto/tls" "io" "net" @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "strconv" - "strings" "sync" + "time" "ehang.io/nps/bridge" "ehang.io/nps/lib/cache" @@ -101,174 +101,161 @@ func (s *httpServer) Close() error { return nil } -func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) { - hijacker, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Hijacking not supported", http.StatusInternalServerError) - return - } - c, _, err := hijacker.Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) +func (s *httpServer) NewServer(port int, scheme string) *http.Server { + rProxy := NewHttpReverseProxy(s) + return &http.Server{ + Addr: ":" + strconv.Itoa(port), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Scheme = scheme + rProxy.ServeHTTP(w, r) + }), + // Disable HTTP/2. + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } - s.handleHttp(conn.NewConn(c), r) } -func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) { +type HttpReverseProxy struct { + proxy *ReverseProxy + + responseHeaderTimeout time.Duration +} + +func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { var ( host *file.Host - target net.Conn - err error - connClient io.ReadWriteCloser - scheme = r.URL.Scheme - lk *conn.Link targetAddr string - lenConn *conn.LenConn - isReset bool - wg sync.WaitGroup + err error ) - defer func() { - if connClient != nil { - connClient.Close() - }else { - s.writeConnFail(c.Conn) - } - c.Close() - }() -reset: - if isReset { - host.Client.AddConn() - } - if host, err = file.GetDb().GetInfoByHost(r.Host, r); err != nil { - logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) + if host, err = file.GetDb().GetInfoByHost(req.Host, req); err != nil { + rw.WriteHeader(http.StatusNotFound) + rw.Write([]byte(req.Host + " not found")) return } - if err := s.CheckFlowAndConnNum(host.Client); err != nil { - logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error()) - return - } - if !isReset { - defer host.Client.AddConn() - } - if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil { - logs.Warn("auth error", err, r.RemoteAddr) + if host.Client.Cnf.U != "" && host.Client.Cnf.P != "" && !common.CheckAuth(req, host.Client.Cnf.U, host.Client.Cnf.P) { + rw.Header().Set("WWW-Authenticate", "Basic realm=\"Private Area\"") + rw.WriteHeader(http.StatusUnauthorized) + rw.Write([]byte("Unauthorized")) return } if targetAddr, err = host.Target.GetRandomTarget(); err != nil { - logs.Warn(err.Error()) + rw.WriteHeader(http.StatusBadGateway) + rw.Write([]byte("502 Bad Gateway")) return } - lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) - if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { - logs.Notice("connect to target %s error %s", lk.Host, err) - return + req.URL.Host = req.Host + req = req.WithContext(context.WithValue(req.Context(), "host", host)) + req = req.WithContext(context.WithValue(req.Context(), "target", targetAddr)) + req = req.WithContext(context.WithValue(req.Context(), "req", req)) + + rp.proxy.ServeHTTP(rw, req) +} + +func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy { + rp := &HttpReverseProxy{ + responseHeaderTimeout: 30 * time.Second, } - connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") + proxy := NewReverseProxy(&httputil.ReverseProxy{ + Director: func(r *http.Request) { + host := r.Context().Value("host").(*file.Host) + common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false) + }, + Transport: &http.Transport{ + ResponseHeaderTimeout: rp.responseHeaderTimeout, + DisableKeepAlives: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var ( + host *file.Host + target net.Conn + err error + connClient io.ReadWriteCloser + targetAddr string + lk *conn.Link + ) - //read from inc-client - go func() { - wg.Add(1) - isReset = false - defer connClient.Close() - defer func() { - wg.Done() - if !isReset { - c.Close() - } - }() - for { - if resp, err := http.ReadResponse(bufio.NewReader(connClient), r); err != nil || resp == nil { - return - } else { - //if the cache is start and the response is in the extension,store the response to the cache list - if s.useCache && r.URL != nil && strings.Contains(r.URL.Path, ".") { - b, err := httputil.DumpResponse(resp, true) - if err != nil { - return - } - c.Write(b) - host.Flow.Add(0, int64(len(b))) - s.cache.Add(filepath.Join(host.Host, r.URL.Path), b) - } else { - lenConn := conn.NewLenConn(c) - if err := resp.Write(lenConn); err != nil { - logs.Error(err) - return - } - host.Flow.Add(0, int64(lenConn.Len)) - } - } - } - }() + r := ctx.Value("req").(*http.Request) + host = ctx.Value("host").(*file.Host) + targetAddr = ctx.Value("target").(string) - for { - //if the cache start and the request is in the cache list, return the cache - if s.useCache { - if v, ok := s.cache.Get(filepath.Join(host.Host, r.URL.Path)); ok { - n, err := c.Write(v.([]byte)) - if err != nil { - break - } - logs.Trace("%s request, method %s, host %s, url %s, remote address %s, return cache", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String()) - host.Flow.Add(0, int64(n)) - //if return cache and does not create a new conn with client and Connection is not set or close, close the connection. - if strings.ToLower(r.Header.Get("Connection")) == "close" || strings.ToLower(r.Header.Get("Connection")) == "" { - break + lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) + if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { + logs.Notice("connect to target %s error %s", lk.Host, err) + return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the server") } - goto readReq - } - } - - //change the host and header and set proxy setting - common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String(), s.addOrigin) - logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String(), lk.Host) - //write - lenConn = conn.NewLenConn(connClient) - if err := r.Write(lenConn); err != nil { - logs.Error(err) - break - } - host.Flow.Add(int64(lenConn.Len), 0) + connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + return &flowConn{ + ReadWriteCloser: connClient, + fakeAddr: local, + host: host, + }, nil + }, + }, + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + logs.Warn("do http proxy request error: %v", err) + rw.WriteHeader(http.StatusNotFound) + }, + }) + proxy.WebSocketDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var ( + host *file.Host + target net.Conn + err error + connClient io.ReadWriteCloser + targetAddr string + lk *conn.Link + ) + r := ctx.Value("req").(*http.Request) + host = ctx.Value("host").(*file.Host) + targetAddr = ctx.Value("target").(string) - readReq: - //read req from connection - if r, err = http.ReadRequest(bufio.NewReader(c)); err != nil { - break - } - r.URL.Scheme = scheme - //What happened ,Why one character less??? - r.Method = resetReqMethod(r.Method) - if hostTmp, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil { - logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) - break - } else if host != hostTmp { - host = hostTmp - isReset = true - connClient.Close() - goto reset + lk = conn.NewLink("tcp", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) + if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { + logs.Notice("connect to target %s error %s", lk.Host, err) + return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the target") } + connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + return &flowConn{ + ReadWriteCloser: connClient, + fakeAddr: local, + host: host, + }, nil } - wg.Wait() + rp.proxy = proxy + return rp } -func resetReqMethod(method string) string { - if method == "ET" { - return "GET" - } - if method == "OST" { - return "POST" - } - return method +type flowConn struct { + io.ReadWriteCloser + fakeAddr net.Addr + host *file.Host + flowIn int64 + flowOut int64 + once sync.Once } -func (s *httpServer) NewServer(port int, scheme string) *http.Server { - return &http.Server{ - Addr: ":" + strconv.Itoa(port), - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Scheme = scheme - s.handleTunneling(w, r) - }), - // Disable HTTP/2. - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), - } +func (c *flowConn) Read(p []byte) (n int, err error) { + n, err = c.ReadWriteCloser.Read(p) + c.flowIn += int64(n) + return n, err +} + +func (c *flowConn) Write(p []byte) (n int, err error) { + n, err = c.ReadWriteCloser.Write(p) + c.flowOut += int64(n) + return n, err } + +func (c *flowConn) Close() error { + c.once.Do(func() { c.host.Flow.Add(c.flowIn, c.flowOut) }) + return c.ReadWriteCloser.Close() +} + +func (c *flowConn) LocalAddr() net.Addr { return c.fakeAddr } + +func (c *flowConn) RemoteAddr() net.Addr { return c.fakeAddr } + +func (*flowConn) SetDeadline(t time.Time) error { return nil } + +func (*flowConn) SetReadDeadline(t time.Time) error { return nil } + +func (*flowConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/server/proxy/reverseproxy.go b/server/proxy/reverseproxy.go new file mode 100644 index 00000000..df7e8660 --- /dev/null +++ b/server/proxy/reverseproxy.go @@ -0,0 +1,136 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package proxy + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "sync" +) + +type HTTPError struct { + error + HTTPCode int +} + +func NewHTTPError(code int, errmsg string) error { + return &HTTPError{ + error: errors.New(errmsg), + HTTPCode: code, + } +} + +type ReverseProxy struct { + *httputil.ReverseProxy + WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error) +} + +func IsWebsocketRequest(req *http.Request) bool { + containsHeader := func(name, value string) bool { + items := strings.Split(req.Header.Get(name), ",") + for _, item := range items { + if value == strings.ToLower(strings.TrimSpace(item)) { + return true + } + } + return false + } + return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket") +} + +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + rp := &ReverseProxy{ + ReverseProxy: httputil.NewSingleHostReverseProxy(target), + WebSocketDialContext: nil, + } + rp.ErrorHandler = rp.errHandler + return rp +} + +func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy { + rp := &ReverseProxy{ + ReverseProxy: orp, + WebSocketDialContext: nil, + } + rp.ErrorHandler = rp.errHandler + return rp +} + +func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) { + if e == io.EOF { + rw.WriteHeader(521) + //rw.Write(getWaitingPageContent()) + } else { + if httperr, ok := e.(*HTTPError); ok { + rw.WriteHeader(httperr.HTTPCode) + } else { + rw.WriteHeader(http.StatusNotFound) + } + rw.Write([]byte("error: " + e.Error())) + } +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if IsWebsocketRequest(req) { + p.serveWebSocket(rw, req) + } else { + p.ReverseProxy.ServeHTTP(rw, req) + } +} + +func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) { + if p.WebSocketDialContext == nil { + rw.WriteHeader(500) + return + } + targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "") + if err != nil { + rw.WriteHeader(501) + return + } + defer targetConn.Close() + + p.Director(req) + + hijacker, ok := rw.(http.Hijacker) + if !ok { + rw.WriteHeader(500) + return + } + conn, _, errHijack := hijacker.Hijack() + if errHijack != nil { + rw.WriteHeader(500) + return + } + defer conn.Close() + + req.Write(targetConn) + Join(conn, targetConn) +} + +func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) { + var wait sync.WaitGroup + pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) { + defer to.Close() + defer from.Close() + defer wait.Done() + + *count, _ = io.Copy(to, from) + } + + wait.Add(2) + go pipe(c1, c2, &inCount) + go pipe(c2, c1, &outCount) + wait.Wait() + return +} diff --git a/server/test/test.go b/server/test/test.go index a30d03d6..3fce9eed 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -52,10 +52,10 @@ func TestServerConfig() { if port, err := strconv.Atoi(p); err != nil { log.Fatalln("get https port error", err) } else { - if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) { + if beego.AppConfig.String("pemPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) { log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath")) } - if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("ketPath"))) { + if beego.AppConfig.String("keyPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("keyPath"))) { log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath")) } isInArr(&postTcpArr, port, "http port", "tcp")