diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index b613028..0000000 --- a/.dockerignore +++ /dev/null @@ -1,3 +0,0 @@ -.aptcache -apt-proxy -release/* \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7614621..14a2ab5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .aptcache +cachedata apt-proxy last-cid -Godeps/_workspace diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 276104c..0000000 --- a/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -FROM ubuntu:14.10 - -RUN apt-get update -RUN apt-get install -y golang - -RUN mkdir -p /go -ENV GOPATH /go:/go/src/github.com/lox/apt-proxy/Godeps/_workspace -ENV GOBIN /go/bin -ADD . /go/src/github.com/lox/apt-proxy - -EXPOSE 3142 -WORKDIR /go/src/github.com/lox/apt-proxy -CMD ["go", "run", "/go/src/github.com/lox/apt-proxy/apt-proxy.go"] \ No newline at end of file diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json deleted file mode 100644 index ffa3e5a..0000000 --- a/Godeps/Godeps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "ImportPath": "github.com/lox/apt-proxy", - "GoVersion": "go1.3.3", - "Deps": [ - { - "ImportPath": "github.com/lox/httpcache", - "Rev": "6a9f0d42fe2cb0308ff6656d4860eb5cf3173b74" - }, - { - "ImportPath": "gopkgs.com/vfs.v1", - "Rev": "60e0a240148f4dce1a8a1f5639b0e0c9e25fd288" - } - ] -} diff --git a/Godeps/Readme b/Godeps/Readme deleted file mode 100644 index 4cdaa53..0000000 --- a/Godeps/Readme +++ /dev/null @@ -1,5 +0,0 @@ -This directory tree is generated automatically by godep. - -Please do not edit. - -See https://github.com/tools/godep for more information. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1cbc14f --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 Su Yang + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 6e0b103..7357163 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,31 @@ -# Apt Proxy +# APT Proxy -A caching proxy specifically for apt package caching, also rewrites to the fastest local mirror. Built as a tiny docker image for easy deployment. +Small and Reliable APT packages cache tool, supports both Ubuntu and Debian. -Built because [apt-cacher-ng](https://www.unix-ag.uni-kl.de/~bloch/acng/) is unreliable. +You can safely use it instead of [apt-cacher-ng](https://www.unix-ag.uni-kl.de/~bloch/acng/). -## Running via Go +## (WIP) Usage -```bash -go install github.com/lox/apt-proxy -$GOBIN/apt-proxy -``` - -## Running in Docker for Development - -```bash -docker build --rm --tag=apt-proxy-dev . -docker run -it --rm --publish=3142 --net host apt-proxy-dev -``` +- Binaries +- Docker -## Building in Docker for Release +### (WIP) Development ```bash -docker build --rm --tag=apt-proxy-dev . -docker run -it --cidfile last-cid apt-proxy-dev ./build.sh -docker cp $(cat last-cid):/apt-proxy release/ -docker build --tag=apt-proxy ./release -rm last-cid +go run apt-proxy.go ``` -## Running from Docker +## Ubuntu / Debian Debugging ``` -docker run -it --rm --publish=3142 --net host lox24/apt-proxy +http_proxy=http://192.168.33.1:3142 apt-get -o Debug::pkgProblemResolver=true -o Debug::Acquire::http=true update +http_proxy=http://192.168.33.1:3142 apt-get -o Debug::pkgProblemResolver=true -o Debug::Acquire::http=true install apache2 ``` -## Debugging +## Licenses, contains dependent software -``` -http_proxy=http://192.168.33.1:3142 apt-get -o Debug::pkgProblemResolver=true -o Debug::Acquire::http=true update -http_proxy=http://192.168.33.1:3142 apt-get -o Debug::pkgProblemResolver=true -o Debug::Acquire::http=true install apache2 -``` \ No newline at end of file +- MIT: [lox/httpcache](https://github.com/lox/httpcache/blob/master/LICENSE) +- NOT FOUND: [lox/apt-proxy](https://github.com/lox/apt-proxy#readme) +- MIT: [djherbis/stream](https://github.com/djherbis/stream/blob/master/LICENSE) +- MPL 2.0 [rainycape/vfs](https://github.com/rainycape/vfs/blob/master/LICENSE) +- MIT: [stretchr/testify](https://github.com/stretchr/testify/blob/master/LICENSE) \ No newline at end of file diff --git a/apt-proxy.go b/apt-proxy.go index 4c38a63..f25a075 100644 --- a/apt-proxy.go +++ b/apt-proxy.go @@ -5,43 +5,65 @@ import ( "log" "net/http" - "github.com/lox/apt-proxy/proxy" - "github.com/lox/httpcache" - "github.com/lox/httpcache/httplog" + "github.com/soulteary/apt-proxy/httpcache" + "github.com/soulteary/apt-proxy/linux" + "github.com/soulteary/apt-proxy/pkgs/httplog" + "github.com/soulteary/apt-proxy/proxy" ) const ( - defaultListen = "0.0.0.0:3142" - defaultDir = "./.aptcache" + DEFAULT_HOST = "0.0.0.0" + DEFAULT_PORT = "3142" + DEFAULT_CACHE_DIR = "./.aptcache" + DEFAULT_MIRROR = "" // "https://mirrors.tuna.tsinghua.edu.cn/ubuntu/" + DEFAULT_TYPE = linux.UBUNTU + DEFAULT_DEBUG = false ) var ( - version string - listen string - dir string - debug bool + version string + listen string + mirror string + types string + cacheDir string + debug bool ) func init() { - flag.StringVar(&listen, "listen", defaultListen, "the host and port to bind to") - flag.StringVar(&dir, "cachedir", defaultDir, "the dir to store cache data in") - flag.BoolVar(&debug, "debug", false, "whether to output debugging logging") + var ( + host string + port string + ) + flag.StringVar(&host, "host", DEFAULT_HOST, "the host to bind to") + flag.StringVar(&port, "port", DEFAULT_PORT, "the port to bind to") + flag.BoolVar(&debug, "debug", DEFAULT_DEBUG, "whether to output debugging logging") + flag.StringVar(&mirror, "mirror", DEFAULT_MIRROR, "the mirror for fetching packages") + flag.StringVar(&types, "type", DEFAULT_TYPE, "select the type of system to cache: ubuntu/debian") + flag.StringVar(&cacheDir, "cachedir", DEFAULT_CACHE_DIR, "the dir to store cache data in") flag.Parse() + + if types != linux.UBUNTU && types != linux.DEBIAN { + types = linux.UBUNTU + } + + listen = host + ":" + port } func main() { + log.Printf("running apt-proxy %s", version) if debug { + log.Printf("enable debug: true") httpcache.DebugLogging = true } - cache, err := httpcache.NewDiskCache(dir) + cache, err := httpcache.NewDiskCache(cacheDir) if err != nil { log.Fatal(err) } - ap := proxy.NewAptProxyFromDefaults() + ap := proxy.NewAptProxyFromDefaults(mirror, types) ap.Handler = httpcache.NewHandler(cache, ap.Handler) logger := httplog.NewResponseLogger(ap.Handler) diff --git a/build.sh b/build.sh deleted file mode 100755 index ec3ceeb..0000000 --- a/build.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -ex -go build -o /apt-proxy . \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3308081 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/soulteary/apt-proxy + +go 1.18 + +require github.com/stretchr/testify v1.7.2 + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..bdf5107 --- /dev/null +++ b/go.sum @@ -0,0 +1,11 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/httpcache/LICENSE b/httpcache/LICENSE new file mode 100644 index 0000000..40b4d9f --- /dev/null +++ b/httpcache/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2010-2014 Lachlan Donald http://lachlan.me + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/httpcache/README.md b/httpcache/README.md new file mode 100644 index 0000000..5c3ba66 --- /dev/null +++ b/httpcache/README.md @@ -0,0 +1,57 @@ + +# httpcache + +`httpcache` provides an [rfc7234][] compliant golang [http.Handler](http://golang.org/pkg/net/http/#Handler). + +[![wercker status](https://app.wercker.com/status/a76986990d27e72ea656bb37bb93f59f/m "wercker status")](https://app.wercker.com/project/bykey/a76986990d27e72ea656bb37bb93f59f) + +[![GoDoc](https://godoc.org/github.com/lox/httpcache?status.svg)](https://godoc.org/github.com/lox/httpcache) + +## Example + +This example is from the included CLI, it runs a caching proxy on http://localhost:8080. + +```go +proxy := &httputil.ReverseProxy{ + Director: func(r *http.Request) { + }, +} + +handler := httpcache.NewHandler(httpcache.NewMemoryCache(), proxy) +handler.Shared = true + +log.Printf("proxy listening on http://%s", listen) +log.Fatal(http.ListenAndServe(listen, handler)) +``` + +## Implemented + +- All of [rfc7234][], except those listed below +- Disk and Memory storage +- Apache-like logging via `httplog` package + +## Todo + +- Offline operation +- Size constraints on memory/disk cache and cache eviction +- Correctly handle mixture of HTTP1.0 clients and 1.1 upstreams +- More detail in `Via` header +- Support for weak entities with `If-Match` and `If-None-Match` +- Invalidation based on `Content-Location` and request method +- Better handling of duplicate headers and CacheControl values + +## Caveats + +- Conditional requests are never cached, this includes `Range` requests + +## Testing + +Tests are currently conducted via the test suite and verified via the [CoAdvisor tool](http://coad.measurement-factory.com/). + +## Reading List + +- http://httpwg.github.io/specs/rfc7234.html +- https://www.mnot.net/blog/2011/07/11/what_proxies_must_do +- https://www.mnot.net/blog/2014/06/07/rfc2616_is_dead + +[rfc7234]: http://httpwg.github.io/specs/rfc7234.html diff --git a/httpcache/bench_test.go b/httpcache/bench_test.go new file mode 100644 index 0000000..5196414 --- /dev/null +++ b/httpcache/bench_test.go @@ -0,0 +1,39 @@ +package httpcache_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" + + "github.com/soulteary/apt-proxy/httpcache" +) + +func BenchmarkCachingFiles(b *testing.B) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Cache-Control", "max-age=100000") + fmt.Fprintf(w, "cache server payload") + })) + defer backend.Close() + + u, err := url.Parse(backend.URL) + if err != nil { + b.Fatal(err) + } + + handler := httpcache.NewHandler(httpcache.NewMemoryCache(), httputil.NewSingleHostReverseProxy(u)) + handler.Shared = true + cacheServer := httptest.NewServer(handler) + defer cacheServer.Close() + + for n := 0; n < b.N; n++ { + client := http.Client{} + resp, err := client.Get(fmt.Sprintf("%s/llamas/%d", cacheServer.URL, n)) + if err != nil { + b.Fatal(err) + } + resp.Body.Close() + } +} diff --git a/httpcache/cache.go b/httpcache/cache.go new file mode 100644 index 0000000..ea1a514 --- /dev/null +++ b/httpcache/cache.go @@ -0,0 +1,238 @@ +package httpcache + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "hash/fnv" + "io" + "log" + "net/http" + "net/textproto" + "os" + pathutil "path" + "strconv" + "strings" + "time" + + "github.com/soulteary/apt-proxy/pkgs/vfs" +) + +const ( + headerPrefix = "header/" + bodyPrefix = "body/" + formatPrefix = "v1/" +) + +// Returned when a resource doesn't exist +var ErrNotFoundInCache = errors.New("Not found in cache") + +type Cache interface { + Header(key string) (Header, error) + Store(res *Resource, keys ...string) error + Retrieve(key string) (*Resource, error) + Invalidate(keys ...string) + Freshen(res *Resource, keys ...string) error +} + +// cache provides a storage mechanism for cached Resources +type cache struct { + fs vfs.VFS + stale map[string]time.Time +} + +var _ Cache = (*cache)(nil) + +type Header struct { + http.Header + StatusCode int +} + +// NewCache returns a cache backend off the provided VFS +func NewVFSCache(fs vfs.VFS) Cache { + return &cache{fs: fs, stale: map[string]time.Time{}} +} + +// NewMemoryCache returns an ephemeral cache in memory +func NewMemoryCache() Cache { + return NewVFSCache(vfs.Memory()) +} + +// NewDiskCache returns a disk-backed cache +func NewDiskCache(dir string) (Cache, error) { + if err := os.MkdirAll(dir, 0777); err != nil { + return nil, err + } + fs, err := vfs.FS(dir) + if err != nil { + return nil, err + } + chfs, err := vfs.Chroot("/", fs) + if err != nil { + return nil, err + } + return NewVFSCache(chfs), nil +} + +func (c *cache) vfsWrite(path string, r io.Reader) error { + if err := vfs.MkdirAll(c.fs, pathutil.Dir(path), 0700); err != nil { + return err + } + f, err := c.fs.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + return err + } + defer f.Close() + if _, err := io.Copy(f, r); err != nil { + return err + } + return nil +} + +// Retrieve the Status and Headers for a given key path +func (c *cache) Header(key string) (Header, error) { + path := headerPrefix + formatPrefix + hashKey(key) + f, err := c.fs.Open(path) + if err != nil { + if vfs.IsNotExist(err) { + return Header{}, ErrNotFoundInCache + } + return Header{}, err + } + + return readHeaders(bufio.NewReader(f)) +} + +// Store a resource against a number of keys +func (c *cache) Store(res *Resource, keys ...string) error { + var buf = &bytes.Buffer{} + + if length, err := strconv.ParseInt(res.Header().Get("Content-Length"), 10, 64); err == nil { + if _, err = io.CopyN(buf, res, length); err != nil { + return err + } + } else if _, err = io.Copy(buf, res); err != nil { + return err + } + + for _, key := range keys { + delete(c.stale, key) + + if err := c.storeBody(buf, key); err != nil { + return err + } + + if err := c.storeHeader(res.Status(), res.Header(), key); err != nil { + return err + } + } + + return nil +} + +func (c *cache) storeBody(r io.Reader, key string) error { + if err := c.vfsWrite(bodyPrefix+formatPrefix+hashKey(key), r); err != nil { + return err + } + return nil +} + +func (c *cache) storeHeader(code int, h http.Header, key string) error { + hb := &bytes.Buffer{} + hb.Write([]byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", code, http.StatusText(code)))) + headersToWriter(h, hb) + + if err := c.vfsWrite(headerPrefix+formatPrefix+hashKey(key), bytes.NewReader(hb.Bytes())); err != nil { + return err + } + return nil +} + +// Retrieve returns a cached Resource for the given key +func (c *cache) Retrieve(key string) (*Resource, error) { + f, err := c.fs.Open(bodyPrefix + formatPrefix + hashKey(key)) + if err != nil { + if vfs.IsNotExist(err) { + return nil, ErrNotFoundInCache + } + return nil, err + } + h, err := c.Header(key) + if err != nil { + if vfs.IsNotExist(err) { + return nil, ErrNotFoundInCache + } + return nil, err + } + res := NewResource(h.StatusCode, f, h.Header) + if staleTime, exists := c.stale[key]; exists { + if !res.DateAfter(staleTime) { + log.Printf("stale marker of %s found", staleTime) + res.MarkStale() + } + } + return res, nil +} + +func (c *cache) Invalidate(keys ...string) { + log.Printf("invalidating %q", keys) + for _, key := range keys { + c.stale[key] = Clock() + } +} + +func (c *cache) Freshen(res *Resource, keys ...string) error { + for _, key := range keys { + if h, err := c.Header(key); err == nil { + if h.StatusCode == res.Status() && headersEqual(h.Header, res.Header()) { + debugf("freshening key %s", key) + if err := c.storeHeader(h.StatusCode, res.Header(), key); err != nil { + return err + } + } else { + debugf("freshen failed, invalidating %s", key) + c.Invalidate(key) + } + } + } + return nil +} + +func hashKey(key string) string { + h := fnv.New64a() + h.Write([]byte(key)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func readHeaders(r *bufio.Reader) (Header, error) { + tp := textproto.NewReader(r) + line, err := tp.ReadLine() + if err != nil { + return Header{}, err + } + + f := strings.SplitN(line, " ", 3) + if len(f) < 2 { + return Header{}, fmt.Errorf("malformed HTTP response: %s", line) + } + statusCode, err := strconv.Atoi(f[1]) + if err != nil { + return Header{}, fmt.Errorf("malformed HTTP status code: %s", f[1]) + } + + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return Header{}, err + } + return Header{StatusCode: statusCode, Header: http.Header(mimeHeader)}, nil +} + +func headersToWriter(h http.Header, w io.Writer) error { + if err := h.Write(w); err != nil { + return err + } + // ReadMIMEHeader expects a trailing newline + _, err := w.Write([]byte("\r\n")) + return err +} diff --git a/httpcache/cache_test.go b/httpcache/cache_test.go new file mode 100644 index 0000000..5e442e5 --- /dev/null +++ b/httpcache/cache_test.go @@ -0,0 +1,51 @@ +package httpcache_test + +import ( + "net/http" + "strings" + "testing" + + "github.com/soulteary/apt-proxy/httpcache" + "github.com/stretchr/testify/require" +) + +func TestSaveResource(t *testing.T) { + var body = strings.Repeat("llamas", 5000) + var cache = httpcache.NewMemoryCache() + + res := httpcache.NewResourceBytes(http.StatusOK, []byte(body), http.Header{ + "Llamas": []string{"true"}, + }) + + if err := cache.Store(res, "testkey"); err != nil { + t.Fatal(err) + } + + resOut, err := cache.Retrieve("testkey") + if err != nil { + t.Fatal(err) + } + + require.NotNil(t, resOut) + require.Equal(t, res.Header(), resOut.Header()) + require.Equal(t, body, readAllString(resOut)) +} + +func TestSaveResourceWithIncorrectContentLength(t *testing.T) { + var body = "llamas" + var cache = httpcache.NewMemoryCache() + + res := httpcache.NewResourceBytes(http.StatusOK, []byte(body), http.Header{ + "Llamas": []string{"true"}, + "Content-Length": []string{"10"}, + }) + + if err := cache.Store(res, "testkey"); err == nil { + t.Fatal("Entry should have generated an error") + } + + _, err := cache.Retrieve("testkey") + if err != httpcache.ErrNotFoundInCache { + t.Fatal("Entry shouldn't have been cached") + } +} diff --git a/httpcache/cachecontrol.go b/httpcache/cachecontrol.go new file mode 100644 index 0000000..c35b9b6 --- /dev/null +++ b/httpcache/cachecontrol.go @@ -0,0 +1,109 @@ +package httpcache + +import ( + "bytes" + "fmt" + "net/http" + "sort" + "strings" + "time" +) + +const ( + CacheControlHeader = "Cache-Control" +) + +type CacheControl map[string][]string + +func ParseCacheControlHeaders(h http.Header) (CacheControl, error) { + return ParseCacheControl(strings.Join(h["Cache-Control"], ", ")) +} + +func ParseCacheControl(input string) (CacheControl, error) { + cc := make(CacheControl) + length := len(input) + isValue := false + lastKey := "" + + for pos := 0; pos < length; pos++ { + var token string + switch input[pos] { + case '"': + if offset := strings.IndexAny(input[pos+1:], `"`); offset != -1 { + token = input[pos+1 : pos+1+offset] + } else { + token = input[pos+1:] + } + pos += len(token) + 1 + case ',', '\n', '\r', ' ', '\t': + continue + case '=': + isValue = true + continue + default: + if offset := strings.IndexAny(input[pos:], "\"\n\t\r ,="); offset != -1 { + token = input[pos : pos+offset] + } else { + token = input[pos:] + } + pos += len(token) - 1 + } + if isValue { + cc.Add(lastKey, token) + isValue = false + } else { + cc.Add(token, "") + lastKey = token + } + } + + return cc, nil +} + +func (cc CacheControl) Get(key string) (string, bool) { + v, exists := cc[key] + if exists && len(v) > 0 { + return v[0], true + } + return "", exists +} + +func (cc CacheControl) Add(key, val string) { + if !cc.Has(key) { + cc[key] = []string{} + } + if val != "" { + cc[key] = append(cc[key], val) + } +} + +func (cc CacheControl) Has(key string) bool { + _, exists := cc[key] + return exists +} + +func (cc CacheControl) Duration(key string) (time.Duration, error) { + d, _ := cc.Get(key) + return time.ParseDuration(d + "s") +} + +func (cc CacheControl) String() string { + keys := make([]string, len(cc)) + for k, _ := range cc { + keys = append(keys, k) + } + sort.Strings(keys) + buf := bytes.Buffer{} + + for _, k := range keys { + vals := cc[k] + if len(vals) == 0 { + buf.WriteString(k + ", ") + } + for _, val := range vals { + buf.WriteString(fmt.Sprintf("%s=%q, ", k, val)) + } + } + + return strings.TrimSuffix(buf.String(), ", ") +} diff --git a/httpcache/cachecontrol_test.go b/httpcache/cachecontrol_test.go new file mode 100644 index 0000000..981e416 --- /dev/null +++ b/httpcache/cachecontrol_test.go @@ -0,0 +1,58 @@ +package httpcache_test + +import ( + "testing" + + . "github.com/soulteary/apt-proxy/httpcache" + "github.com/stretchr/testify/require" +) + +func TestParsingCacheControl(t *testing.T) { + table := []struct { + ccString string + ccStruct CacheControl + }{ + {`public, private="set-cookie", max-age=100`, CacheControl{ + "public": []string{}, + "private": []string{"set-cookie"}, + "max-age": []string{"100"}, + }}, + {` foo="max-age=8, space", public`, CacheControl{ + "public": []string{}, + "foo": []string{"max-age=8, space"}, + }}, + {`s-maxage=86400`, CacheControl{ + "s-maxage": []string{"86400"}, + }}, + {`max-stale`, CacheControl{ + "max-stale": []string{}, + }}, + {`max-stale=60`, CacheControl{ + "max-stale": []string{"60"}, + }}, + {`" max-age=8,max-age=8 "=blah`, CacheControl{ + " max-age=8,max-age=8 ": []string{"blah"}, + }}, + } + + for _, expect := range table { + cc, err := ParseCacheControl(expect.ccString) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, cc, expect.ccStruct) + require.NotEmpty(t, cc.String()) + } +} + +func BenchmarkCacheControlParsing(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ParseCacheControl(`public, private="set-cookie", max-age=100`) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/httpcache/handler.go b/httpcache/handler.go new file mode 100644 index 0000000..9da30e4 --- /dev/null +++ b/httpcache/handler.go @@ -0,0 +1,590 @@ +package httpcache + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net/http" + "strconv" + "sync" + "time" + + "github.com/soulteary/apt-proxy/pkgs/stream.v1" +) + +const ( + CacheHeader = "X-Cache" + ProxyDateHeader = "Proxy-Date" +) + +var Writes sync.WaitGroup + +var storeable = map[int]bool{ + http.StatusOK: true, + http.StatusFound: true, + http.StatusNonAuthoritativeInfo: true, + http.StatusMultipleChoices: true, + http.StatusMovedPermanently: true, + http.StatusGone: true, + http.StatusNotFound: true, +} + +var cacheableByDefault = map[int]bool{ + http.StatusOK: true, + http.StatusFound: true, + http.StatusNotModified: true, + http.StatusNonAuthoritativeInfo: true, + http.StatusMultipleChoices: true, + http.StatusMovedPermanently: true, + http.StatusGone: true, + http.StatusPartialContent: true, +} + +type Handler struct { + Shared bool + upstream http.Handler + validator *Validator + cache Cache +} + +func NewHandler(cache Cache, upstream http.Handler) *Handler { + return &Handler{ + upstream: upstream, + cache: cache, + validator: &Validator{upstream}, + Shared: false, + } +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + cReq, err := newCacheRequest(r) + if err != nil { + http.Error(rw, "invalid request: "+err.Error(), + http.StatusBadRequest) + return + } + + if !cReq.isCacheable() { + debugf("request not cacheable") + rw.Header().Set(CacheHeader, "SKIP") + h.pipeUpstream(rw, cReq) + return + } + + res, err := h.lookup(cReq) + if err != nil && err != ErrNotFoundInCache { + http.Error(rw, "lookup error: "+err.Error(), + http.StatusInternalServerError) + return + } + + cacheType := "private" + if h.Shared { + cacheType = "shared" + } + + if err == ErrNotFoundInCache { + if cReq.CacheControl.Has("only-if-cached") { + http.Error(rw, "key not in cache", + http.StatusGatewayTimeout) + return + } + debugf("%s %s not in %s cache", r.Method, r.URL.String(), cacheType) + h.passUpstream(rw, cReq) + return + } else { + debugf("%s %s found in %s cache", r.Method, r.URL.String(), cacheType) + } + + if h.needsValidation(res, cReq) { + if cReq.CacheControl.Has("only-if-cached") { + http.Error(rw, "key was in cache, but required validation", + http.StatusGatewayTimeout) + return + } + + debugf("validating cached response") + if h.validator.Validate(r, res) { + debugf("response is valid") + h.cache.Freshen(res, cReq.Key.String()) + } else { + debugf("response is changed") + h.passUpstream(rw, cReq) + return + } + } + + debugf("serving from cache") + res.Header().Set(CacheHeader, "HIT") + h.serveResource(res, rw, cReq) + + if err := res.Close(); err != nil { + errorf("Error closing resource: %s", err.Error()) + } +} + +// freshness returns the duration that a requested resource will be fresh for +func (h *Handler) freshness(res *Resource, r *cacheRequest) (time.Duration, error) { + maxAge, err := res.MaxAge(h.Shared) + if err != nil { + return time.Duration(0), err + } + + if r.CacheControl.Has("max-age") { + reqMaxAge, err := r.CacheControl.Duration("max-age") + if err != nil { + return time.Duration(0), err + } + + if reqMaxAge < maxAge { + debugf("using request max-age of %s", reqMaxAge.String()) + maxAge = reqMaxAge + } + } + + age, err := res.Age() + if err != nil { + return time.Duration(0), err + } + + if res.IsStale() { + return time.Duration(0), nil + } + + if hFresh := res.HeuristicFreshness(); hFresh > maxAge { + debugf("using heuristic freshness of %q", hFresh) + maxAge = hFresh + } + + return maxAge - age, nil +} + +func (h *Handler) needsValidation(res *Resource, r *cacheRequest) bool { + if res.MustValidate(h.Shared) { + return true + } + + freshness, err := h.freshness(res, r) + if err != nil { + debugf("error calculating freshness: %s", err.Error()) + return true + } + + if r.CacheControl.Has("min-fresh") { + reqMinFresh, err := r.CacheControl.Duration("min-fresh") + if err != nil { + debugf("error parsing request min-fresh: %s", err.Error()) + return true + } + + if freshness < reqMinFresh { + debugf("resource is fresh, but won't satisfy min-fresh of %s", reqMinFresh) + return true + } + } + + debugf("resource has a freshness of %s", freshness) + + if freshness <= 0 && r.CacheControl.Has("max-stale") { + if len(r.CacheControl["max-stale"]) == 0 { + debugf("resource is stale, but client sent max-stale") + return false + } else if maxStale, _ := r.CacheControl.Duration("max-stale"); maxStale >= (freshness * -1) { + log.Printf("resource is stale, but within allowed max-stale period of %s", maxStale) + return false + } + } + + return freshness <= 0 +} + +// pipeUpstream makes the request via the upstream handler, the response is not stored or modified +func (h *Handler) pipeUpstream(w http.ResponseWriter, r *cacheRequest) { + rw := newResponseStreamer(w) + rdr, err := rw.Stream.NextReader() + if err != nil { + debugf("error creating next stream reader: %v", err) + w.Header().Set(CacheHeader, "SKIP") + h.upstream.ServeHTTP(w, r.Request) + return + } + defer rdr.Close() + + debugf("piping request upstream") + go func() { + h.upstream.ServeHTTP(rw, r.Request) + rw.Stream.Close() + }() + rw.WaitHeaders() + + if r.Method != "HEAD" && !r.isStateChanging() { + return + } + + res := rw.Resource() + defer res.Close() + + if r.Method == "HEAD" { + h.cache.Freshen(res, r.Key.ForMethod("GET").String()) + } else if res.IsNonErrorStatus() { + h.invalidateResource(res, r) + } +} + +// passUpstream makes the request via the upstream handler and stores the result +func (h *Handler) passUpstream(w http.ResponseWriter, r *cacheRequest) { + rw := newResponseStreamer(w) + rdr, err := rw.Stream.NextReader() + if err != nil { + debugf("error creating next stream reader: %v", err) + w.Header().Set(CacheHeader, "SKIP") + h.upstream.ServeHTTP(w, r.Request) + return + } + + t := Clock() + debugf("passing request upstream") + rw.Header().Set(CacheHeader, "MISS") + + go func() { + h.upstream.ServeHTTP(rw, r.Request) + rw.Stream.Close() + }() + rw.WaitHeaders() + debugf("upstream responded headers in %s", Clock().Sub(t).String()) + + // just the headers! + res := NewResourceBytes(rw.StatusCode, nil, rw.Header()) + if !h.isCacheable(res, r) { + rdr.Close() + debugf("resource is uncacheable") + rw.Header().Set(CacheHeader, "SKIP") + return + } + b, err := ioutil.ReadAll(rdr) + rdr.Close() + if err != nil { + debugf("error reading stream: %v", err) + rw.Header().Set(CacheHeader, "SKIP") + return + } + debugf("full upstream response took %s", Clock().Sub(t).String()) + res.ReadSeekCloser = &byteReadSeekCloser{bytes.NewReader(b)} + + if age, err := correctedAge(res.Header(), t, Clock()); err == nil { + res.Header().Set("Age", strconv.Itoa(int(math.Ceil(age.Seconds())))) + } else { + debugf("error calculating corrected age: %s", err.Error()) + } + + rw.Header().Set(ProxyDateHeader, Clock().Format(http.TimeFormat)) + h.storeResource(res, r) +} + +// correctedAge adjusts the age of a resource for clock skew and travel time +// https://httpwg.github.io/specs/rfc7234.html#rfc.section.4.2.3 +func correctedAge(h http.Header, reqTime, respTime time.Time) (time.Duration, error) { + date, err := timeHeader("Date", h) + if err != nil { + return time.Duration(0), err + } + + apparentAge := respTime.Sub(date) + if apparentAge < 0 { + apparentAge = 0 + } + + respDelay := respTime.Sub(reqTime) + ageSeconds, err := intHeader("Age", h) + age := time.Second * time.Duration(ageSeconds) + correctedAge := age + respDelay + + if apparentAge > correctedAge { + correctedAge = apparentAge + } + + residentTime := Clock().Sub(respTime) + currentAge := correctedAge + residentTime + + return currentAge, nil +} + +func (h *Handler) isCacheable(res *Resource, r *cacheRequest) bool { + cc, err := res.cacheControl() + if err != nil { + errorf("Error parsing cache-control: %s", err.Error()) + return false + } + + if cc.Has("no-cache") || cc.Has("no-store") { + return false + } + + if cc.Has("private") && len(cc["private"]) == 0 && h.Shared { + return false + } + + if _, ok := storeable[res.Status()]; !ok { + return false + } + + if r.Header.Get("Authorization") != "" && h.Shared { + return false + } + + if res.Header().Get("Authorization") != "" && h.Shared && + !cc.Has("must-revalidate") && !cc.Has("s-maxage") { + return false + } + + if res.HasExplicitExpiration() { + return true + } + + if _, ok := cacheableByDefault[res.Status()]; !ok && !cc.Has("public") { + return false + } + + if res.HasValidators() { + return true + } else if res.HeuristicFreshness() > 0 { + return true + } + + return false +} + +func (h *Handler) serveResource(res *Resource, w http.ResponseWriter, req *cacheRequest) { + for key, headers := range res.Header() { + for _, header := range headers { + w.Header().Add(key, header) + } + } + + age, err := res.Age() + if err != nil { + http.Error(w, "Error calculating age: "+err.Error(), + http.StatusInternalServerError) + return + } + + // http://httpwg.github.io/specs/rfc7234.html#warn.113 + if age > (time.Hour*24) && res.HeuristicFreshness() > (time.Hour*24) { + w.Header().Add("Warning", `113 - "Heuristic Expiration"`) + } + + // http://httpwg.github.io/specs/rfc7234.html#warn.110 + freshness, err := h.freshness(res, req) + if err != nil || freshness <= 0 { + w.Header().Add("Warning", `110 - "Response is Stale"`) + } + + debugf("resource is %s old, updating age from %s", + age.String(), w.Header().Get("Age")) + + w.Header().Set("Age", fmt.Sprintf("%.f", math.Floor(age.Seconds()))) + w.Header().Set("Via", res.Via()) + + // hacky handler for non-ok statuses + if res.Status() != http.StatusOK { + w.WriteHeader(res.Status()) + io.Copy(w, res) + } else { + http.ServeContent(w, req.Request, "", res.LastModified(), res) + } +} + +func (h *Handler) invalidateResource(res *Resource, r *cacheRequest) { + Writes.Add(1) + + go func() { + defer Writes.Done() + debugf("invalidating resource %+v", res) + }() +} + +func (h *Handler) storeResource(res *Resource, r *cacheRequest) { + Writes.Add(1) + + go func() { + defer Writes.Done() + t := Clock() + keys := []string{r.Key.String()} + headers := res.Header() + + if h.Shared { + res.RemovePrivateHeaders() + } + + // store a secondary vary version + if vary := headers.Get("Vary"); vary != "" { + keys = append(keys, r.Key.Vary(vary, r.Request).String()) + } + + if err := h.cache.Store(res, keys...); err != nil { + errorf("storing resources %#v failed with error: %s", keys, err.Error()) + } + + debugf("stored resources %+v in %s", keys, Clock().Sub(t)) + }() +} + +// lookupResource finds the best matching Resource for the +// request, or nil and ErrNotFoundInCache if none is found +func (h *Handler) lookup(req *cacheRequest) (*Resource, error) { + res, err := h.cache.Retrieve(req.Key.String()) + + // HEAD requests can possibly be served from GET + if err == ErrNotFoundInCache && req.Method == "HEAD" { + res, err = h.cache.Retrieve(req.Key.ForMethod("GET").String()) + if err != nil { + return nil, err + } + + if res.HasExplicitExpiration() && req.isCacheable() { + debugf("using cached GET request for serving HEAD") + return res, nil + } else { + return nil, ErrNotFoundInCache + } + } else if err != nil { + return res, err + } + + // Secondary lookup for Vary + if vary := res.Header().Get("Vary"); vary != "" { + res, err = h.cache.Retrieve(req.Key.Vary(vary, req.Request).String()) + if err != nil { + return res, err + } + } + + return res, nil +} + +type cacheRequest struct { + *http.Request + Key Key + Time time.Time + CacheControl CacheControl +} + +func newCacheRequest(r *http.Request) (*cacheRequest, error) { + cc, err := ParseCacheControl(r.Header.Get("Cache-Control")) + if err != nil { + return nil, err + } + + if r.Proto == "HTTP/1.1" && r.Host == "" { + return nil, errors.New("Host header can't be empty") + } + + return &cacheRequest{ + Request: r, + Key: NewRequestKey(r), + Time: Clock(), + CacheControl: cc, + }, nil +} + +func (r *cacheRequest) isStateChanging() bool { + if !(r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE") { + return true + } + + return false +} + +func (r *cacheRequest) isCacheable() bool { + if !(r.Method == "GET" || r.Method == "HEAD") { + return false + } + + if r.Header.Get("If-Match") != "" || + r.Header.Get("If-Unmodified-Since") != "" || + r.Header.Get("If-Range") != "" { + return false + } + + if maxAge, ok := r.CacheControl.Get("max-age"); ok && maxAge == "0" { + return false + } + + if r.CacheControl.Has("no-store") || r.CacheControl.Has("no-cache") { + return false + } + + return true +} + +func newResponseStreamer(w http.ResponseWriter) *responseStreamer { + strm, err := stream.NewStream("responseBuffer", stream.NewMemFS()) + if err != nil { + panic(err) + } + return &responseStreamer{ + ResponseWriter: w, + Stream: strm, + C: make(chan struct{}), + } +} + +type responseStreamer struct { + StatusCode int + http.ResponseWriter + *stream.Stream + // C will be closed by WriteHeader to signal the headers' writing. + C chan struct{} +} + +// WaitHeaders returns iff and when WriteHeader has been called. +func (rw *responseStreamer) WaitHeaders() { + for range rw.C { + } +} + +func (rw *responseStreamer) WriteHeader(status int) { + defer close(rw.C) + rw.StatusCode = status + rw.ResponseWriter.WriteHeader(status) +} + +func (rw *responseStreamer) Write(b []byte) (int, error) { + rw.Stream.Write(b) + return rw.ResponseWriter.Write(b) +} +func (rw *responseStreamer) Close() error { + return rw.Stream.Close() +} + +// Resource returns a copy of the responseStreamer as a Resource object +func (rw *responseStreamer) Resource() *Resource { + r, err := rw.Stream.NextReader() + if err == nil { + b, err := ioutil.ReadAll(r) + r.Close() + if err == nil { + return NewResourceBytes(rw.StatusCode, b, rw.Header()) + } + } + return &Resource{ + header: rw.Header(), + statusCode: rw.StatusCode, + ReadSeekCloser: errReadSeekCloser{err}, + } +} + +type errReadSeekCloser struct { + err error +} + +func (e errReadSeekCloser) Error() string { + return e.err.Error() +} +func (e errReadSeekCloser) Close() error { return e.err } +func (e errReadSeekCloser) Read(_ []byte) (int, error) { return 0, e.err } +func (e errReadSeekCloser) Seek(_ int64, _ int) (int64, error) { return 0, e.err } diff --git a/httpcache/header.go b/httpcache/header.go new file mode 100644 index 0000000..6bd5262 --- /dev/null +++ b/httpcache/header.go @@ -0,0 +1,26 @@ +package httpcache + +import ( + "errors" + "net/http" + "strconv" + "time" +) + +var errNoHeader = errors.New("Header doesn't exist") + +func timeHeader(key string, h http.Header) (time.Time, error) { + if header := h.Get(key); header != "" { + return http.ParseTime(header) + } else { + return time.Time{}, errNoHeader + } +} + +func intHeader(key string, h http.Header) (int, error) { + if header := h.Get(key); header != "" { + return strconv.Atoi(header) + } else { + return 0, errNoHeader + } +} diff --git a/httpcache/key.go b/httpcache/key.go new file mode 100644 index 0000000..b4e7162 --- /dev/null +++ b/httpcache/key.go @@ -0,0 +1,83 @@ +package httpcache + +import ( + "bytes" + "fmt" + "net/http" + "net/url" + "strings" +) + +// Key represents a unique identifier for a resource in the cache +type Key struct { + method string + header http.Header + u url.URL + vary []string +} + +// NewKey returns a new Key instance +func NewKey(method string, u *url.URL, h http.Header) Key { + return Key{method: method, header: h, u: *u, vary: []string{}} +} + +// NewRequestKey generates a Key for a request +func NewRequestKey(r *http.Request) Key { + URL := r.URL + + if location := r.Header.Get("Content-Location"); location != "" { + u, err := url.Parse(location) + if err == nil { + if !u.IsAbs() { + u = r.URL.ResolveReference(u) + } + if u.Host != r.Host { + debugf("illegal host %q in Content-Location", u.Host) + } else { + debugf("using Content-Location: %q", u.String()) + URL = u + } + } else { + debugf("failed to parse Content-Location %q", location) + } + } + + return NewKey(r.Method, URL, r.Header) +} + +// ForMethod returns a new Key with a given method +func (k Key) ForMethod(method string) Key { + k2 := k + k2.method = method + return k2 +} + +// Vary returns a Key that is varied on particular headers in a http.Request +func (k Key) Vary(varyHeader string, r *http.Request) Key { + k2 := k + + for _, header := range strings.Split(varyHeader, ", ") { + k2.vary = append(k2.vary, header+"="+r.Header.Get(header)) + } + + return k2 +} + +func (k Key) String() string { + URL := strings.ToLower(canonicalURL(&k.u).String()) + b := &bytes.Buffer{} + b.WriteString(fmt.Sprintf("%s:%s", k.method, URL)) + + if len(k.vary) > 0 { + b.WriteString("::") + for _, v := range k.vary { + b.WriteString(v + ":") + } + } + + return b.String() +} + +func canonicalURL(u *url.URL) *url.URL { + return u +} diff --git a/httpcache/key_test.go b/httpcache/key_test.go new file mode 100644 index 0000000..2e41f35 --- /dev/null +++ b/httpcache/key_test.go @@ -0,0 +1,60 @@ +package httpcache_test + +import ( + "net/url" + "testing" + + "github.com/soulteary/apt-proxy/httpcache" + "github.com/stretchr/testify/assert" +) + +func mustParseUrl(u string) *url.URL { + ru, err := url.Parse(u) + if err != nil { + panic(err) + } + return ru +} + +func TestKeysDiffer(t *testing.T) { + k1 := httpcache.NewKey("GET", mustParseUrl("http://x.org/test"), nil) + k2 := httpcache.NewKey("GET", mustParseUrl("http://y.org/test"), nil) + + assert.NotEqual(t, k1.String(), k2.String()) +} + +func TestRequestKey(t *testing.T) { + r := newRequest("GET", "http://x.org/test") + + k1 := httpcache.NewKey("GET", mustParseUrl("http://x.org/test"), nil) + k2 := httpcache.NewRequestKey(r) + + assert.Equal(t, k1.String(), k2.String()) +} + +func TestVaryKey(t *testing.T) { + r := newRequest("GET", "http://x.org/test", "Llamas-1: true", "Llamas-2: false") + + k1 := httpcache.NewRequestKey(r) + k2 := httpcache.NewRequestKey(r).Vary("Llamas-1, Llamas-2", r) + + assert.NotEqual(t, k1.String(), k2.String()) +} + +func TestRequestKeyWithContentLocation(t *testing.T) { + r := newRequest("GET", "http://x.org/test1", "Content-Location: http://x.org/test2") + + k1 := httpcache.NewKey("GET", mustParseUrl("http://x.org/test2"), nil) + k2 := httpcache.NewRequestKey(r) + + assert.Equal(t, k1.String(), k2.String()) +} + +func TestRequestKeyWithIllegalContentLocation(t *testing.T) { + r := newRequest("GET", "http://x.org/test1", "Content-Location: http://y.org/test2") + + k1 := httpcache.NewKey("GET", mustParseUrl("http://x.org/test1"), nil) + k2 := httpcache.NewRequestKey(r) + + assert.Equal(t, k1.String(), k2.String()) +} diff --git a/httpcache/logger.go b/httpcache/logger.go new file mode 100644 index 0000000..75ab11f --- /dev/null +++ b/httpcache/logger.go @@ -0,0 +1,20 @@ +package httpcache + +import "log" + +const ( + ansiRed = "\x1b[31;1m" + ansiReset = "\x1b[0m" +) + +var DebugLogging = false + +func debugf(format string, args ...interface{}) { + if DebugLogging { + log.Printf(format, args...) + } +} + +func errorf(format string, args ...interface{}) { + log.Printf(ansiRed+"✗ "+format+ansiReset, args) +} diff --git a/httpcache/resource.go b/httpcache/resource.go new file mode 100644 index 0000000..6783741 --- /dev/null +++ b/httpcache/resource.go @@ -0,0 +1,249 @@ +package httpcache + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + lastModDivisor = 10 + viaPseudonym = "httpcache" +) + +var Clock = func() time.Time { + return time.Now().UTC() +} + +type ReadSeekCloser interface { + io.Reader + io.Seeker + io.Closer +} + +type byteReadSeekCloser struct { + *bytes.Reader +} + +func (brsc *byteReadSeekCloser) Close() error { return nil } + +type Resource struct { + ReadSeekCloser + RequestTime, ResponseTime time.Time + header http.Header + statusCode int + cc CacheControl + stale bool +} + +func NewResource(statusCode int, body ReadSeekCloser, hdrs http.Header) *Resource { + return &Resource{ + header: hdrs, + ReadSeekCloser: body, + statusCode: statusCode, + } +} + +func NewResourceBytes(statusCode int, b []byte, hdrs http.Header) *Resource { + return &Resource{ + header: hdrs, + statusCode: statusCode, + ReadSeekCloser: &byteReadSeekCloser{bytes.NewReader(b)}, + } +} + +func (r *Resource) IsNonErrorStatus() bool { + return r.statusCode >= 200 && r.statusCode < 400 +} + +func (r *Resource) Status() int { + return r.statusCode +} + +func (r *Resource) Header() http.Header { + return r.header +} + +func (r *Resource) IsStale() bool { + return r.stale +} + +func (r *Resource) MarkStale() { + r.stale = true +} + +func (r *Resource) cacheControl() (CacheControl, error) { + if r.cc != nil { + return r.cc, nil + } + + cc, err := ParseCacheControlHeaders(r.header) + if err != nil { + return cc, err + } + + r.cc = cc + return cc, nil +} + +func (r *Resource) LastModified() time.Time { + var modTime time.Time + + if lastModHeader := r.header.Get("Last-Modified"); lastModHeader != "" { + if t, err := http.ParseTime(lastModHeader); err == nil { + modTime = t + } + } + + return modTime +} + +func (r *Resource) Expires() (time.Time, error) { + if expires := r.header.Get("Expires"); expires != "" { + return http.ParseTime(expires) + } + + return time.Time{}, nil +} + +func (r *Resource) MustValidate(shared bool) bool { + cc, err := r.cacheControl() + if err != nil { + debugf("Error parsing Cache-Control: ", err.Error()) + return true + } + + // The s-maxage directive also implies the semantics of proxy-revalidate + if cc.Has("s-maxage") && shared { + return true + } + + if cc.Has("must-revalidate") || (cc.Has("proxy-revalidate") && shared) { + return true + } + + return false +} + +func (r *Resource) DateAfter(d time.Time) bool { + if dateHeader := r.header.Get("Date"); dateHeader != "" { + if t, err := http.ParseTime(dateHeader); err != nil { + return false + } else { + return t.After(d) + } + } + return false +} + +// Calculate the age of the resource +func (r *Resource) Age() (time.Duration, error) { + var age time.Duration + + if ageInt, err := intHeader("Age", r.header); err == nil { + age = time.Second * time.Duration(ageInt) + } + + if proxyDate, err := timeHeader(ProxyDateHeader, r.header); err == nil { + return Clock().Sub(proxyDate) + age, nil + } + + if date, err := timeHeader("Date", r.header); err == nil { + return Clock().Sub(date) + age, nil + } + + return time.Duration(0), errors.New("Unable to calculate age") +} + +func (r *Resource) MaxAge(shared bool) (time.Duration, error) { + cc, err := r.cacheControl() + if err != nil { + return time.Duration(0), err + } + + if cc.Has("s-maxage") && shared { + if maxAge, err := cc.Duration("s-maxage"); err != nil { + return time.Duration(0), err + } else if maxAge > 0 { + return maxAge, nil + } + } + + if cc.Has("max-age") { + if maxAge, err := cc.Duration("max-age"); err != nil { + return time.Duration(0), err + } else if maxAge > 0 { + return maxAge, nil + } + } + + if expiresVal := r.header.Get("Expires"); expiresVal != "" { + expires, err := http.ParseTime(expiresVal) + if err != nil { + return time.Duration(0), err + } + return expires.Sub(Clock()), nil + } + + return time.Duration(0), nil +} + +func (r *Resource) RemovePrivateHeaders() { + cc, err := r.cacheControl() + if err != nil { + debugf("Error parsing Cache-Control: %s", err.Error()) + } + + for _, p := range cc["private"] { + debugf("removing private header %q", p) + r.header.Del(p) + } +} + +func (r *Resource) HasValidators() bool { + if r.header.Get("Last-Modified") != "" || r.header.Get("Etag") != "" { + return true + } + + return false +} + +func (r *Resource) HasExplicitExpiration() bool { + cc, err := r.cacheControl() + if err != nil { + debugf("Error parsing Cache-Control: %s", err.Error()) + return false + } + + if d, _ := cc.Duration("max-age"); d > time.Duration(0) { + return true + } + + if d, _ := cc.Duration("s-maxage"); d > time.Duration(0) { + return true + } + + if exp, _ := r.Expires(); !exp.IsZero() { + return true + } + + return false +} + +func (r *Resource) HeuristicFreshness() time.Duration { + if !r.HasExplicitExpiration() && r.header.Get("Last-Modified") != "" { + return Clock().Sub(r.LastModified()) / time.Duration(lastModDivisor) + } + + return time.Duration(0) +} + +func (r *Resource) Via() string { + via := []string{} + via = append(via, fmt.Sprintf("1.1 %s", viaPseudonym)) + return strings.Join(via, ",") +} diff --git a/httpcache/spec_test.go b/httpcache/spec_test.go new file mode 100644 index 0000000..a6e187e --- /dev/null +++ b/httpcache/spec_test.go @@ -0,0 +1,485 @@ +package httpcache_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "testing" + "time" + + "github.com/soulteary/apt-proxy/httpcache" + "github.com/soulteary/apt-proxy/pkgs/httplog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testSetup() (*client, *upstreamServer) { + upstream := &upstreamServer{ + Body: []byte("llamas"), + asserts: []func(r *http.Request){}, + Now: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), + Header: http.Header{}, + } + + httpcache.Clock = func() time.Time { + return upstream.Now + } + + cacheHandler := httpcache.NewHandler( + httpcache.NewMemoryCache(), + upstream, + ) + + var handler http.Handler = cacheHandler + + if testing.Verbose() { + rlogger := httplog.NewResponseLogger(cacheHandler) + rlogger.DumpRequests = true + rlogger.DumpResponses = true + handler = rlogger + httpcache.DebugLogging = true + } else { + log.SetOutput(ioutil.Discard) + } + + return &client{handler, cacheHandler}, upstream +} + +func TestSpecResponseCacheControl(t *testing.T) { + var cases = []struct { + cacheControl string + cacheStatus string + requests int + secondsElapsed time.Duration + shared bool + }{ + {cacheControl: "", requests: 2}, + {cacheControl: "no-cache", requests: 2, cacheStatus: "SKIP"}, + {cacheControl: "no-store", requests: 2, cacheStatus: "SKIP"}, + {cacheControl: "max-age=0, no-cache", requests: 2, cacheStatus: "SKIP"}, + {cacheControl: "max-age=0", requests: 2, cacheStatus: "SKIP"}, + {cacheControl: "s-maxage=0", requests: 2, cacheStatus: "SKIP", shared: true}, + {cacheControl: "s-maxage=60", requests: 2, cacheStatus: "HIT", shared: true}, + {cacheControl: "s-maxage=60", requests: 2, secondsElapsed: 65, shared: true}, + {cacheControl: "max-age=60", requests: 1, cacheStatus: "HIT"}, + {cacheControl: "max-age=60", requests: 1, secondsElapsed: 35, cacheStatus: "HIT"}, + {cacheControl: "max-age=60", requests: 2, secondsElapsed: 65}, + {cacheControl: "max-age=60, must-revalidate", requests: 2, cacheStatus: "HIT"}, + {cacheControl: "max-age=60, proxy-revalidate", requests: 1, cacheStatus: "HIT"}, + {cacheControl: "max-age=60, proxy-revalidate", requests: 2, cacheStatus: "HIT", shared: true}, + {cacheControl: "private, max-age=60", requests: 1, cacheStatus: "HIT"}, + {cacheControl: "private, max-age=60", requests: 2, cacheStatus: "SKIP", shared: true}, + } + + for idx, c := range cases { + client, upstream := testSetup() + upstream.CacheControl = c.cacheControl + client.cacheHandler.Shared = c.shared + + assert.Equal(t, http.StatusOK, client.get("/").Code) + upstream.timeTravel(time.Second * time.Duration(c.secondsElapsed)) + + r := client.get("/") + assert.Equal(t, http.StatusOK, r.statusCode) + require.Equal(t, c.requests, upstream.requests, + fmt.Sprintf("case #%d failed, %+v", idx+1, c)) + + if c.cacheStatus != "" { + require.Equal(t, c.cacheStatus, r.cacheStatus, + fmt.Sprintf("case #%d failed, %+v", idx+1, c)) + } + } +} + +func TestSpecResponseCacheControlWithPrivateHeaders(t *testing.T) { + client, upstream := testSetup() + client.cacheHandler.Shared = false + upstream.CacheControl = `max-age=10, private=X-Llamas, private=Set-Cookie"` + upstream.Header.Add("X-Llamas", "fully") + upstream.Header.Add("Set-Cookie", "llamas=true") + assert.Equal(t, http.StatusOK, client.get("/r1").Code) + + r1 := client.get("/r1") + assert.Equal(t, http.StatusOK, r1.statusCode) + assert.Equal(t, "HIT", r1.cacheStatus) + assert.Equal(t, "fully", r1.HeaderMap.Get("X-Llamas")) + assert.Equal(t, "llamas=true", r1.HeaderMap.Get("Set-Cookie")) + assert.Equal(t, 1, upstream.requests) + + client.cacheHandler.Shared = true + assert.Equal(t, http.StatusOK, client.get("/r2").Code) + + r2 := client.get("/r2") + assert.Equal(t, http.StatusOK, r1.statusCode) + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, "", r2.HeaderMap.Get("X-Llamas")) + assert.Equal(t, "", r2.HeaderMap.Get("Set-Cookie")) + assert.Equal(t, 2, upstream.requests) +} + +func TestSpecResponseCacheControlWithAuthorizationHeaders(t *testing.T) { + client, upstream := testSetup() + client.cacheHandler.Shared = true + upstream.CacheControl = `max-age=10` + upstream.Header.Add("Authorization", "fully") + assert.Equal(t, http.StatusOK, client.get("/r1").Code) + + r1 := client.get("/r1") + assert.Equal(t, http.StatusOK, r1.statusCode) + assert.Equal(t, "SKIP", r1.cacheStatus) + assert.Equal(t, "fully", r1.HeaderMap.Get("Authorization")) + assert.Equal(t, 2, upstream.requests) + + client.cacheHandler.Shared = false + assert.Equal(t, http.StatusOK, client.get("/r2").Code) + + r3 := client.get("/r2") + assert.Equal(t, http.StatusOK, r3.statusCode) + assert.Equal(t, "HIT", r3.cacheStatus) + assert.Equal(t, "fully", r3.HeaderMap.Get("Authorization")) + assert.Equal(t, 3, upstream.requests) +} + +func TestSpecRequestCacheControl(t *testing.T) { + var cases = []struct { + cacheControl string + cacheStatus string + requests int + secondsElapsed time.Duration + }{ + {cacheControl: "", requests: 1}, + {cacheControl: "no-cache", requests: 2}, + {cacheControl: "no-store", requests: 2}, + {cacheControl: "max-age=0", requests: 2}, + {cacheControl: "max-stale", requests: 1, secondsElapsed: 65}, + {cacheControl: "max-stale=0", requests: 2, secondsElapsed: 65}, + {cacheControl: "max-stale=60", requests: 1, secondsElapsed: 65}, + {cacheControl: "max-stale=60", requests: 1, secondsElapsed: 65}, + {cacheControl: "max-age=30", requests: 2, secondsElapsed: 40}, + {cacheControl: "min-fresh=5", requests: 1}, + {cacheControl: "min-fresh=120", requests: 2}, + } + + for idx, c := range cases { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + + assert.Equal(t, http.StatusOK, client.get("/").Code) + upstream.timeTravel(time.Second * time.Duration(c.secondsElapsed)) + + r := client.get("/", "Cache-Control: "+c.cacheControl) + assert.Equal(t, http.StatusOK, r.statusCode) + assert.Equal(t, c.requests, upstream.requests, + fmt.Sprintf("case #%d failed, %+v", idx+1, c)) + } +} + +func TestSpecRequestCacheControlWithOnlyIfCached(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=10" + + assert.Equal(t, http.StatusOK, client.get("/").Code) + assert.Equal(t, http.StatusOK, client.get("/").Code) + + upstream.timeTravel(time.Second * 20) + assert.Equal(t, http.StatusGatewayTimeout, + client.get("/", "Cache-Control: only-if-cached").Code) + + assert.Equal(t, 1, upstream.requests) +} + +func TestSpecCachingStatusCodes(t *testing.T) { + client, upstream := testSetup() + upstream.StatusCode = http.StatusNotFound + upstream.CacheControl = "public, max-age=60" + + r1 := client.get("/r1") + assert.Equal(t, http.StatusNotFound, r1.statusCode) + assert.Equal(t, "MISS", r1.cacheStatus) + assert.Equal(t, string(upstream.Body), string(r1.body)) + + upstream.timeTravel(time.Second * 10) + r2 := client.get("/r1") + assert.Equal(t, http.StatusNotFound, r2.statusCode) + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, string(upstream.Body), string(r2.body)) + assert.Equal(t, time.Second*10, r2.age) + + upstream.StatusCode = http.StatusPaymentRequired + r3 := client.get("/r2") + assert.Equal(t, http.StatusPaymentRequired, r3.statusCode) + assert.Equal(t, "SKIP", r3.cacheStatus) +} + +func TestSpecConditionalCaching(t *testing.T) { + client, upstream := testSetup() + upstream.Etag = `"llamas"` + + r1 := client.get("/") + assert.Equal(t, "MISS", r1.cacheStatus) + assert.Equal(t, string(upstream.Body), string(r1.body)) + + r2 := client.get("/", `If-None-Match: "llamas"`) + assert.Equal(t, http.StatusNotModified, r2.Code) + assert.Equal(t, "", string(r2.body)) + assert.Equal(t, "HIT", r2.cacheStatus) +} + +func TestSpecRangeRequests(t *testing.T) { + client, upstream := testSetup() + + r1 := client.get("/", "Range: bytes=0-3") + assert.Equal(t, http.StatusPartialContent, r1.Code) + assert.Equal(t, "SKIP", r1.cacheStatus) + assert.Equal(t, string(upstream.Body[0:4]), string(r1.body)) +} + +func TestSpecHeuristicCaching(t *testing.T) { + client, upstream := testSetup() + upstream.LastModified = upstream.Now.AddDate(-1, 0, 0) + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + upstream.timeTravel(time.Hour * 48) + r2 := client.get("/") + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, []string{"113 - \"Heuristic Expiration\""}, r2.Header()["Warning"]) + assert.Equal(t, 1, upstream.requests, "The second request shouldn't validate") +} + +func TestSpecCacheControlTrumpsExpires(t *testing.T) { + client, upstream := testSetup() + upstream.LastModified = upstream.Now.AddDate(-1, 0, 0) + upstream.CacheControl = "max-age=2" + assert.Equal(t, "MISS", client.get("/").cacheStatus) + assert.Equal(t, "HIT", client.get("/").cacheStatus) + assert.Equal(t, 1, upstream.requests) + + upstream.timeTravel(time.Hour * 48) + assert.Equal(t, "HIT", client.get("/").cacheStatus) + assert.Equal(t, 2, upstream.requests) +} + +func TestSpecNotCachedWithoutValidatorOrExpiration(t *testing.T) { + client, upstream := testSetup() + upstream.LastModified = time.Time{} + upstream.Etag = "" + + assert.Equal(t, "SKIP", client.get("/").cacheStatus) + assert.Equal(t, "SKIP", client.get("/").cacheStatus) + assert.Equal(t, 2, upstream.requests) +} + +func TestSpecNoCachingForInvalidExpires(t *testing.T) { + client, upstream := testSetup() + upstream.LastModified = time.Time{} + upstream.Header.Set("Expires", "-1") + + assert.Equal(t, "SKIP", client.get("/").cacheStatus) +} + +func TestSpecRequestsWithoutHostHeader(t *testing.T) { + client, _ := testSetup() + + r := newRequest("GET", "http://example.org") + r.Header.Del("Host") + r.Host = "" + + resp := client.do(r) + assert.Equal(t, http.StatusBadRequest, resp.Code, + "Requests without a Host header should result in a 400") +} + +func TestSpecCacheControlMaxStale(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + upstream.timeTravel(time.Second * 90) + upstream.Body = []byte("brand new content") + r2 := client.get("/", "Cache-Control: max-stale=3600") + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, time.Second*90, r2.age) + + upstream.timeTravel(time.Second * 90) + r3 := client.get("/") + assert.Equal(t, "MISS", r3.cacheStatus) + assert.Equal(t, time.Duration(0), r3.age) +} + +func TestSpecValidatingStaleResponsesUnchanged(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + upstream.Etag = "llamas1" + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + upstream.timeTravel(time.Second * 90) + upstream.Header.Add("X-New-Header", "1") + + r2 := client.get("/") + assert.Equal(t, http.StatusOK, r2.Code) + assert.Equal(t, string(upstream.Body), string(r2.body)) + assert.Equal(t, "HIT", r2.cacheStatus) +} + +func TestSpecValidatingStaleResponsesWithNewContent(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + upstream.timeTravel(time.Second * 90) + upstream.Body = []byte("brand new content") + + r2 := client.get("/") + assert.Equal(t, http.StatusOK, r2.Code) + assert.Equal(t, "MISS", r2.cacheStatus) + assert.Equal(t, "brand new content", string(r2.body)) + assert.Equal(t, time.Duration(0), r2.age) +} + +func TestSpecValidatingStaleResponsesWithNewEtag(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + upstream.Etag = "llamas1" + + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + upstream.timeTravel(time.Second * 90) + upstream.Etag = "llamas2" + + r2 := client.get("/") + assert.Equal(t, http.StatusOK, r2.Code) + assert.Equal(t, "MISS", r2.cacheStatus) +} + +func TestSpecVaryHeader(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + upstream.Vary = "Accept-Language" + upstream.Etag = "llamas" + + assert.Equal(t, "MISS", client.get("/", "Accept-Language: en").cacheStatus) + assert.Equal(t, "HIT", client.get("/", "Accept-Language: en").cacheStatus) + assert.Equal(t, "MISS", client.get("/", "Accept-Language: de").cacheStatus) + assert.Equal(t, "HIT", client.get("/", "Accept-Language: de").cacheStatus) +} + +func TestSpecHeadersPropagated(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=60" + upstream.Header.Add("X-Llamas", "1") + upstream.Header.Add("X-Llamas", "3") + upstream.Header.Add("X-Llamas", "2") + + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + r2 := client.get("/") + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, []string{"1", "3", "2"}, r2.Header()["X-Llamas"]) +} + +func TestSpecAgeHeaderFromUpstream(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=86400" + upstream.Header.Set("Age", "3600") //1hr + assert.Equal(t, time.Hour, client.get("/").age) + + upstream.timeTravel(time.Hour * 2) + assert.Equal(t, time.Hour*3, client.get("/").age) +} + +func TestSpecAgeHeaderWithResponseDelay(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=86400" + upstream.Header.Set("Age", "3600") //1hr + upstream.ResponseDuration = time.Second * 2 + assert.Equal(t, time.Second*3602, client.get("/").age) + + upstream.timeTravel(time.Second * 60) + assert.Equal(t, time.Second*3662, client.get("/").age) + assert.Equal(t, 1, upstream.requests) +} + +func TestSpecAgeHeaderGeneratedWhereNoneExists(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=86400" + upstream.ResponseDuration = time.Second * 2 + assert.Equal(t, time.Second*2, client.get("/").age) + + upstream.timeTravel(time.Second * 60) + assert.Equal(t, time.Second*62, client.get("/").age) + assert.Equal(t, 1, upstream.requests) +} + +func TestSpecWarningForOldContent(t *testing.T) { + client, upstream := testSetup() + upstream.LastModified = upstream.Now.AddDate(-1, 0, 0) + assert.Equal(t, "MISS", client.get("/").cacheStatus) + + upstream.timeTravel(time.Hour * 48) + r2 := client.get("/") + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, []string{"113 - \"Heuristic Expiration\""}, r2.Header()["Warning"]) +} + +func TestSpecHeadCanBeServedFromCacheOnlyWithExplicitFreshness(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=3600" + assert.Equal(t, "MISS", client.get("/explicit").cacheStatus) + assert.Equal(t, "HIT", client.head("/explicit").cacheStatus) + assert.Equal(t, "HIT", client.head("/explicit").cacheStatus) + + upstream.CacheControl = "" + assert.Equal(t, "SKIP", client.get("/implicit").cacheStatus) + assert.Equal(t, "SKIP", client.head("/implicit").cacheStatus) + assert.Equal(t, "SKIP", client.head("/implicit").cacheStatus) +} + +func TestSpecInvalidatingGetWithHeadRequest(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=3600" + assert.Equal(t, "MISS", client.get("/explicit").cacheStatus) + + upstream.Body = []byte("brand new content") + assert.Equal(t, "SKIP", client.head("/explicit", "Cache-Control: max-age=0").cacheStatus) + assert.Equal(t, "MISS", client.get("/explicit").cacheStatus) +} + +func TestSpecFresheningGetWithHeadRequest(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=3600" + assert.Equal(t, "MISS", client.get("/explicit").cacheStatus) + + upstream.timeTravel(time.Second * 10) + assert.Equal(t, 10*time.Second, client.get("/explicit").age) + + upstream.Header.Add("X-Llamas", "llamas") + assert.Equal(t, "SKIP", client.head("/explicit", "Cache-Control: max-age=0").cacheStatus) + + refreshed := client.get("/explicit") + assert.Equal(t, "HIT", refreshed.cacheStatus) + assert.Equal(t, time.Duration(0), refreshed.age) + assert.Equal(t, "llamas", refreshed.header.Get("X-Llamas")) +} + +func TestSpecContentHeaderInRequestRespected(t *testing.T) { + client, upstream := testSetup() + upstream.CacheControl = "max-age=3600" + + r1 := client.get("/llamas/rock") + assert.Equal(t, "MISS", r1.cacheStatus) + assert.Equal(t, string(upstream.Body), string(r1.body)) + + r2 := client.get("/another/llamas", "Content-Location: /llamas/rock") + assert.Equal(t, "HIT", r2.cacheStatus) + assert.Equal(t, string(upstream.Body), string(r2.body)) +} + +func TestSpecMultipleCacheControlHeaders(t *testing.T) { + client, upstream := testSetup() + upstream.Header.Add("Cache-Control", "max-age=60, max-stale=10") + upstream.Header.Add("Cache-Control", "no-cache") + + r1 := client.get("/") + assert.Equal(t, "SKIP", r1.cacheStatus) +} diff --git a/httpcache/util_test.go b/httpcache/util_test.go new file mode 100644 index 0000000..f6ce075 --- /dev/null +++ b/httpcache/util_test.go @@ -0,0 +1,196 @@ +package httpcache_test + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "time" + + "github.com/soulteary/apt-proxy/httpcache" +) + +func newRequest(method, url string, h ...string) *http.Request { + req, err := http.NewRequest(method, url, strings.NewReader("")) + if err != nil { + panic(err) + } + req.Header = parseHeaders(h) + req.RemoteAddr = "test.local" + return req +} + +func newResponse(status int, body []byte, h ...string) *http.Response { + return &http.Response{ + Status: fmt.Sprintf("%d %s", status, http.StatusText(status)), + StatusCode: status, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: int64(len(body)), + Body: ioutil.NopCloser(bytes.NewReader(body)), + Header: parseHeaders(h), + Close: true, + } +} + +func parseHeaders(input []string) http.Header { + headers := http.Header{} + for _, header := range input { + if idx := strings.Index(header, ": "); idx != -1 { + headers.Add(header[0:idx], strings.TrimSpace(header[idx+1:])) + } + } + return headers +} + +type client struct { + handler http.Handler + cacheHandler *httpcache.Handler +} + +func (c *client) do(r *http.Request) *clientResponse { + rec := httptest.NewRecorder() + c.handler.ServeHTTP(rec, r) + rec.Flush() + + var age int + var err error + + if ageHeader := rec.HeaderMap.Get("Age"); ageHeader != "" { + age, err = strconv.Atoi(ageHeader) + if err != nil { + panic("Can't parse age header") + } + } + + // wait for writes to finish + httpcache.Writes.Wait() + + return &clientResponse{ + ResponseRecorder: rec, + cacheStatus: rec.HeaderMap.Get(httpcache.CacheHeader), + statusCode: rec.Code, + age: time.Second * time.Duration(age), + body: rec.Body.Bytes(), + header: rec.HeaderMap, + } +} + +func (c *client) get(path string, headers ...string) *clientResponse { + return c.do(newRequest("GET", "http://example.org"+path, headers...)) +} + +func (c *client) head(path string, headers ...string) *clientResponse { + return c.do(newRequest("HEAD", "http://example.org"+path, headers...)) +} + +func (c *client) put(path string, headers ...string) *clientResponse { + return c.do(newRequest("PUT", "http://example.org"+path, headers...)) +} + +func (c *client) post(path string, headers ...string) *clientResponse { + return c.do(newRequest("POST", "http://example.org"+path, headers...)) +} + +type clientResponse struct { + *httptest.ResponseRecorder + cacheStatus string + statusCode int + age time.Duration + body []byte + header http.Header +} + +type upstreamServer struct { + Now time.Time + Body []byte + Filename string + CacheControl string + Etag, Vary string + LastModified time.Time + ResponseDuration time.Duration + StatusCode int + Header http.Header + asserts []func(r *http.Request) + requests int +} + +func (u *upstreamServer) timeTravel(d time.Duration) { + u.Now = u.Now.Add(d) +} + +func (u *upstreamServer) assert(f func(r *http.Request)) { + u.asserts = append(u.asserts, f) +} + +func (u *upstreamServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + u.requests = u.requests + 1 + + for _, assertf := range u.asserts { + assertf(req) + } + + if !u.Now.IsZero() { + rw.Header().Set("Date", u.Now.Format(http.TimeFormat)) + } + + if u.CacheControl != "" { + rw.Header().Set("Cache-Control", u.CacheControl) + } + + if u.Etag != "" { + rw.Header().Set("Etag", u.Etag) + } + + if u.Vary != "" { + rw.Header().Set("Vary", u.Vary) + } + + if u.Header != nil { + for key, headers := range u.Header { + for _, header := range headers { + rw.Header().Add(key, header) + } + } + } + + u.timeTravel(u.ResponseDuration) + + if u.StatusCode != 0 && u.StatusCode != 200 { + rw.WriteHeader(u.StatusCode) + io.Copy(rw, bytes.NewReader(u.Body)) + } else { + http.ServeContent(rw, req, u.Filename, u.LastModified, bytes.NewReader(u.Body)) + } +} + +func (u *upstreamServer) RoundTrip(req *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + u.ServeHTTP(rec, req) + rec.Flush() + + resp := newResponse(rec.Code, rec.Body.Bytes()) + resp.Header = rec.HeaderMap + return resp, nil +} + +func cc(cc string) string { + return fmt.Sprintf("Cache-Control: %s", cc) +} + +func readAll(r io.Reader) []byte { + b, err := ioutil.ReadAll(r) + if err != nil { + panic(err) + } + return b +} + +func readAllString(r io.Reader) string { + return string(readAll(r)) +} diff --git a/httpcache/validator.go b/httpcache/validator.go new file mode 100644 index 0000000..9070acb --- /dev/null +++ b/httpcache/validator.go @@ -0,0 +1,66 @@ +package httpcache + +import ( + "fmt" + "net/http" + "net/http/httptest" +) + +type Validator struct { + Handler http.Handler +} + +func (v *Validator) Validate(req *http.Request, res *Resource) bool { + outreq := cloneRequest(req) + resHeaders := res.Header() + + if etag := resHeaders.Get("Etag"); etag != "" { + outreq.Header.Set("If-None-Match", etag) + } else if lastMod := resHeaders.Get("Last-Modified"); lastMod != "" { + outreq.Header.Set("If-Modified-Since", lastMod) + } + + t := Clock() + resp := httptest.NewRecorder() + v.Handler.ServeHTTP(resp, outreq) + resp.Flush() + + if age, err := correctedAge(resp.HeaderMap, t, Clock()); err == nil { + resp.Header().Set("Age", fmt.Sprintf("%.f", age.Seconds())) + } + + if headersEqual(resHeaders, resp.HeaderMap) { + res.header = resp.HeaderMap + res.header.Set(ProxyDateHeader, Clock().Format(http.TimeFormat)) + return true + } + + return false +} + +var validationHeaders = []string{"ETag", "Content-MD5", "Last-Modified", "Content-Length"} + +func headersEqual(h1, h2 http.Header) bool { + for _, header := range validationHeaders { + if value := h2.Get(header); value != "" { + if h1.Get(header) != value { + debugf("%s changed, %q != %q", header, value, h1.Get(header)) + return false + } + } + } + + return true +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func cloneRequest(r *http.Request) *http.Request { + r2 := new(http.Request) + *r2 = *r + r2.Header = make(http.Header) + for k, s := range r.Header { + r2.Header[k] = s + } + return r2 +} diff --git a/linux/benchmark.go b/linux/benchmark.go new file mode 100644 index 0000000..7f41060 --- /dev/null +++ b/linux/benchmark.go @@ -0,0 +1,87 @@ +package linux + +import ( + "errors" + "io/ioutil" + "log" + "net/http" + "time" +) + +func benchmark(base string, query string, times int) (time.Duration, error) { + var sum int64 + var d time.Duration + url := base + query + + timeout := time.Duration(mirrorTimeout * time.Second) + client := http.Client{ + Timeout: timeout, + } + + for i := 0; i < times; i++ { + timer := time.Now() + response, err := client.Get(url) + if err != nil { + return d, err + } + + defer response.Body.Close() + _, err = ioutil.ReadAll(response.Body) + if err != nil { + return d, err + } + + sum = sum + int64(time.Since(timer)) + } + + return time.Duration(sum / int64(times)), nil +} + +type benchmarkResult struct { + URL string + Duration time.Duration +} + +func fastest(m Mirrors, testUrl string) (string, error) { + ch := make(chan benchmarkResult) + log.Printf("Start benchmarking mirrors") + // kick off all benchmarks in parallel + for _, url := range m.URLs { + go func(u string) { + duration, err := benchmark(u, testUrl, benchmarkTimes) + if err == nil { + ch <- benchmarkResult{u, duration} + } + }(url) + } + + readN := len(m.URLs) + if 3 < readN { + readN = 3 + } + + // wait for the fastest results to come back + results, err := readResults(ch, readN) + log.Printf("Finished benchmarking mirrors") + if len(results) == 0 { + return "", errors.New("No results found: " + err.Error()) + } else if err != nil { + log.Printf("Error benchmarking mirrors: %s", err.Error()) + } + + return results[0].URL, nil +} + +func readResults(ch <-chan benchmarkResult, size int) (br []benchmarkResult, err error) { + for { + select { + case r := <-ch: + br = append(br, r) + if len(br) >= size { + return br, nil + } + case <-time.After(benchmarkTimeout * time.Second): + return br, errors.New("Timed out waiting for results") + } + } +} diff --git a/ubuntu/mirrors_test.go b/linux/benchmark_test.go similarity index 51% rename from ubuntu/mirrors_test.go rename to linux/benchmark_test.go index f136087..fec97c0 100644 --- a/ubuntu/mirrors_test.go +++ b/linux/benchmark_test.go @@ -1,28 +1,24 @@ -package ubuntu +package linux import ( "log" "testing" ) -func TestMirrors(t *testing.T) { - mirrors, err := GetGeoMirrors() +func TestBenchmark(t *testing.T) { + _, err := benchmark(UBUNTU_MIRROR_URLS, "", benchmarkTimes) if err != nil { t.Fatal(err) } - - if len(mirrors.URLs) == 0 { - t.Fatal("No mirrors found") - } } func TestMirrorsBenchmark(t *testing.T) { - mirrors, err := GetGeoMirrors() + mirrors, err := getGeoMirrors(UBUNTU_MIRROR_URLS) if err != nil { t.Fatal(err) } - fastest, err := mirrors.Fastest() + fastest, err := fastest(mirrors, UBUNTU_BENCHMAKR_URL) if err != nil { t.Fatal(err) } diff --git a/linux/common.go b/linux/common.go new file mode 100644 index 0000000..e7ef44b --- /dev/null +++ b/linux/common.go @@ -0,0 +1,81 @@ +package linux + +import "regexp" + +const ( + UBUNTU string = "ubuntu" + DEBIAN = "debian" +) + +type Rule struct { + Pattern *regexp.Regexp + CacheControl string + Rewrite bool +} + +const ( + mirrorTimeout = 15 // seconds, detect resource timeout + benchmarkTimes = 3 // times, maximum number of attempts + benchmarkTimeout = 10 // 10 seconds, for select fast mirror +) + +// DEBIAN +const ( + DEBIAN_BENCHMAKR_URL = "dists/bullseye/main/binary-amd64/Release" +) + +var DEBIAN_MIRROR_URLS = []string{ + "http://ftp.cn.debian.org/debian/", + "http://mirror.bjtu.edu.cn/debian/", + "http://mirror.lzu.edu.cn/debian/", + "http://mirror.nju.edu.cn/debian/", + "http://mirrors.163.com/debian/", + "http://mirrors.bfsu.edu.cn/debian/", + "http://mirrors.hit.edu.cn/debian/", + "http://mirrors.huaweicloud.com/debian/", + "http://mirrors.neusoft.edu.cn/debian/", + "http://mirrors.tuna.tsinghua.edu.cn/debian/", + "http://mirrors.ustc.edu.cn/debian/", +} + +var DEBIAN_HOST_PATTERN = regexp.MustCompile( + `https?://(deb|security|snapshot).debian.org/debian/(.+)$`, +) + +var DEBIAN_DEFAULT_CACHE_RULES = []Rule{ + {Pattern: regexp.MustCompile(`deb$`), CacheControl: `max-age=100000`, Rewrite: true}, + {Pattern: regexp.MustCompile(`udeb$`), CacheControl: `max-age=100000`, Rewrite: true}, + {Pattern: regexp.MustCompile(`DiffIndex$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`PackagesIndex$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Packages\.(bz2|gz|lzma)$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`SourcesIndex$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Sources\.(bz2|gz|lzma)$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Release(\.gpg)?$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Translation-(en|fr)\.(gz|bz2|bzip2|lzma)$`), CacheControl: `max-age=3600`, Rewrite: true}, + // Add file file hash + {Pattern: regexp.MustCompile(`/by-hash/`), CacheControl: `max-age=3600`, Rewrite: true}, +} + +// Ubuntu +const ( + UBUNTU_MIRROR_URLS = "http://mirrors.ubuntu.com/mirrors.txt" + UBUNTU_BENCHMAKR_URL = "dists/jammy/main/binary-amd64/Release" +) + +var UBUNTU_HOST_PATTERN = regexp.MustCompile( + `https?://(security|archive).ubuntu.com/ubuntu/(.+)$`, +) + +var UBUNTU_DEFAULT_CACHE_RULES = []Rule{ + {Pattern: regexp.MustCompile(`deb$`), CacheControl: `max-age=100000`, Rewrite: true}, + {Pattern: regexp.MustCompile(`udeb$`), CacheControl: `max-age=100000`, Rewrite: true}, + {Pattern: regexp.MustCompile(`DiffIndex$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`PackagesIndex$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Packages\.(bz2|gz|lzma)$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`SourcesIndex$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Sources\.(bz2|gz|lzma)$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Release(\.gpg)?$`), CacheControl: `max-age=3600`, Rewrite: true}, + {Pattern: regexp.MustCompile(`Translation-(en|fr)\.(gz|bz2|bzip2|lzma)$`), CacheControl: `max-age=3600`, Rewrite: true}, + // Add file file hash + {Pattern: regexp.MustCompile(`/by-hash/`), CacheControl: `max-age=3600`, Rewrite: true}, +} diff --git a/linux/mirrors.go b/linux/mirrors.go new file mode 100644 index 0000000..4bc8590 --- /dev/null +++ b/linux/mirrors.go @@ -0,0 +1,40 @@ +package linux + +import ( + "bufio" + "net/http" + "regexp" +) + +type Mirrors struct { + URLs []string +} + +func getGeoMirrors(mirrorListUrl string) (m Mirrors, err error) { + if len(mirrorListUrl) == 0 { + m.URLs = DEBIAN_MIRROR_URLS + return m, nil + } + response, err := http.Get(mirrorListUrl) + if err != nil { + return + } + + defer response.Body.Close() + scanner := bufio.NewScanner(response.Body) + m.URLs = []string{} + + for scanner.Scan() { + m.URLs = append(m.URLs, scanner.Text()) + } + + return m, scanner.Err() +} + +func getPredefinedConfiguration(osType string) (string, string, *regexp.Regexp) { + if osType == UBUNTU { + return UBUNTU_MIRROR_URLS, UBUNTU_BENCHMAKR_URL, UBUNTU_HOST_PATTERN + } else { + return "", DEBIAN_BENCHMAKR_URL, DEBIAN_HOST_PATTERN + } +} diff --git a/linux/mirrors_test.go b/linux/mirrors_test.go new file mode 100644 index 0000000..5e542e8 --- /dev/null +++ b/linux/mirrors_test.go @@ -0,0 +1,27 @@ +package linux + +import ( + "testing" +) + +func TestGetGeoMirrors(t *testing.T) { + mirrors, err := getGeoMirrors(UBUNTU_MIRROR_URLS) + if err != nil { + t.Fatal(err) + } + + if len(mirrors.URLs) == 0 { + t.Fatal("No mirrors found") + } +} + +func TestGetMirrorsUrlAndBenchmarkUrl(t *testing.T) { + url, res, pattern := getPredefinedConfiguration(UBUNTU) + if url != UBUNTU_MIRROR_URLS || res != UBUNTU_BENCHMAKR_URL { + t.Fatal("Failed to get resource link") + } + + if !pattern.MatchString("http://archive.ubuntu.com/ubuntu/InRelease") { + t.Fatal("Failed to verify domain name rules") + } +} diff --git a/linux/rewriter.go b/linux/rewriter.go new file mode 100644 index 0000000..9760cbe --- /dev/null +++ b/linux/rewriter.go @@ -0,0 +1,81 @@ +package linux + +import ( + "fmt" + "log" + "net/http" + "net/url" + "regexp" +) + +type URLRewriter struct { + mirror *url.URL + pattern *regexp.Regexp +} + +func NewRewriter(mirror string, osType string) *URLRewriter { + u := &URLRewriter{} + + if len(mirror) > 0 { + mirrorUrl, err := url.Parse(mirror) + if err == nil { + log.Printf("using ubuntu mirror %s", mirror) + u.mirror = mirrorUrl + _, _, pattern := getPredefinedConfiguration(osType) + u.pattern = pattern + return u + } + } + + // benchmark in the background to make sure we have the fastest + go func() { + mirrorsListUrl, benchmarkUrl, pattern := getPredefinedConfiguration(osType) + u.pattern = pattern + + mirrors, err := getGeoMirrors(mirrorsListUrl) + if err != nil { + log.Fatal(err) + } + + mirror, err := fastest(mirrors, benchmarkUrl) + if err != nil { + log.Println("Error finding fastest mirror", err) + } + + if mirrorUrl, err := url.Parse(mirror); err == nil { + log.Printf("using ubuntu mirror %s", mirror) + u.mirror = mirrorUrl + } + }() + + return u +} + +func Rewrite(r *http.Request, rewriter *URLRewriter) { + uri := r.URL.String() + if rewriter.mirror != nil && rewriter.pattern.MatchString(uri) { + r.Header.Add("Content-Location", uri) + m := rewriter.pattern.FindAllStringSubmatch(uri, -1) + // Fix the problem of double escaping of symbols + unescapedQuery, err := url.PathUnescape(m[0][2]) + if err != nil { + unescapedQuery = m[0][2] + } + r.URL.Host = rewriter.mirror.Host + r.URL.Path = rewriter.mirror.Path + unescapedQuery + } +} + +func MatchingRule(subject string, rules []Rule) (*Rule, bool) { + for _, rule := range rules { + if rule.Pattern.MatchString(subject) { + return &rule, true + } + } + return nil, false +} + +func (r *Rule) String() string { + return fmt.Sprintf("%s Cache-Control=%s Rewrite=%#v", + r.Pattern.String(), r.CacheControl, r.Rewrite) +} diff --git a/pkgs/httplog/log.go b/pkgs/httplog/log.go new file mode 100644 index 0000000..e722921 --- /dev/null +++ b/pkgs/httplog/log.go @@ -0,0 +1,131 @@ +package httplog + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + "net/http/httputil" + "os" + "strings" + "time" +) + +const ( + CacheHeader = "X-Cache" +) + +type responseWriter struct { + http.ResponseWriter + status int + size int + t time.Time + errorOutput bytes.Buffer +} + +func (l *responseWriter) Header() http.Header { + return l.ResponseWriter.Header() +} + +func (l *responseWriter) Write(b []byte) (int, error) { + if l.status == 0 { + l.status = http.StatusOK + } + if isError(l.status) { + l.errorOutput.Write(b) + } + size, err := l.ResponseWriter.Write(b) + l.size += size + return size, err +} + +func (l *responseWriter) WriteHeader(s int) { + l.ResponseWriter.WriteHeader(s) + l.status = s +} + +func (l *responseWriter) Status() int { + return l.status +} + +func (l *responseWriter) Size() int { + return l.size +} + +func NewResponseLogger(delegate http.Handler) *ResponseLogger { + return &ResponseLogger{Handler: delegate} +} + +type ResponseLogger struct { + http.Handler + DumpRequests, DumpErrors, DumpResponses bool +} + +func (l *ResponseLogger) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if l.DumpRequests { + b, _ := httputil.DumpRequest(req, false) + writePrefixString(strings.TrimSpace(string(b)), ">> ", os.Stderr) + } + + respWr := &responseWriter{ResponseWriter: w, t: time.Now()} + l.Handler.ServeHTTP(respWr, req) + + if l.DumpResponses { + buf := &bytes.Buffer{} + buf.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", + respWr.status, http.StatusText(respWr.status), + )) + respWr.Header().Write(buf) + writePrefixString(strings.TrimSpace(buf.String()), "<< ", os.Stderr) + } + + if l.DumpErrors && isError(respWr.status) { + writePrefixString(respWr.errorOutput.String(), "<< ", os.Stderr) + } + + l.writeLog(req, respWr) +} + +func (l *ResponseLogger) writeLog(req *http.Request, respWr *responseWriter) { + cacheStatus := respWr.Header().Get(CacheHeader) + + if strings.HasPrefix(cacheStatus, "HIT") { + cacheStatus = "\x1b[32;1mHIT\x1b[0m" + } else if strings.HasPrefix(cacheStatus, "MISS") { + cacheStatus = "\x1b[31;1mMISS\x1b[0m" + } else { + cacheStatus = "\x1b[33;1mSKIP\x1b[0m" + } + + clientIP := req.RemoteAddr + if colon := strings.LastIndex(clientIP, ":"); colon != -1 { + clientIP = clientIP[:colon] + } + + log.Printf( + "%s \"%s %s %s\" (%s) %d %s %s", + clientIP, + req.Method, + req.URL.String(), + req.Proto, + http.StatusText(respWr.status), + respWr.size, + cacheStatus, + time.Now().Sub(respWr.t).String(), + ) +} + +func isError(code int) bool { + return code >= 500 +} + +func writePrefixString(s, prefix string, w io.Writer) { + os.Stderr.Write([]byte("\n")) + for _, line := range strings.Split(s, "\r\n") { + w.Write([]byte(prefix)) + w.Write([]byte(line)) + w.Write([]byte("\n")) + } + os.Stderr.Write([]byte("\n")) +} diff --git a/pkgs/stream.v1/LICENSE b/pkgs/stream.v1/LICENSE new file mode 100644 index 0000000..1e7b7cc --- /dev/null +++ b/pkgs/stream.v1/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Dustin H + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/pkgs/stream.v1/README.md b/pkgs/stream.v1/README.md new file mode 100644 index 0000000..d6034d8 --- /dev/null +++ b/pkgs/stream.v1/README.md @@ -0,0 +1,80 @@ +stream +========== + +[![GoDoc](https://godoc.org/github.com/djherbis/stream?status.svg)](https://godoc.org/github.com/djherbis/stream) +[![Release](https://img.shields.io/github/release/djherbis/stream.svg)](https://github.com/djherbis/stream/releases/latest) +[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.txt) +[![Build Status](https://travis-ci.org/djherbis/stream.svg?branch=master)](https://travis-ci.org/djherbis/stream) +[![Coverage Status](https://coveralls.io/repos/djherbis/stream/badge.svg?branch=master)](https://coveralls.io/r/djherbis/stream?branch=master) + +Usage +------------ + +Write and Read concurrently, and independently. + +To explain further, if you need to write to multiple places you can use io.MultiWriter, +if you need multiple Readers on something you can use io.TeeReader. If you want concurrency you can use io.Pipe(). + +However all of these methods "tie" each Read/Write together, your readers can't read from different places in the stream, each write must be distributed to all readers in sequence. + +This package provides a way for multiple Readers to read off the same Writer, without waiting for the others. This is done by writing to a "File" interface which buffers the input so it can be read at any time from many independent readers. Readers can even be created while writing or after the stream is closed. They will all see a consistent view of the stream and will block until the section of the stream they request is written, all while being unaffected by the actions of the other readers. + +The use case for this stems from my other project djherbis/fscache. I needed a byte caching mechanism which allowed many independent clients to have access to the data while it was being written, rather than re-generating the byte stream for each of them or waiting for a complete copy of the stream which could be stored and then re-used. + +```go +import( + "io" + "log" + "os" + "time" + + "github.com/djherbis/stream" +) + +func main(){ + w, err := stream.New("mystream") + if err != nil { + log.Fatal(err) + } + + go func(){ + io.WriteString(w, "Hello World!") + <-time.After(time.Second) + io.WriteString(w, "Streaming updates...") + w.Close() + }() + + waitForReader := make(chan struct{}) + go func(){ + // Read from the stream + r, err := w.NextReader() + if err != nil { + log.Fatal(err) + } + io.Copy(os.Stdout, r) // Hello World! (1 second) Streaming updates... + r.Close() + close(waitForReader) + }() + + // Full copy of the stream! + r, err := w.NextReader() + if err != nil { + log.Fatal(err) + } + io.Copy(os.Stdout, r) // Hello World! (1 second) Streaming updates... + + // r supports io.ReaderAt too. + p := make([]byte, 4) + r.ReadAt(p, 1) // Read "ello" into p + + r.Close() + + <-waitForReader // don't leave main before go-routine finishes +} +``` + +Installation +------------ +```sh +go get github.com/djherbis/stream +``` diff --git a/pkgs/stream.v1/fs.go b/pkgs/stream.v1/fs.go new file mode 100644 index 0000000..fe808bf --- /dev/null +++ b/pkgs/stream.v1/fs.go @@ -0,0 +1,39 @@ +package stream + +import ( + "io" + "os" +) + +// File is a backing data-source for a Stream. +type File interface { + Name() string // The name used to Create/Open the File + io.Reader // Reader must continue reading after EOF on subsequent calls after more Writes. + io.ReaderAt // Similarly to Reader + io.Writer // Concurrent reading/writing must be supported. + io.Closer // Close should do any cleanup when done with the File. +} + +// FileSystem is used to manage Files +type FileSystem interface { + Create(name string) (File, error) // Create must return a new File for Writing + Open(name string) (File, error) // Open must return an existing File for Reading + Remove(name string) error // Remove deletes an existing File +} + +// StdFileSystem is backed by the os package. +var StdFileSystem FileSystem = stdFS{} + +type stdFS struct{} + +func (fs stdFS) Create(name string) (File, error) { + return os.Create(name) +} + +func (fs stdFS) Open(name string) (File, error) { + return os.Open(name) +} + +func (fs stdFS) Remove(name string) error { + return os.Remove(name) +} diff --git a/pkgs/stream.v1/memfs.go b/pkgs/stream.v1/memfs.go new file mode 100644 index 0000000..b4432ae --- /dev/null +++ b/pkgs/stream.v1/memfs.go @@ -0,0 +1,107 @@ +package stream + +import ( + "bytes" + "errors" + "io" + "sync" +) + +// ErrNotFoundInMem is returned when an in-memory FileSystem cannot find a file. +var ErrNotFoundInMem = errors.New("not found") + +type memfs struct { + mu sync.RWMutex + files map[string]*memFile +} + +// NewMemFS returns a New in-memory FileSystem +func NewMemFS() FileSystem { + return &memfs{ + files: make(map[string]*memFile), + } +} + +func (fs *memfs) Create(key string) (File, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + + file := &memFile{ + name: key, + r: bytes.NewBuffer(nil), + } + file.memReader.memFile = file + fs.files[key] = file + return file, nil +} + +func (fs *memfs) Open(key string) (File, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + + if f, ok := fs.files[key]; ok { + return &memReader{memFile: f}, nil + } + return nil, ErrNotFoundInMem +} + +func (fs *memfs) Remove(key string) error { + fs.mu.Lock() + defer fs.mu.Unlock() + delete(fs.files, key) + return nil +} + +type memFile struct { + mu sync.RWMutex + name string + r *bytes.Buffer + memReader +} + +func (f *memFile) Name() string { + return f.name +} + +func (f *memFile) Write(p []byte) (int, error) { + if len(p) > 0 { + f.mu.Lock() + defer f.mu.Unlock() + return f.r.Write(p) + } + return len(p), nil +} + +func (f *memFile) Bytes() []byte { + f.mu.RLock() + defer f.mu.RUnlock() + return f.r.Bytes() +} + +func (f *memFile) Close() error { + return nil +} + +type memReader struct { + *memFile + n int +} + +func (r *memReader) ReadAt(p []byte, off int64) (n int, err error) { + data := r.Bytes() + if int64(len(data)) < off { + return 0, io.EOF + } + n, err = bytes.NewReader(data[off:]).ReadAt(p, 0) + return n, err +} + +func (r *memReader) Read(p []byte) (n int, err error) { + n, err = bytes.NewReader(r.Bytes()[r.n:]).Read(p) + r.n += n + return n, err +} + +func (r *memReader) Close() error { + return nil +} diff --git a/pkgs/stream.v1/reader.go b/pkgs/stream.v1/reader.go new file mode 100644 index 0000000..8321270 --- /dev/null +++ b/pkgs/stream.v1/reader.go @@ -0,0 +1,82 @@ +package stream + +import "io" + +// Reader is a concurrent-safe Stream Reader. +type Reader struct { + s *Stream + file File +} + +// Name returns the name of the underlying File in the FileSystem. +func (r *Reader) Name() string { return r.file.Name() } + +// ReadAt lets you Read from specific offsets in the Stream. +// ReadAt blocks while waiting for the requested section of the Stream to be written, +// unless the Stream is closed in which case it will always return immediately. +func (r *Reader) ReadAt(p []byte, off int64) (n int, err error) { + r.s.b.RLock() + defer r.s.b.RUnlock() + + var m int + + for { + + m, err = r.file.ReadAt(p[n:], off+int64(n)) + n += m + + if r.s.b.IsOpen() { + + switch { + case n != 0 && err == nil: + return n, err + case err == io.EOF: + r.s.b.Wait() + case err != nil: + return n, err + } + + } else { + return n, err + } + + } +} + +// Read reads from the Stream. If the end of an open Stream is reached, Read +// blocks until more data is written or the Stream is Closed. +func (r *Reader) Read(p []byte) (n int, err error) { + r.s.b.RLock() + defer r.s.b.RUnlock() + + var m int + + for { + + m, err = r.file.Read(p[n:]) + n += m + + if r.s.b.IsOpen() { + + switch { + case n != 0 && err == nil: + return n, err + case err == io.EOF: + r.s.b.Wait() + case err != nil: + return n, err + } + + } else { + return n, err + } + + } +} + +// Close closes this Reader on the Stream. This must be called when done with the +// Reader or else the Stream cannot be Removed. +func (r *Reader) Close() error { + defer r.s.dec() + return r.file.Close() +} diff --git a/pkgs/stream.v1/stream.go b/pkgs/stream.v1/stream.go new file mode 100644 index 0000000..a0b3e1a --- /dev/null +++ b/pkgs/stream.v1/stream.go @@ -0,0 +1,92 @@ +// Package stream provides a way to read and write to a synchronous buffered pipe, with multiple reader support. +package stream + +import ( + "errors" + "sync" +) + +// ErrRemoving is returned when requesting a Reader on a Stream which is being Removed. +var ErrRemoving = errors.New("cannot open a new reader while removing file") + +// Stream is used to concurrently Write and Read from a File. +type Stream struct { + grp sync.WaitGroup + b *broadcaster + file File + fs FileSystem + removing chan struct{} +} + +// New creates a new Stream from the StdFileSystem with Name "name". +func New(name string) (*Stream, error) { + return NewStream(name, StdFileSystem) +} + +// NewStream creates a new Stream with Name "name" in FileSystem fs. +func NewStream(name string, fs FileSystem) (*Stream, error) { + f, err := fs.Create(name) + sf := &Stream{ + file: f, + fs: fs, + b: newBroadcaster(), + removing: make(chan struct{}), + } + sf.inc() + return sf, err +} + +// Name returns the name of the underlying File in the FileSystem. +func (s *Stream) Name() string { return s.file.Name() } + +// Write writes p to the Stream. It's concurrent safe to be called with Stream's other methods. +func (s *Stream) Write(p []byte) (int, error) { + defer s.b.Broadcast() + s.b.Lock() + defer s.b.Unlock() + return s.file.Write(p) +} + +// Close will close the active stream. This will cause Readers to return EOF once they have +// read the entire stream. +func (s *Stream) Close() error { + defer s.dec() + defer s.b.Close() + s.b.Lock() + defer s.b.Unlock() + return s.file.Close() +} + +// Remove will block until the Stream and all its Readers have been Closed, +// at which point it will delete the underlying file. NextReader() will return +// ErrRemoving if called after Remove. +func (s *Stream) Remove() error { + close(s.removing) + s.grp.Wait() + return s.fs.Remove(s.file.Name()) +} + +// NextReader will return a concurrent-safe Reader for this stream. Each Reader will +// see a complete and independent view of the stream, and can Read will the stream +// is written to. +func (s *Stream) NextReader() (*Reader, error) { + s.inc() + + select { + case <-s.removing: + s.dec() + return nil, ErrRemoving + default: + } + + file, err := s.fs.Open(s.file.Name()) + if err != nil { + s.dec() + return nil, err + } + + return &Reader{file: file, s: s}, nil +} + +func (s *Stream) inc() { s.grp.Add(1) } +func (s *Stream) dec() { s.grp.Done() } diff --git a/pkgs/stream.v1/stream_test.go b/pkgs/stream.v1/stream_test.go new file mode 100644 index 0000000..fef2942 --- /dev/null +++ b/pkgs/stream.v1/stream_test.go @@ -0,0 +1,178 @@ +package stream + +import ( + "bytes" + "errors" + "io" + "os" + "testing" + "time" +) + +var ( + testdata = []byte("hello\nworld\n") + errFail = errors.New("fail") +) + +type badFs struct { + readers []File +} +type badFile struct{ name string } + +func (r badFile) Name() string { return r.name } +func (r badFile) Read(p []byte) (int, error) { return 0, errFail } +func (r badFile) ReadAt(p []byte, off int64) (int, error) { return 0, errFail } +func (r badFile) Write(p []byte) (int, error) { return 0, errFail } +func (r badFile) Close() error { return errFail } + +func (fs badFs) Create(name string) (File, error) { return os.Create(name) } +func (fs badFs) Open(name string) (File, error) { + if len(fs.readers) > 0 { + f := fs.readers[len(fs.readers)-1] + fs.readers = fs.readers[:len(fs.readers)-1] + return f, nil + } + return nil, errFail +} +func (fs badFs) Remove(name string) error { return os.Remove(name) } + +func TestMemFs(t *testing.T) { + fs := NewMemFS() + if _, err := fs.Open("not found"); err != ErrNotFoundInMem { + t.Error(err) + t.FailNow() + } +} + +func TestBadFile(t *testing.T) { + fs := badFs{readers: make([]File, 0, 1)} + fs.readers = append(fs.readers, badFile{name: "test"}) + f, err := NewStream("test", fs) + if err != nil { + t.Error(err) + t.FailNow() + } + defer f.Remove() + defer f.Close() + + r, err := f.NextReader() + if err != nil { + t.Error(err) + t.FailNow() + } + defer r.Close() + if r.Name() != "test" { + t.Errorf("expected name to to be 'test' got %s", r.Name()) + t.FailNow() + } + if _, err := r.ReadAt(nil, 0); err == nil { + t.Error("expected ReadAt error") + t.FailNow() + } + if _, err := r.Read(nil); err == nil { + t.Error("expected Read error") + t.FailNow() + } +} + +func TestBadFs(t *testing.T) { + f, err := NewStream("test", badFs{}) + if err != nil { + t.Error(err) + t.FailNow() + } + defer f.Remove() + defer f.Close() + + r, err := f.NextReader() + if err == nil { + t.Error("expected open error") + t.FailNow() + } else { + return + } + r.Close() +} + +func TestStd(t *testing.T) { + f, err := New("test.txt") + if err != nil { + t.Error(err) + t.FailNow() + } + if f.Name() != "test.txt" { + t.Errorf("expected name to be test.txt: %s", f.Name()) + } + testFile(f, t) +} + +func TestMem(t *testing.T) { + f, err := NewStream("test.txt", NewMemFS()) + if err != nil { + t.Error(err) + t.FailNow() + } + f.Write(nil) + testFile(f, t) +} + +func TestRemove(t *testing.T) { + f, err := NewStream("test.txt", NewMemFS()) + if err != nil { + t.Error(err) + t.FailNow() + } + defer f.Close() + go f.Remove() + <-time.After(100 * time.Millisecond) + r, err := f.NextReader() + switch err { + case ErrRemoving: + case nil: + t.Error("expected error on NextReader()") + r.Close() + default: + t.Error("expected diff error on NextReader()", err) + } + +} + +func testFile(f *Stream, t *testing.T) { + + for i := 0; i < 10; i++ { + go testReader(f, t) + } + + for i := 0; i < 10; i++ { + f.Write(testdata) + <-time.After(10 * time.Millisecond) + } + + f.Close() + testReader(f, t) + f.Remove() +} + +func testReader(f *Stream, t *testing.T) { + r, err := f.NextReader() + if err != nil { + t.Error(err) + t.FailNow() + } + defer r.Close() + + buf := bytes.NewBuffer(nil) + sr := io.NewSectionReader(r, 1+int64(len(testdata)*5), 5) + io.Copy(buf, sr) + if !bytes.Equal(buf.Bytes(), testdata[1:6]) { + t.Errorf("unequal %s", buf.Bytes()) + return + } + + buf.Reset() + io.Copy(buf, r) + if !bytes.Equal(buf.Bytes(), bytes.Repeat(testdata, 10)) { + t.Errorf("unequal %s", buf.Bytes()) + return + } +} diff --git a/pkgs/stream.v1/sync.go b/pkgs/stream.v1/sync.go new file mode 100644 index 0000000..26096ed --- /dev/null +++ b/pkgs/stream.v1/sync.go @@ -0,0 +1,34 @@ +package stream + +import ( + "sync" + "sync/atomic" +) + +type broadcaster struct { + sync.RWMutex + closed uint32 + *sync.Cond +} + +func newBroadcaster() *broadcaster { + var b broadcaster + b.Cond = sync.NewCond(b.RWMutex.RLocker()) + return &b +} + +func (b *broadcaster) Wait() { + if b.IsOpen() { + b.Cond.Wait() + } +} + +func (b *broadcaster) IsOpen() bool { + return atomic.LoadUint32(&b.closed) == 0 +} + +func (b *broadcaster) Close() error { + atomic.StoreUint32(&b.closed, 1) + b.Cond.Broadcast() + return nil +} diff --git a/pkgs/vfs/LICENSE b/pkgs/vfs/LICENSE new file mode 100644 index 0000000..14e2f77 --- /dev/null +++ b/pkgs/vfs/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/pkgs/vfs/README.md b/pkgs/vfs/README.md new file mode 100644 index 0000000..a49691d --- /dev/null +++ b/pkgs/vfs/README.md @@ -0,0 +1,5 @@ +# vfs + +vfs implements Virtual File Systems with read-write support in Go (golang) + +[![GoDoc](https://godoc.org/github.com/rainycape/vfs?status.svg)](https://godoc.org/github.com/rainycape/vfs) diff --git a/pkgs/vfs/bench_test.go b/pkgs/vfs/bench_test.go new file mode 100644 index 0000000..a191265 --- /dev/null +++ b/pkgs/vfs/bench_test.go @@ -0,0 +1,43 @@ +package vfs + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "os" + "testing" +) + +func BenchmarkLoadGoSrc(b *testing.B) { + f := openOptionalTestFile(b, goTestFile) + defer f.Close() + // Decompress to avoid measuring the time to gunzip + zr, err := gzip.NewReader(f) + if err != nil { + b.Fatal(err) + } + defer zr.Close() + data, err := ioutil.ReadAll(zr) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for ii := 0; ii < b.N; ii++ { + if _, err := Tar(bytes.NewReader(data)); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWalkGoSrc(b *testing.B) { + f := openOptionalTestFile(b, goTestFile) + defer f.Close() + fs, err := TarGzip(f) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for ii := 0; ii < b.N; ii++ { + Walk(fs, "/", func(_ VFS, _ string, _ os.FileInfo, _ error) error { return nil }) + } +} diff --git a/pkgs/vfs/chroot.go b/pkgs/vfs/chroot.go new file mode 100644 index 0000000..719d913 --- /dev/null +++ b/pkgs/vfs/chroot.go @@ -0,0 +1,70 @@ +package vfs + +import ( + "fmt" + "os" + "path" +) + +type chrootFileSystem struct { + root string + fs VFS +} + +func (fs *chrootFileSystem) path(p string) string { + // root always ends with /, if there are double + // slashes they will be fixed by the underlying + // VFS + return fs.root + p +} + +func (fs *chrootFileSystem) VFS() VFS { + return fs.fs +} + +func (fs *chrootFileSystem) Open(path string) (RFile, error) { + return fs.fs.Open(fs.path(path)) +} + +func (fs *chrootFileSystem) OpenFile(path string, flag int, perm os.FileMode) (WFile, error) { + return fs.fs.OpenFile(fs.path(path), flag, perm) +} + +func (fs *chrootFileSystem) Lstat(path string) (os.FileInfo, error) { + return fs.fs.Lstat(fs.path(path)) +} + +func (fs *chrootFileSystem) Stat(path string) (os.FileInfo, error) { + return fs.fs.Stat(fs.path(path)) +} + +func (fs *chrootFileSystem) ReadDir(path string) ([]os.FileInfo, error) { + return fs.fs.ReadDir(fs.path(path)) +} + +func (fs *chrootFileSystem) Mkdir(path string, perm os.FileMode) error { + return fs.fs.Mkdir(fs.path(path), perm) +} + +func (fs *chrootFileSystem) Remove(path string) error { + return fs.fs.Remove(fs.path(path)) +} + +func (fs *chrootFileSystem) String() string { + return fmt.Sprintf("Chroot %s %s", fs.root, fs.fs.String()) +} + +// Chroot returns a new VFS wrapping the given VFS, making the given +// directory the new root ("/"). Note that root must be an existing +// directory in the given file system, otherwise an error is returned. +func Chroot(root string, fs VFS) (VFS, error) { + root = path.Clean("/" + root) + st, err := fs.Stat(root) + if err != nil { + return nil, err + } + if !st.IsDir() { + return nil, fmt.Errorf("%s is not a directory", root) + } + return &chrootFileSystem{root: root + "/", fs: fs}, nil +} diff --git a/pkgs/vfs/doc.go b/pkgs/vfs/doc.go new file mode 100644 index 0000000..b0be92f --- /dev/null +++ b/pkgs/vfs/doc.go @@ -0,0 +1,23 @@ +// Package vfs implements Virtual File Systems with read-write support. +// +// All implementatations use slash ('/') separated paths, with / representing +// the root directory. This means that to manipulate or construct paths, the +// functions in path package should be used, like path.Join or path.Dir. +// There's also no notion of the current directory nor relative paths. The paths +// /a/b/c and a/b/c are considered to point to the same element. +// +// This package also implements some shorthand functions which might be used with +// any VFS implementation, providing the same functionality than functions in the +// io/ioutil, os and path/filepath packages: +// +// io/ioutil.ReadFile => ReadFile +// io/ioutil.WriteFile => WriteFile +// os.IsExist => IsExist +// os.IsNotExist => IsNotExist +// os.MkdirAll => MkdirAll +// os.RemoveAll => RemoveAll +// path/filepath.Walk => Walk +// +// All VFS implementations are thread safe, so multiple readers and writers might +// operate on them at any time. +package vfs diff --git a/pkgs/vfs/file.go b/pkgs/vfs/file.go new file mode 100644 index 0000000..8f1f3eb --- /dev/null +++ b/pkgs/vfs/file.go @@ -0,0 +1,200 @@ +package vfs + +import ( + "os" + "path" + "sync" + "time" +) + +// EntryType indicates the type of the entry. +type EntryType uint8 + +const ( + // EntryTypeFile indicates the entry is a file. + EntryTypeFile EntryType = iota + 1 + // EntryTypeDir indicates the entry is a directory. + EntryTypeDir +) + +const ( + ModeCompress os.FileMode = 1 << 16 +) + +// Entry is the interface implemented by the in-memory representations +// of files and directories. +type Entry interface { + // Type returns the entry type, either EntryTypeFile or + // EntryTypeDir. + Type() EntryType + // Size returns the file size. For directories, it's always zero. + Size() int64 + // FileMode returns the file mode as an os.FileMode. + FileMode() os.FileMode + // ModificationTime returns the last time the file or the directory + // was modified. + ModificationTime() time.Time +} + +// Type File represents an in-memory file. Most in-memory VFS implementations +// should use this structure to represent their files, in order to save work. +type File struct { + sync.RWMutex + // Data contains the file data. + Data []byte + // Mode is the file or directory mode. Note that some filesystems + // might ignore the permission bits. + Mode os.FileMode + // ModTime represents the last modification time to the file. + ModTime time.Time +} + +func (f *File) Type() EntryType { + return EntryTypeFile +} + +func (f *File) Size() int64 { + f.RLock() + defer f.RUnlock() + return int64(len(f.Data)) +} + +func (f *File) FileMode() os.FileMode { + return f.Mode +} + +func (f *File) ModificationTime() time.Time { + f.RLock() + defer f.RUnlock() + return f.ModTime +} + +// Type Dir represents an in-memory directory. Most in-memory VFS +// implementations should use this structure to represent their +// directories, in order to save work. +type Dir struct { + sync.RWMutex + // Mode is the file or directory mode. Note that some filesystems + // might ignore the permission bits. + Mode os.FileMode + // ModTime represents the last modification time to directory. + ModTime time.Time + // Entry names in this directory, in order. + EntryNames []string + // Entries in the same order as EntryNames. + Entries []Entry +} + +func (d *Dir) Type() EntryType { + return EntryTypeDir +} + +func (d *Dir) Size() int64 { + return 0 +} + +func (d *Dir) FileMode() os.FileMode { + return d.Mode +} + +func (d *Dir) ModificationTime() time.Time { + d.RLock() + defer d.RUnlock() + return d.ModTime +} + +// Add ads a new entry to the directory. If there's already an +// entry ith the same name, an error is returned. +func (d *Dir) Add(name string, entry Entry) error { + // TODO: Binary search + for ii, v := range d.EntryNames { + if v > name { + names := make([]string, len(d.EntryNames)+1) + copy(names, d.EntryNames[:ii]) + names[ii] = name + copy(names[ii+1:], d.EntryNames[ii:]) + d.EntryNames = names + + entries := make([]Entry, len(d.Entries)+1) + copy(entries, d.Entries[:ii]) + entries[ii] = entry + copy(entries[ii+1:], d.Entries[ii:]) + + d.Entries = entries + return nil + } + if v == name { + return os.ErrExist + } + } + // Not added yet, put at the end + d.EntryNames = append(d.EntryNames, name) + d.Entries = append(d.Entries, entry) + return nil +} + +// Find returns the entry with the given name and its index, +// or an error if an entry with that name does not exist in +// the directory. +func (d *Dir) Find(name string) (Entry, int, error) { + for ii, v := range d.EntryNames { + if v == name { + return d.Entries[ii], ii, nil + } + } + return nil, -1, os.ErrNotExist +} + +// EntryInfo implements the os.FileInfo interface wrapping +// a given File and its Path in its VFS. +type EntryInfo struct { + // Path is the full path to the entry in its VFS. + Path string + // Entry is the instance used by the VFS to represent + // the in-memory entry. + Entry Entry +} + +func (info *EntryInfo) Name() string { + return path.Base(info.Path) +} + +func (info *EntryInfo) Size() int64 { + return info.Entry.Size() +} + +func (info *EntryInfo) Mode() os.FileMode { + return info.Entry.FileMode() +} + +func (info *EntryInfo) ModTime() time.Time { + return info.Entry.ModificationTime() +} + +func (info *EntryInfo) IsDir() bool { + return info.Entry.Type() == EntryTypeDir +} + +// Sys returns the underlying Entry. +func (info *EntryInfo) Sys() interface{} { + return info.Entry +} + +// FileInfos represents an slice of os.FileInfo which +// implements the sort.Interface. This type is only +// exported for users who want to implement their own +// filesystems, since VFS.ReadDir requires the returned +// []os.FileInfo to be sorted by name. +type FileInfos []os.FileInfo + +func (f FileInfos) Len() int { + return len(f) +} + +func (f FileInfos) Less(i, j int) bool { + return f[i].Name() < f[j].Name() +} + +func (f FileInfos) Swap(i, j int) { + f[i], f[j] = f[j], f[i] +} diff --git a/pkgs/vfs/file_util.go b/pkgs/vfs/file_util.go new file mode 100644 index 0000000..38514c4 --- /dev/null +++ b/pkgs/vfs/file_util.go @@ -0,0 +1,171 @@ +package vfs + +import ( + "bytes" + "compress/zlib" + "errors" + "fmt" + "io" + "os" + "runtime" + "time" +) + +var ( + errFileClosed = errors.New("file is closed") +) + +// NewRFile returns a RFile from a *File. +func NewRFile(f *File) (RFile, error) { + data, err := fileData(f) + if err != nil { + return nil, err + } + return &file{f: f, data: data, readable: true}, nil +} + +// NewWFile returns a WFile from a *File. +func NewWFile(f *File, read bool, write bool) (WFile, error) { + data, err := fileData(f) + if err != nil { + return nil, err + } + w := &file{f: f, data: data, readable: read, writable: write} + runtime.SetFinalizer(w, closeFile) + return w, nil +} + +func closeFile(f *file) { + f.Close() +} + +func fileData(f *File) ([]byte, error) { + if len(f.Data) == 0 || f.Mode&ModeCompress == 0 { + return f.Data, nil + } + zr, err := zlib.NewReader(bytes.NewReader(f.Data)) + if err != nil { + return nil, err + } + defer zr.Close() + var out bytes.Buffer + if _, err := io.Copy(&out, zr); err != nil { + return nil, err + } + return out.Bytes(), nil +} + +type file struct { + f *File + data []byte + offset int + readable bool + writable bool + closed bool +} + +func (f *file) Read(p []byte) (int, error) { + if !f.readable { + return 0, ErrWriteOnly + } + f.f.RLock() + defer f.f.RUnlock() + if f.closed { + return 0, errFileClosed + } + if f.offset > len(f.data) { + return 0, io.EOF + } + n := copy(p, f.data[f.offset:]) + f.offset += n + if n < len(p) { + return n, io.EOF + } + return n, nil +} + +func (f *file) Seek(offset int64, whence int) (int64, error) { + f.f.Lock() + defer f.f.Unlock() + if f.closed { + return 0, errFileClosed + } + switch whence { + case os.SEEK_SET: + f.offset = int(offset) + case os.SEEK_CUR: + f.offset += int(offset) + case os.SEEK_END: + f.offset = len(f.data) + int(offset) + default: + panic(fmt.Errorf("Seek: invalid whence %d", whence)) + } + if f.offset > len(f.data) { + f.offset = len(f.data) + } else if f.offset < 0 { + f.offset = 0 + } + return int64(f.offset), nil +} + +func (f *file) Write(p []byte) (int, error) { + if !f.writable { + return 0, ErrReadOnly + } + f.f.Lock() + defer f.f.Unlock() + if f.closed { + return 0, errFileClosed + } + count := len(p) + n := copy(f.data[f.offset:], p) + if n < count { + f.data = append(f.data, p[n:]...) + } + f.offset += count + f.f.ModTime = time.Now() + return count, nil +} + +func (f *file) Close() error { + if !f.closed { + f.f.Lock() + defer f.f.Unlock() + if !f.closed { + if f.f.Mode&ModeCompress != 0 { + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + if _, err := zw.Write(f.data); err != nil { + return err + } + if err := zw.Close(); err != nil { + return err + } + if buf.Len() < len(f.data) { + f.f.Data = buf.Bytes() + } else { + f.f.Mode &= ^ModeCompress + f.f.Data = f.data + } + } else { + f.f.Data = f.data + } + f.closed = true + } + } + return nil +} + +func (f *file) IsCompressed() bool { + return f.f.Mode&ModeCompress != 0 +} + +func (f *file) SetCompressed(c bool) { + f.f.Lock() + defer f.f.Unlock() + if c { + f.f.Mode |= ModeCompress + } else { + f.f.Mode &= ^ModeCompress + } +} diff --git a/pkgs/vfs/fs.go b/pkgs/vfs/fs.go new file mode 100644 index 0000000..891e503 --- /dev/null +++ b/pkgs/vfs/fs.go @@ -0,0 +1,132 @@ +package vfs + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "path/filepath" +) + +// IMPORTANT: Note about wrapping os. functions: os.Open, os.OpenFile etc... will return a non-nil +// interface pointing to a nil instance in case of error (whoever decided this disctintion in Go +// was a good idea deservers to be hung by his thumbs). This is highly undesirable, since users +// can't rely on checking f != nil to know if a correct handle was returned. That's why the +// methods in fileSystem do the error checking themselves and return a true nil in case of error. + +type fileSystem struct { + root string + temporary bool +} + +func (fs *fileSystem) path(name string) string { + name = path.Clean("/" + name) + return filepath.Join(fs.root, filepath.FromSlash(name)) +} + +// Root returns the root directory of the fileSystem, as an +// absolute path native to the current operating system. +func (fs *fileSystem) Root() string { + return fs.root +} + +// IsTemporary returns wheter the fileSystem is temporary. +func (fs *fileSystem) IsTemporary() bool { + return fs.temporary +} + +func (fs *fileSystem) Open(path string) (RFile, error) { + f, err := os.Open(fs.path(path)) + if err != nil { + return nil, err + } + return f, nil +} + +func (fs *fileSystem) OpenFile(path string, flag int, mode os.FileMode) (WFile, error) { + f, err := os.OpenFile(fs.path(path), flag, mode) + if err != nil { + return nil, err + } + return f, nil +} + +func (fs *fileSystem) Lstat(path string) (os.FileInfo, error) { + info, err := os.Lstat(fs.path(path)) + if err != nil { + return nil, err + } + return info, nil +} + +func (fs *fileSystem) Stat(path string) (os.FileInfo, error) { + info, err := os.Stat(fs.path(path)) + if err != nil { + return nil, err + } + return info, nil +} + +func (fs *fileSystem) ReadDir(path string) ([]os.FileInfo, error) { + files, err := ioutil.ReadDir(fs.path(path)) + if err != nil { + return nil, err + } + return files, nil +} + +func (fs *fileSystem) Mkdir(path string, perm os.FileMode) error { + return os.Mkdir(fs.path(path), perm) +} + +func (fs *fileSystem) Remove(path string) error { + return os.Remove(fs.path(path)) +} + +func (fs *fileSystem) String() string { + return fmt.Sprintf("fileSystem: %s", fs.root) +} + +// Close is a no-op on non-temporary filesystems. On temporary +// ones (as returned by TmpFS), it removes all the temporary files. +func (f *fileSystem) Close() error { + if f.temporary { + return os.RemoveAll(f.root) + } + return nil +} + +func newFS(root string) (*fileSystem, error) { + abs, err := filepath.Abs(root) + if err != nil { + return nil, err + } + return &fileSystem{root: abs}, nil +} + +// FS returns a VFS at the given path, which must be provided +// as native path of the current operating system. The path might be +// either absolute or relative, but the fileSystem will be anchored +// at the absolute path represented by root at the time of the function +// call. +func FS(root string) (VFS, error) { + return newFS(root) +} + +// TmpFS returns a temporary file system with the given prefix and its root +// directory name, which might be empty. The temporary file system is created +// in the default temporary directory for the operating system. Once you're +// done with the temporary filesystem, you might can all its files by calling +// its Close method. +func TmpFS(prefix string) (TemporaryVFS, error) { + dir, err := ioutil.TempDir("", prefix) + if err != nil { + return nil, err + } + fs, err := newFS(dir) + if err != nil { + return nil, err + } + fs.temporary = true + return fs, nil +} diff --git a/pkgs/vfs/map.go b/pkgs/vfs/map.go new file mode 100644 index 0000000..c9a467b --- /dev/null +++ b/pkgs/vfs/map.go @@ -0,0 +1,48 @@ +package vfs + +import ( + "path" + "sort" +) + +// Map returns an in-memory file system using the given files argument to +// populate it (which might be nil). Note that the files map does +// not need to contain any directories, they will be created automatically. +// If the files contain conflicting paths (e.g. files named a and a/b, thus +// making "a" both a file and a directory), an error will be returned. +func Map(files map[string]*File) (VFS, error) { + fs := newMemory() + keys := make([]string, 0, len(files)) + for k := range files { + keys = append(keys, k) + } + sort.Strings(keys) + var dir *Dir + var prevDir *Dir + var prevDirPath string + for _, k := range keys { + file := files[k] + if file.Mode == 0 { + file.Mode = 0644 + } + fileDir, fileBase := path.Split(k) + if prevDir != nil && fileDir == prevDirPath { + dir = prevDir + } else { + if err := MkdirAll(fs, fileDir, 0755); err != nil { + return nil, err + } + var err error + dir, err = fs.dirEntry(fileDir) + if err != nil { + return nil, err + } + prevDir = dir + prevDirPath = fileDir + } + if err := dir.Add(fileBase, file); err != nil { + return nil, err + } + } + return fs, nil +} diff --git a/pkgs/vfs/mem.go b/pkgs/vfs/mem.go new file mode 100644 index 0000000..7a878b9 --- /dev/null +++ b/pkgs/vfs/mem.go @@ -0,0 +1,239 @@ +package vfs + +import ( + "errors" + "fmt" + "os" + pathpkg "path" + "strings" + "sync" + "time" +) + +var ( + errNoEmptyNameFile = errors.New("can't create file with empty name") + errNoEmptyNameDir = errors.New("can't create directory with empty name") +) + +type memoryFileSystem struct { + mu sync.RWMutex + root *Dir +} + +// entry must always be called with the lock held +func (fs *memoryFileSystem) entry(path string) (Entry, *Dir, int, error) { + path = cleanPath(path) + if path == "" || path == "/" || path == "." { + return fs.root, nil, 0, nil + } + if path[0] == '/' { + path = path[1:] + } + dir := fs.root + for { + p := strings.IndexByte(path, '/') + name := path + if p > 0 { + name = path[:p] + path = path[p+1:] + } else { + path = "" + } + dir.RLock() + entry, pos, err := dir.Find(name) + dir.RUnlock() + if err != nil { + return nil, nil, 0, err + } + if len(path) == 0 { + return entry, dir, pos, nil + } + if entry.Type() != EntryTypeDir { + break + } + dir = entry.(*Dir) + } + return nil, nil, 0, os.ErrNotExist +} + +func (fs *memoryFileSystem) dirEntry(path string) (*Dir, error) { + entry, _, _, err := fs.entry(path) + if err != nil { + return nil, err + } + if entry.Type() != EntryTypeDir { + return nil, fmt.Errorf("%s it's not a directory", path) + } + return entry.(*Dir), nil +} + +func (fs *memoryFileSystem) Open(path string) (RFile, error) { + entry, _, _, err := fs.entry(path) + if err != nil { + return nil, err + } + if entry.Type() != EntryTypeFile { + return nil, fmt.Errorf("%s is not a file", path) + } + return NewRFile(entry.(*File)) +} + +func (fs *memoryFileSystem) OpenFile(path string, flag int, mode os.FileMode) (WFile, error) { + if mode&os.ModeType != 0 { + return nil, fmt.Errorf("%T does not support special files", fs) + } + path = cleanPath(path) + dir, base := pathpkg.Split(path) + if base == "" { + return nil, errNoEmptyNameFile + } + fs.mu.RLock() + d, err := fs.dirEntry(dir) + fs.mu.RUnlock() + if err != nil { + return nil, err + } + + d.Lock() + defer d.Unlock() + f, _, _ := d.Find(base) + if f == nil && flag&os.O_CREATE == 0 { + return nil, os.ErrNotExist + } + // Read only file? + if flag&os.O_WRONLY == 0 && flag&os.O_RDWR == 0 { + if f == nil { + return nil, os.ErrNotExist + } + return NewWFile(f.(*File), true, false) + } + // Write file, either f != nil or flag&os.O_CREATE + if f != nil { + if f.Type() != EntryTypeFile { + return nil, fmt.Errorf("%s is not a file", path) + } + if flag&os.O_EXCL != 0 { + return nil, os.ErrExist + } + // Check if we should truncate + if flag&os.O_TRUNC != 0 { + file := f.(*File) + file.Lock() + file.ModTime = time.Now() + file.Data = nil + file.Unlock() + } + } else { + f = &File{ModTime: time.Now()} + d.Add(base, f) + } + return NewWFile(f.(*File), flag&os.O_RDWR != 0, true) +} + +func (fs *memoryFileSystem) Lstat(path string) (os.FileInfo, error) { + return fs.Stat(path) +} + +func (fs *memoryFileSystem) Stat(path string) (os.FileInfo, error) { + entry, _, _, err := fs.entry(path) + if err != nil { + return nil, err + } + return &EntryInfo{Path: path, Entry: entry}, nil +} + +func (fs *memoryFileSystem) ReadDir(path string) ([]os.FileInfo, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + return fs.readDir(path) +} + +func (fs *memoryFileSystem) readDir(path string) ([]os.FileInfo, error) { + entry, _, _, err := fs.entry(path) + if err != nil { + return nil, err + } + if entry.Type() != EntryTypeDir { + return nil, fmt.Errorf("%s is not a directory", path) + } + dir := entry.(*Dir) + dir.RLock() + infos := make([]os.FileInfo, len(dir.Entries)) + for ii, v := range dir.EntryNames { + infos[ii] = &EntryInfo{ + Path: pathpkg.Join(path, v), + Entry: dir.Entries[ii], + } + } + dir.RUnlock() + return infos, nil +} + +func (fs *memoryFileSystem) Mkdir(path string, perm os.FileMode) error { + path = cleanPath(path) + dir, base := pathpkg.Split(path) + if base == "" { + if dir == "/" || dir == "" { + return os.ErrExist + } + return errNoEmptyNameDir + } + fs.mu.RLock() + d, err := fs.dirEntry(dir) + fs.mu.RUnlock() + if err != nil { + return err + } + d.Lock() + defer d.Unlock() + if _, p, _ := d.Find(base); p >= 0 { + return os.ErrExist + } + d.Add(base, &Dir{ + Mode: os.ModeDir | perm, + ModTime: time.Now(), + }) + return nil +} + +func (fs *memoryFileSystem) Remove(path string) error { + entry, dir, pos, err := fs.entry(path) + if err != nil { + return err + } + if entry.Type() == EntryTypeDir && len(entry.(*Dir).Entries) > 0 { + return fmt.Errorf("directory %s not empty", path) + } + // Lock again, the position might have changed + dir.Lock() + _, pos, err = dir.Find(pathpkg.Base(path)) + if err == nil { + dir.EntryNames = append(dir.EntryNames[:pos], dir.EntryNames[pos+1:]...) + dir.Entries = append(dir.Entries[:pos], dir.Entries[pos+1:]...) + } + dir.Unlock() + return err +} + +func (fs *memoryFileSystem) String() string { + return "MemoryFileSystem" +} + +func newMemory() *memoryFileSystem { + fs := &memoryFileSystem{ + root: &Dir{ + Mode: os.ModeDir | 0755, + ModTime: time.Now(), + }, + } + return fs +} + +// Memory returns an empty in memory VFS. +func Memory() VFS { + return newMemory() +} + +func cleanPath(path string) string { + return strings.Trim(pathpkg.Clean("/"+path), "/") +} diff --git a/pkgs/vfs/mounter.go b/pkgs/vfs/mounter.go new file mode 100644 index 0000000..8d42ba8 --- /dev/null +++ b/pkgs/vfs/mounter.go @@ -0,0 +1,163 @@ +package vfs + +import ( + "fmt" + "os" + "path" + "strings" +) + +const ( + separator = "/" +) + +func hasSubdir(root, dir string) (string, bool) { + root = path.Clean(root) + if !strings.HasSuffix(root, separator) { + root += separator + } + dir = path.Clean(dir) + if !strings.HasPrefix(dir, root) { + return "", false + } + return dir[len(root):], true +} + +type mountPoint struct { + point string + fs VFS +} + +func (m *mountPoint) String() string { + return fmt.Sprintf("%s at %s", m.fs, m.point) +} + +// Mounter implements the VFS interface and allows mounting different virtual +// file systems at arbitraty points, working much like a UNIX filesystem. +// Note that the first mounted filesystem must be always at "/". +type Mounter struct { + points []*mountPoint +} + +func (m *Mounter) fs(p string) (VFS, string, error) { + for ii := len(m.points) - 1; ii >= 0; ii-- { + if rel, ok := hasSubdir(m.points[ii].point, p); ok { + return m.points[ii].fs, rel, nil + } + } + return nil, "", os.ErrNotExist +} + +// Mount mounts the given filesystem at the given mount point. Unless the +// mount point is /, it must be an already existing directory. +func (m *Mounter) Mount(fs VFS, point string) error { + point = path.Clean(point) + if point == "." || point == "" { + point = "/" + } + if point == "/" { + if len(m.points) > 0 { + return fmt.Errorf("%s is already mounted at /", m.points[0]) + } + m.points = append(m.points, &mountPoint{point, fs}) + return nil + } + stat, err := m.Stat(point) + if err != nil { + return err + } + if !stat.IsDir() { + return fmt.Errorf("%s is not a directory", point) + } + m.points = append(m.points, &mountPoint{point, fs}) + return nil +} + +// Umount umounts the filesystem from the given mount point. If there are other filesystems +// mounted below it or there's no filesystem mounted at that point, an error is returned. +func (m *Mounter) Umount(point string) error { + point = path.Clean(point) + for ii, v := range m.points { + if v.point == point { + // Check if we have mount points below this one + for _, vv := range m.points[ii:] { + if _, ok := hasSubdir(v.point, vv.point); ok { + return fmt.Errorf("can't umount %s because %s is mounted below it", point, vv) + } + } + m.points = append(m.points[:ii], m.points[ii+1:]...) + return nil + } + } + return fmt.Errorf("no filesystem mounted at %s", point) +} + +func (m *Mounter) Open(path string) (RFile, error) { + fs, p, err := m.fs(path) + if err != nil { + return nil, err + } + return fs.Open(p) +} + +func (m *Mounter) OpenFile(path string, flag int, perm os.FileMode) (WFile, error) { + fs, p, err := m.fs(path) + if err != nil { + return nil, err + } + return fs.OpenFile(p, flag, perm) +} + +func (m *Mounter) Lstat(path string) (os.FileInfo, error) { + fs, p, err := m.fs(path) + if err != nil { + return nil, err + } + return fs.Lstat(p) +} + +func (m *Mounter) Stat(path string) (os.FileInfo, error) { + fs, p, err := m.fs(path) + if err != nil { + return nil, err + } + return fs.Stat(p) +} + +func (m *Mounter) ReadDir(path string) ([]os.FileInfo, error) { + fs, p, err := m.fs(path) + if err != nil { + return nil, err + } + return fs.ReadDir(p) +} + +func (m *Mounter) Mkdir(path string, perm os.FileMode) error { + fs, p, err := m.fs(path) + if err != nil { + return err + } + return fs.Mkdir(p, perm) +} + +func (m *Mounter) Remove(path string) error { + // TODO: Don't allow removing an empty directory + // with a mount below it. + fs, p, err := m.fs(path) + if err != nil { + return err + } + return fs.Remove(p) +} + +func (m *Mounter) String() string { + s := make([]string, len(m.points)) + for ii, v := range m.points { + s[ii] = v.String() + } + return fmt.Sprintf("Mounter: %s", strings.Join(s, ", ")) +} + +func mounterCompileTimeCheck() VFS { + return &Mounter{} +} diff --git a/pkgs/vfs/open.go b/pkgs/vfs/open.go new file mode 100644 index 0000000..e099114 --- /dev/null +++ b/pkgs/vfs/open.go @@ -0,0 +1,142 @@ +package vfs + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/bzip2" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + +// Zip returns an in-memory VFS initialized with the +// contents of the .zip file read from the given io.Reader. +// Since archive/zip requires an io.ReaderAt rather than an +// io.Reader, and a known size, Zip will read the whole file +// into memory and provide its own buffering if r does not +// implement io.ReaderAt or size is <= 0. +func Zip(r io.Reader, size int64) (VFS, error) { + rat, _ := r.(io.ReaderAt) + if rat == nil || size <= 0 { + data, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + rat = bytes.NewReader(data) + size = int64(len(data)) + } + zr, err := zip.NewReader(rat, size) + if err != nil { + return nil, err + } + files := make(map[string]*File) + for _, file := range zr.File { + if file.Mode().IsDir() { + continue + } + f, err := file.Open() + if err != nil { + return nil, err + } + data, err := ioutil.ReadAll(f) + f.Close() + if err != nil { + return nil, err + } + files[file.Name] = &File{ + Data: data, + Mode: file.Mode(), + ModTime: file.ModTime(), + } + } + return Map(files) +} + +// Tar returns an in-memory VFS initialized with the +// contents of the .tar file read from the given io.Reader. +func Tar(r io.Reader) (VFS, error) { + files := make(map[string]*File) + tr := tar.NewReader(r) + for { + hdr, err := tr.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + if hdr.FileInfo().IsDir() { + continue + } + data, err := ioutil.ReadAll(tr) + if err != nil { + return nil, err + } + files[hdr.Name] = &File{ + Data: data, + Mode: hdr.FileInfo().Mode(), + ModTime: hdr.ModTime, + } + } + return Map(files) +} + +// TarGzip returns an in-memory VFS initialized with the +// contents of the .tar.gz file read from the given io.Reader. +func TarGzip(r io.Reader) (VFS, error) { + zr, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + defer zr.Close() + return Tar(zr) +} + +// TarBzip2 returns an in-memory VFS initialized with the +// contents of then .tar.bz2 file read from the given io.Reader. +func TarBzip2(r io.Reader) (VFS, error) { + bzr := bzip2.NewReader(r) + return Tar(bzr) +} + +// Open returns an in-memory VFS initialized with the contents +// of the given filename, which must have one of the following +// extensions: +// +// - .zip +// - .tar +// - .tar.gz +// - .tar.bz2 +func Open(filename string) (VFS, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + base := filepath.Base(filename) + ext := strings.ToLower(filepath.Ext(base)) + nonExt := filename[:len(filename)-len(ext)] + if strings.ToLower(filepath.Ext(nonExt)) == ".tar" { + ext = ".tar" + ext + } + switch ext { + case ".zip": + st, err := f.Stat() + if err != nil { + return nil, err + } + return Zip(f, st.Size()) + case ".tar": + return Tar(f) + case ".tar.gz": + return TarGzip(f) + case ".tar.bz2": + return TarBzip2(f) + } + return nil, fmt.Errorf("can't open a VFS from a %s file", ext) +} diff --git a/pkgs/vfs/open_test.go b/pkgs/vfs/open_test.go new file mode 100644 index 0000000..a078ccf --- /dev/null +++ b/pkgs/vfs/open_test.go @@ -0,0 +1,48 @@ +package vfs + +import ( + "path/filepath" + "testing" +) + +func testOpenedVFS(t *testing.T, fs VFS) { + data1, err := ReadFile(fs, "a/b/c/d") + if err != nil { + t.Fatal(err) + } + if string(data1) != "go" { + t.Errorf("expecting a/b/c/d to contain \"go\", it contains %q instead", string(data1)) + } + data2, err := ReadFile(fs, "empty") + if err != nil { + t.Fatal(err) + } + if len(data2) > 0 { + t.Error("non-empty empty file") + } +} + +func testOpenFilename(t *testing.T, filename string) { + p := filepath.Join("testdata", filename) + fs, err := Open(p) + if err != nil { + t.Fatal(err) + } + testOpenedVFS(t, fs) +} + +func TestOpenZip(t *testing.T) { + testOpenFilename(t, "fs.zip") +} + +func TestOpenTar(t *testing.T) { + testOpenFilename(t, "fs.tar") +} + +func TestOpenTarGzip(t *testing.T) { + testOpenFilename(t, "fs.tar.gz") +} + +func TestOpenTarBzip2(t *testing.T) { + testOpenFilename(t, "fs.tar.bz2") +} diff --git a/pkgs/vfs/rewriter.go b/pkgs/vfs/rewriter.go new file mode 100644 index 0000000..9d7f419 --- /dev/null +++ b/pkgs/vfs/rewriter.go @@ -0,0 +1,56 @@ +package vfs + +import ( + "fmt" + "os" +) + +type rewriterFileSystem struct { + fs VFS + rewriter func(string) string +} + +func (fs *rewriterFileSystem) VFS() VFS { + return fs.fs +} + +func (fs *rewriterFileSystem) Open(path string) (RFile, error) { + return fs.fs.Open(fs.rewriter(path)) +} + +func (fs *rewriterFileSystem) OpenFile(path string, flag int, perm os.FileMode) (WFile, error) { + return fs.fs.OpenFile(fs.rewriter(path), flag, perm) +} + +func (fs *rewriterFileSystem) Lstat(path string) (os.FileInfo, error) { + return fs.fs.Lstat(fs.rewriter(path)) +} + +func (fs *rewriterFileSystem) Stat(path string) (os.FileInfo, error) { + return fs.fs.Stat(fs.rewriter(path)) +} + +func (fs *rewriterFileSystem) ReadDir(path string) ([]os.FileInfo, error) { + return fs.fs.ReadDir(fs.rewriter(path)) +} + +func (fs *rewriterFileSystem) Mkdir(path string, perm os.FileMode) error { + return fs.fs.Mkdir(fs.rewriter(path), perm) +} + +func (fs *rewriterFileSystem) Remove(path string) error { + return fs.fs.Remove(fs.rewriter(path)) +} + +func (fs *rewriterFileSystem) String() string { + return fmt.Sprintf("Rewriter %s", fs.fs.String()) +} + +// Rewriter returns a file system which uses the provided function +// to rewrite paths. +func Rewriter(fs VFS, rewriter func(oldPath string) (newPath string)) VFS { + if rewriter == nil { + return fs + } + return &rewriterFileSystem{fs: fs, rewriter: rewriter} +} diff --git a/pkgs/vfs/ro.go b/pkgs/vfs/ro.go new file mode 100644 index 0000000..4104c88 --- /dev/null +++ b/pkgs/vfs/ro.go @@ -0,0 +1,61 @@ +package vfs + +import ( + "errors" + "fmt" + "os" +) + +var ( + // ErrReadOnlyFileSystem is the error returned by read only file systems + // from calls which would result in a write operation. + ErrReadOnlyFileSystem = errors.New("read-only filesystem") +) + +type readOnlyFileSystem struct { + fs VFS +} + +func (fs *readOnlyFileSystem) VFS() VFS { + return fs.fs +} + +func (fs *readOnlyFileSystem) Open(path string) (RFile, error) { + return fs.fs.Open(path) +} + +func (fs *readOnlyFileSystem) OpenFile(path string, flag int, perm os.FileMode) (WFile, error) { + if flag&(os.O_CREATE|os.O_WRONLY|os.O_RDWR) != 0 { + return nil, ErrReadOnlyFileSystem + } + return fs.fs.OpenFile(path, flag, perm) +} + +func (fs *readOnlyFileSystem) Lstat(path string) (os.FileInfo, error) { + return fs.fs.Lstat(path) +} + +func (fs *readOnlyFileSystem) Stat(path string) (os.FileInfo, error) { + return fs.fs.Stat(path) +} + +func (fs *readOnlyFileSystem) ReadDir(path string) ([]os.FileInfo, error) { + return fs.fs.ReadDir(path) +} + +func (fs *readOnlyFileSystem) Mkdir(path string, perm os.FileMode) error { + return ErrReadOnlyFileSystem +} + +func (fs *readOnlyFileSystem) Remove(path string) error { + return ErrReadOnlyFileSystem +} + +func (fs *readOnlyFileSystem) String() string { + return fmt.Sprintf("RO %s", fs.fs.String()) +} + +// ReadOnly returns a read-only filesystem wrapping the given fs. +func ReadOnly(fs VFS) VFS { + return &readOnlyFileSystem{fs: fs} +} diff --git a/pkgs/vfs/testdata/download-data.sh b/pkgs/vfs/testdata/download-data.sh new file mode 100755 index 0000000..96ca2fd --- /dev/null +++ b/pkgs/vfs/testdata/download-data.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +SRC=https://storage.googleapis.com/golang/go1.3.src.tar.gz +if which curl > /dev/null 2>&1; then + curl -O ${SRC} +elif which wget > /dev/null 2&1; then + wget -O `basename ${SRC}` ${SRC} +else + echo "no curl nor wget found" 1>&2 + exit 1 +fi diff --git a/pkgs/vfs/testdata/fs.tar b/pkgs/vfs/testdata/fs.tar new file mode 100644 index 0000000..75a82ab Binary files /dev/null and b/pkgs/vfs/testdata/fs.tar differ diff --git a/pkgs/vfs/testdata/fs.tar.bz2 b/pkgs/vfs/testdata/fs.tar.bz2 new file mode 100644 index 0000000..1e363b0 Binary files /dev/null and b/pkgs/vfs/testdata/fs.tar.bz2 differ diff --git a/pkgs/vfs/testdata/fs.tar.gz b/pkgs/vfs/testdata/fs.tar.gz new file mode 100644 index 0000000..c5946ce Binary files /dev/null and b/pkgs/vfs/testdata/fs.tar.gz differ diff --git a/pkgs/vfs/testdata/fs.zip b/pkgs/vfs/testdata/fs.zip new file mode 100644 index 0000000..3bd1967 Binary files /dev/null and b/pkgs/vfs/testdata/fs.zip differ diff --git a/pkgs/vfs/testdata/fs/a/b/c/d b/pkgs/vfs/testdata/fs/a/b/c/d new file mode 100644 index 0000000..c08e80d --- /dev/null +++ b/pkgs/vfs/testdata/fs/a/b/c/d @@ -0,0 +1 @@ +go \ No newline at end of file diff --git a/pkgs/vfs/testdata/fs/empty b/pkgs/vfs/testdata/fs/empty new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/vfs/testdata/update-fs.sh b/pkgs/vfs/testdata/update-fs.sh new file mode 100755 index 0000000..b290b1a --- /dev/null +++ b/pkgs/vfs/testdata/update-fs.sh @@ -0,0 +1,9 @@ +#!/bin/sh + +set -e +cd fs +zip -r ../fs.zip * +tar cvvf ../fs.tar * +tar cvvzf ../fs.tar.gz * +tar cvvjf ../fs.tar.bz2 * +cd - diff --git a/pkgs/vfs/util.go b/pkgs/vfs/util.go new file mode 100644 index 0000000..c22089b --- /dev/null +++ b/pkgs/vfs/util.go @@ -0,0 +1,227 @@ +package vfs + +import ( + "errors" + "fmt" + "io/ioutil" + "os" + pathpkg "path" + "strings" +) + +var ( + // SkipDir is used by a WalkFunc to signal Walk that + // it wans to skip the given directory. + SkipDir = errors.New("skip this directory") + // ErrReadOnly is returned from Write() on a read-only file. + ErrReadOnly = errors.New("can't write to read only file") + // ErrWriteOnly is returned from Read() on a write-only file. + ErrWriteOnly = errors.New("can't read from write only file") +) + +// WalkFunc is the function type used by Walk to iterate over a VFS. +type WalkFunc func(fs VFS, path string, info os.FileInfo, err error) error + +func walk(fs VFS, p string, info os.FileInfo, fn WalkFunc) error { + err := fn(fs, p, info, nil) + if err != nil { + if info.IsDir() && err == SkipDir { + err = nil + } + return err + } + if !info.IsDir() { + return nil + } + infos, err := fs.ReadDir(p) + if err != nil { + return fn(fs, p, info, err) + } + for _, v := range infos { + name := pathpkg.Join(p, v.Name()) + fileInfo, err := fs.Lstat(name) + if err != nil { + if err := fn(fs, name, fileInfo, err); err != nil && err != SkipDir { + return err + } + continue + } + if err := walk(fs, name, fileInfo, fn); err != nil && (!fileInfo.IsDir() || err != SkipDir) { + return err + } + } + return nil +} + +// Walk iterates over all the files in the VFS which descend from the given +// root (including root itself), descending into any subdirectories. In each +// directory, files are visited in alphabetical order. The given function might +// chose to skip a directory by returning SkipDir. +func Walk(fs VFS, root string, fn WalkFunc) error { + info, err := fs.Lstat(root) + if err != nil { + return fn(fs, root, nil, err) + } + return walk(fs, root, info, fn) +} + +func makeDir(fs VFS, path string, perm os.FileMode) error { + stat, err := fs.Lstat(path) + if err == nil { + if !stat.IsDir() { + return fmt.Errorf("%s exists and is not a directory", path) + } + } else { + if err := fs.Mkdir(path, perm); err != nil { + return err + } + } + return nil +} + +// MkdirAll makes all directories pointed by the given path, using the same +// permissions for all of them. Note that MkdirAll skips directories which +// already exists rather than returning an error. +func MkdirAll(fs VFS, path string, perm os.FileMode) error { + cur := "/" + if err := makeDir(fs, cur, perm); err != nil { + return err + } + parts := strings.Split(path, "/") + for _, v := range parts { + cur += v + if err := makeDir(fs, cur, perm); err != nil { + return err + } + cur += "/" + } + return nil +} + +// RemoveAll removes all files from the given fs and path, including +// directories (by removing its contents first). +func RemoveAll(fs VFS, path string) error { + stat, err := fs.Lstat(path) + if err != nil { + if err == os.ErrNotExist { + return nil + } + return err + } + if stat.IsDir() { + files, err := fs.ReadDir(path) + if err != nil { + return err + } + for _, v := range files { + filePath := pathpkg.Join(path, v.Name()) + if err := RemoveAll(fs, filePath); err != nil { + return err + } + } + } + return fs.Remove(path) +} + +// ReadFile reads the file at the given path from the given fs, returning +// either its contents or an error if the file couldn't be read. +func ReadFile(fs VFS, path string) ([]byte, error) { + f, err := fs.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + return ioutil.ReadAll(f) +} + +// WriteFile writes a file at the given path and fs with the given data and +// permissions. If the file already exists, WriteFile truncates it before +// writing. If the file can't be created, an error will be returned. +func WriteFile(fs VFS, path string, data []byte, perm os.FileMode) error { + f, err := fs.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, perm) + if err != nil { + return err + } + if _, err := f.Write(data); err != nil { + f.Close() + return err + } + return f.Close() +} + +// Clone copies all the files from the src VFS to dst. Note that files or directories with +// all permissions set to 0 will be set to 0755 for directories and 0644 for files. If you +// need more granularity, use Walk directly to clone the file systems. +func Clone(dst VFS, src VFS) error { + err := Walk(src, "/", func(fs VFS, path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + perm := info.Mode() & os.ModePerm + if perm == 0 { + perm = 0755 + } + err := dst.Mkdir(path, info.Mode()|perm) + if err != nil && !IsExist(err) { + return err + } + return nil + } + data, err := ReadFile(fs, path) + if err != nil { + return err + } + perm := info.Mode() & os.ModePerm + if perm == 0 { + perm = 0644 + } + if err := WriteFile(dst, path, data, info.Mode()|perm); err != nil { + return err + } + return nil + }) + return err +} + +// IsExist returns wheter the error indicates that the file or directory +// already exists. +func IsExist(err error) bool { + return os.IsExist(err) +} + +// IsExist returns wheter the error indicates that the file or directory +// does not exist. +func IsNotExist(err error) bool { + return os.IsNotExist(err) +} + +// Compressor is the interface implemented by VFS files which can be +// transparently compressed and decompressed. Currently, this is only +// supported by the in-memory filesystems. +type Compressor interface { + IsCompressed() bool + SetCompressed(c bool) +} + +// Compress is a shorthand method for compressing all the files in a VFS. +// Note that not all file systems support transparent compression/decompression. +func Compress(fs VFS) error { + return Walk(fs, "/", func(fs VFS, p string, info os.FileInfo, err error) error { + if err != nil { + return err + } + mode := info.Mode() + if mode.IsDir() || mode&ModeCompress != 0 { + return nil + } + f, err := fs.Open(p) + if err != nil { + return err + } + if c, ok := f.(Compressor); ok { + c.SetCompressed(true) + } + return f.Close() + }) +} diff --git a/pkgs/vfs/vfs.go b/pkgs/vfs/vfs.go new file mode 100644 index 0000000..3a98e63 --- /dev/null +++ b/pkgs/vfs/vfs.go @@ -0,0 +1,85 @@ +package vfs + +import ( + "io" + "os" +) + +// Opener is the interface which specifies the methods for +// opening a file. All the VFS implementations implement +// this interface. +type Opener interface { + // Open returns a readable file at the given path. See also + // the shorthand function ReadFile. + Open(path string) (RFile, error) + // OpenFile returns a readable and writable file at the given + // path. Note that, depending on the flags, the file might be + // only readable or only writable. See also the shorthand + // function WriteFile. + OpenFile(path string, flag int, perm os.FileMode) (WFile, error) +} + +// RFile is the interface implemented by the returned value from a VFS +// Open method. It allows reading and seeking, and must be closed after use. +type RFile interface { + io.Reader + io.Seeker + io.Closer +} + +// WFile is the interface implemented by the returned value from a VFS +// OpenFile method. It allows reading, seeking and writing, and must +// be closed after use. Note that, depending on the flags passed to +// OpenFile, the Read or Write methods might always return an error (e.g. +// if the file was opened in read-only or write-only mode). +type WFile interface { + io.Reader + io.Writer + io.Seeker + io.Closer +} + +// VFS is the interface implemented by all the Virtual File Systems. +type VFS interface { + Opener + // Lstat returns the os.FileInfo for the given path, without + // following symlinks. + Lstat(path string) (os.FileInfo, error) + // Stat returns the os.FileInfo for the given path, following + // symlinks. + Stat(path string) (os.FileInfo, error) + // ReadDir returns the contents of the directory at path as an slice + // of os.FileInfo, ordered alphabetically by name. If path is not a + // directory or the permissions don't allow it, an error will be + // returned. + ReadDir(path string) ([]os.FileInfo, error) + // Mkdir creates a directory at the given path. If the directory + // already exists or its parent directory does not exist or + // the permissions don't allow it, an error will be returned. See + // also the shorthand function MkdirAll. + Mkdir(path string, perm os.FileMode) error + // Remove removes the item at the given path. If the path does + // not exists or the permissions don't allow removing it or it's + // a non-empty directory, an error will be returned. See also + // the shorthand function RemoveAll. + Remove(path string) error + // String returns a human-readable description of the VFS. + String() string +} + +// TemporaryVFS represents a temporary on-disk file system which can be removed +// by calling its Close method. +type TemporaryVFS interface { + VFS + // Root returns the root directory for the temporary VFS. + Root() string + // Close removes all the files in temporary VFS. + Close() error +} + +// Container is implemented by some file systems which +// contain another one. +type Container interface { + // VFS returns the underlying VFS. + VFS() VFS +} diff --git a/pkgs/vfs/vfs_test.go b/pkgs/vfs/vfs_test.go new file mode 100644 index 0000000..7ffe75a --- /dev/null +++ b/pkgs/vfs/vfs_test.go @@ -0,0 +1,336 @@ +package vfs + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "reflect" + "testing" +) + +const ( + goTestFile = "go1.3.src.tar.gz" +) + +type errNoTestFile string + +func (e errNoTestFile) Error() string { + return fmt.Sprintf("%s test file not found, use testdata/download-data.sh to fetch it", filepath.Base(string(e))) +} + +func openOptionalTestFile(t testing.TB, name string) *os.File { + filename := filepath.Join("testdata", name) + f, err := os.Open(filename) + if err != nil { + t.Skip(errNoTestFile(filename)) + } + return f +} + +func testVFS(t *testing.T, fs VFS) { + if err := WriteFile(fs, "a", []byte("A"), 0644); err != nil { + t.Fatal(err) + } + data, err := ReadFile(fs, "a") + if err != nil { + t.Fatal(err) + } + if string(data) != "A" { + t.Errorf("expecting file a to contain \"A\" got %q instead", string(data)) + } + if err := WriteFile(fs, "b", []byte("B"), 0755); err != nil { + t.Fatal(err) + } + if _, err := fs.OpenFile("b", os.O_CREATE|os.O_TRUNC|os.O_EXCL|os.O_WRONLY, 0755); err == nil || !IsExist(err) { + t.Errorf("error should be ErrExist, it's %v", err) + } + fb, err := fs.OpenFile("b", os.O_TRUNC|os.O_WRONLY, 0755) + if err != nil { + t.Fatalf("error opening b: %s", err) + } + if _, err := fb.Write([]byte("BB")); err != nil { + t.Errorf("error writing to b: %s", err) + } + if _, err := fb.Seek(0, os.SEEK_SET); err != nil { + t.Errorf("error seeking b: %s", err) + } + if _, err := fb.Read(make([]byte, 2)); err == nil { + t.Error("allowed reading WRONLY file b") + } + if err := fb.Close(); err != nil { + t.Errorf("error closing b: %s", err) + } + files, err := fs.ReadDir("/") + if err != nil { + t.Fatal(err) + } + if len(files) != 2 { + t.Errorf("expecting 2 files, got %d", len(files)) + } + if n := files[0].Name(); n != "a" { + t.Errorf("expecting first file named \"a\", got %q", n) + } + if n := files[1].Name(); n != "b" { + t.Errorf("expecting first file named \"b\", got %q", n) + } + for ii, v := range files { + es := int64(ii + 1) + if s := v.Size(); es != s { + t.Errorf("expecting file %s to have size %d, has %d", v.Name(), es, s) + } + } + if err := MkdirAll(fs, "a/b/c/d", 0); err == nil { + t.Error("should not allow dir over file") + } + if err := MkdirAll(fs, "c/d", 0755); err != nil { + t.Fatal(err) + } + // Idempotent + if err := MkdirAll(fs, "c/d", 0755); err != nil { + t.Fatal(err) + } + if err := fs.Mkdir("c", 0755); err == nil || !IsExist(err) { + t.Errorf("err should be ErrExist, it's %v", err) + } + // Should fail to remove, c is not empty + if err := fs.Remove("c"); err == nil { + t.Fatalf("removed non-empty directory") + } + var walked []os.FileInfo + var walkedNames []string + err = Walk(fs, "c", func(fs VFS, path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + walked = append(walked, info) + walkedNames = append(walkedNames, path) + return nil + }) + if err != nil { + t.Fatal(err) + } + if exp := []string{"c", "c/d"}; !reflect.DeepEqual(exp, walkedNames) { + t.Error("expecting walked names %v, got %v", exp, walkedNames) + } + for _, v := range walked { + if !v.IsDir() { + t.Errorf("%s should be a dir", v.Name()) + } + } + if err := RemoveAll(fs, "c"); err != nil { + t.Fatal(err) + } + err = Walk(fs, "c", func(fs VFS, path string, info os.FileInfo, err error) error { + return err + }) + if err == nil || !IsNotExist(err) { + t.Errorf("error should be ErrNotExist, it's %v", err) + } +} + +func TestMapFS(t *testing.T) { + fs, err := Map(nil) + if err != nil { + t.Fatal(err) + } + testVFS(t, fs) +} + +func TestPopulatedMap(t *testing.T) { + files := map[string]*File{ + "a/1": &File{}, + "a/2": &File{}, + } + fs, err := Map(files) + if err != nil { + t.Fatal(err) + } + infos, err := fs.ReadDir("a") + if err != nil { + t.Fatal(err) + } + if c := len(infos); c != 2 { + t.Fatalf("expecting 2 files in a, got %d", c) + } + if infos[0].Name() != "1" || infos[1].Name() != "2" { + t.Errorf("expecting names 1, 2 got %q, %q", infos[0].Name(), infos[1].Name()) + } +} + +func TestBadPopulatedMap(t *testing.T) { + // 1 can't be file and directory + files := map[string]*File{ + "a/1": &File{}, + "a/1/2": &File{}, + } + _, err := Map(files) + if err == nil { + t.Fatal("Map should not work with a path as both file and directory") + } +} + +func TestTmpFS(t *testing.T) { + fs, err := TmpFS("vfs-test") + if err != nil { + t.Fatal(err) + } + defer fs.Close() + testVFS(t, fs) +} + +const ( + go13FileCount = 4157 + // +1 because of the root, the real count is 407 + go13DirCount = 407 + 1 +) + +func countFileSystem(fs VFS) (int, int, error) { + files, dirs := 0, 0 + err := Walk(fs, "/", func(fs VFS, _ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + dirs++ + } else { + files++ + } + return nil + }) + return files, dirs, err +} + +func testGoFileCount(t *testing.T, fs VFS) { + files, dirs, err := countFileSystem(fs) + if err != nil { + t.Fatal(err) + } + if files != go13FileCount { + t.Errorf("expecting %d files in go1.3, got %d instead", go13FileCount, files) + } + if dirs != go13DirCount { + t.Errorf("expecting %d directories in go1.3, got %d instead", go13DirCount, dirs) + } +} + +func TestGo13Files(t *testing.T) { + f := openOptionalTestFile(t, goTestFile) + defer f.Close() + fs, err := TarGzip(f) + if err != nil { + t.Fatal(err) + } + testGoFileCount(t, fs) +} + +func TestMounter(t *testing.T) { + m := &Mounter{} + f := openOptionalTestFile(t, goTestFile) + defer f.Close() + fs, err := TarGzip(f) + if err != nil { + t.Fatal(err) + } + m.Mount(fs, "/") + testGoFileCount(t, m) +} + +func TestClone(t *testing.T) { + fs, err := Open(filepath.Join("testdata", "fs.zip")) + if err != nil { + t.Fatal(err) + } + infos1, err := fs.ReadDir("/") + if err != nil { + t.Fatal(err) + } + mem1 := Memory() + if err := Clone(mem1, fs); err != nil { + t.Fatal(err) + } + infos2, err := mem1.ReadDir("/") + if err != nil { + t.Fatal(err) + } + if len(infos2) != len(infos1) { + t.Fatalf("cloned fs has %d entries in / rather than %d", len(infos2), len(infos1)) + } + mem2 := Memory() + if err := Clone(mem2, mem1); err != nil { + t.Fatal(err) + } + infos3, err := mem2.ReadDir("/") + if err != nil { + t.Fatal(err) + } + if len(infos3) != len(infos2) { + t.Fatalf("cloned fs has %d entries in / rather than %d", len(infos3), len(infos2)) + } +} + +func measureVFSMemorySize(t testing.TB, fs VFS) int { + mem, ok := fs.(*memoryFileSystem) + if !ok { + t.Fatalf("%T is not a memory filesystem", fs) + } + var total int + var f func(d *Dir) + f = func(d *Dir) { + for _, v := range d.Entries { + total += int(v.Size()) + if sd, ok := v.(*Dir); ok { + f(sd) + } + } + } + f(mem.root) + return total +} + +func hashVFS(t testing.TB, fs VFS) string { + sha := sha1.New() + err := Walk(fs, "/", func(fs VFS, p string, info os.FileInfo, err error) error { + if err != nil || info.IsDir() { + return err + } + f, err := fs.Open(p) + if err != nil { + return err + } + defer f.Close() + if _, err := io.Copy(sha, f); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + return hex.EncodeToString(sha.Sum(nil)) +} + +func TestCompress(t *testing.T) { + f := openOptionalTestFile(t, goTestFile) + defer f.Close() + fs, err := TarGzip(f) + if err != nil { + t.Fatal(err) + } + size1 := measureVFSMemorySize(t, fs) + hash1 := hashVFS(t, fs) + if err := Compress(fs); err != nil { + t.Fatalf("can't compress fs: %s", err) + } + testGoFileCount(t, fs) + size2 := measureVFSMemorySize(t, fs) + hash2 := hashVFS(t, fs) + if size2 >= size1 { + t.Fatalf("compressed fs takes more memory %d than bare fs %d", size2, size1) + } + if hash1 != hash2 { + t.Fatalf("compressing fs changed hash from %s to %s", hash1, hash2) + } +} diff --git a/pkgs/vfs/write.go b/pkgs/vfs/write.go new file mode 100644 index 0000000..dc51b24 --- /dev/null +++ b/pkgs/vfs/write.go @@ -0,0 +1,81 @@ +package vfs + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "io" + "os" +) + +func copyVFS(fs VFS, copier func(p string, info os.FileInfo, f io.Reader) error) error { + return Walk(fs, "/", func(vfs VFS, p string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + f, err := fs.Open(p) + if err != nil { + return err + } + defer f.Close() + return copier(p[1:], info, f) + }) +} + +// WriteZip writes the given VFS as a zip file to the given io.Writer. +func WriteZip(w io.Writer, fs VFS) error { + zw := zip.NewWriter(w) + err := copyVFS(fs, func(p string, info os.FileInfo, f io.Reader) error { + hdr, err := zip.FileInfoHeader(info) + if err != nil { + return err + } + hdr.Name = p + fw, err := zw.CreateHeader(hdr) + if err != nil { + return err + } + _, err = io.Copy(fw, f) + return err + }) + if err != nil { + return err + } + return zw.Close() +} + +// WriteTar writes the given VFS as a tar file to the given io.Writer. +func WriteTar(w io.Writer, fs VFS) error { + tw := tar.NewWriter(w) + err := copyVFS(fs, func(p string, info os.FileInfo, f io.Reader) error { + hdr, err := tar.FileInfoHeader(info, "") + if err != nil { + return err + } + hdr.Name = p + if err := tw.WriteHeader(hdr); err != nil { + return err + } + _, err = io.Copy(tw, f) + return err + }) + if err != nil { + return err + } + return tw.Close() +} + +// WriteTarGzip writes the given VFS as a tar.gz file to the given io.Writer. +func WriteTarGzip(w io.Writer, fs VFS) error { + gw, err := gzip.NewWriterLevel(w, gzip.BestCompression) + if err != nil { + return err + } + if err := WriteTar(gw, fs); err != nil { + return err + } + return gw.Close() +} diff --git a/pkgs/vfs/write_test.go b/pkgs/vfs/write_test.go new file mode 100644 index 0000000..9969079 --- /dev/null +++ b/pkgs/vfs/write_test.go @@ -0,0 +1,41 @@ +package vfs + +import ( + "bytes" + "io" + "path/filepath" + "testing" +) + +type writeTester struct { + name string + writer func(io.Writer, VFS) error + reader func(io.Reader) (VFS, error) +} + +func TestWrite(t *testing.T) { + var ( + writeTests = []writeTester{ + {"zip", WriteZip, func(r io.Reader) (VFS, error) { return Zip(r, 0) }}, + {"tar", WriteTar, Tar}, + {"tar.gz", WriteTarGzip, TarGzip}, + } + ) + p := filepath.Join("testdata", "fs.zip") + fs, err := Open(p) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + for _, v := range writeTests { + buf.Reset() + if err := v.writer(&buf, fs); err != nil { + t.Fatalf("error writing %s: %s", v.name, err) + } + newFs, err := v.reader(&buf) + if err != nil { + t.Fatalf("error reading %s: %s", v.name, err) + } + testOpenedVFS(t, newFs) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 1d2f2e3..e1d4267 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,25 +6,33 @@ import ( "net/http/httputil" "time" - "github.com/lox/apt-proxy/ubuntu" + "github.com/soulteary/apt-proxy/linux" ) -var ubuntuRewriter = ubuntu.NewRewriter() +var rewriter *linux.URLRewriter var defaultTransport http.RoundTripper = &http.Transport{ - Proxy: http.ProxyFromEnvironment, + Proxy: http.ProxyFromEnvironment, ResponseHeaderTimeout: time.Second * 45, DisableKeepAlives: true, } type AptProxy struct { Handler http.Handler - Rules []Rule + Rules []linux.Rule } -func NewAptProxyFromDefaults() *AptProxy { +func NewAptProxyFromDefaults(mirror string, osType string) *AptProxy { + rewriter = linux.NewRewriter(mirror, osType) + var rules []linux.Rule + if osType == linux.UBUNTU { + rules = linux.UBUNTU_DEFAULT_CACHE_RULES + } else if osType == linux.DEBIAN { + rules = linux.DEBIAN_DEFAULT_CACHE_RULES + } + return &AptProxy{ - Rules: DefaultRules, + Rules: rules, Handler: &httputil.ReverseProxy{ Director: func(r *http.Request) {}, Transport: defaultTransport, @@ -33,7 +41,7 @@ func NewAptProxyFromDefaults() *AptProxy { } func (ap *AptProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - rule, match := matchingRule(r.URL.Path, ap.Rules) + rule, match := linux.MatchingRule(r.URL.Path, ap.Rules) if match { r.Header.Del("Cache-Control") if rule.Rewrite { @@ -46,14 +54,14 @@ func (ap *AptProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (ap *AptProxy) rewriteRequest(r *http.Request) { before := r.URL.String() - ubuntuRewriter.Rewrite(r) + linux.Rewrite(r, rewriter) log.Printf("rewrote %q to %q", before, r.URL.String()) r.Host = r.URL.Host } type responseWriter struct { http.ResponseWriter - rule *Rule + rule *linux.Rule } func (rw *responseWriter) WriteHeader(status int) { diff --git a/proxy/rules.go b/proxy/rules.go deleted file mode 100644 index 53d2cd9..0000000 --- a/proxy/rules.go +++ /dev/null @@ -1,39 +0,0 @@ -package proxy - -import ( - "fmt" - "regexp" -) - -var DefaultRules = []Rule{ - {Pattern: regexp.MustCompile(`deb$`), CacheControl: `max-age=100000`, Rewrite: true}, - {Pattern: regexp.MustCompile(`udeb$`), CacheControl: `max-age=100000`, Rewrite: true}, - {Pattern: regexp.MustCompile(`DiffIndex$`), CacheControl: `max-age=3600`}, - {Pattern: regexp.MustCompile(`PackagesIndex$`), CacheControl: `max-age=3600`}, - {Pattern: regexp.MustCompile(`Packages\.(bz2|gz|lzma)$`), CacheControl: `max-age=3600`}, - {Pattern: regexp.MustCompile(`SourcesIndex$`), CacheControl: `max-age=3600`}, - {Pattern: regexp.MustCompile(`Sources\.(bz2|gz|lzma)$`), CacheControl: `max-age=3600`}, - {Pattern: regexp.MustCompile(`Release(\.gpg)?$`), CacheControl: `max-age=3600`}, - {Pattern: regexp.MustCompile(`Translation-(en|fr)\.(gz|bz2|bzip2|lzma)$`), CacheControl: `max-age=3600`}, -} - -type Rule struct { - Pattern *regexp.Regexp - CacheControl string - Rewrite bool -} - -func (r *Rule) String() string { - return fmt.Sprintf("%s Cache-Control=%s Rewrite=%#v", - r.Pattern.String(), r.CacheControl, r.Rewrite) -} - -func matchingRule(subject string, rules []Rule) (*Rule, bool) { - for _, rule := range rules { - if rule.Pattern.MatchString(subject) { - return &rule, true - } - } - - return nil, false -} diff --git a/release/Dockerfile b/release/Dockerfile deleted file mode 100644 index 0adda91..0000000 --- a/release/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -FROM progrium/busybox -ADD apt-proxy /apt-proxy -EXPOSE 3142 -CMD ["/apt-proxy"] \ No newline at end of file diff --git a/ubuntu/mirrors.go b/ubuntu/mirrors.go deleted file mode 100644 index 94ce281..0000000 --- a/ubuntu/mirrors.go +++ /dev/null @@ -1,112 +0,0 @@ -package ubuntu - -import ( - "bufio" - "errors" - "io" - "log" - "net/http" - "time" -) - -const ( - mirrorsUrl = "http://mirrors.ubuntu.com/mirrors.txt" - benchmarkUrl = "dists/saucy/main/binary-amd64/Packages.bz2" - benchmarkTimes = 3 - benchmarkBytes = 1024 * 512 // 512Kb - benchmarkTimeout = 20 // 20 seconds -) - -type Mirrors struct { - URLs []string -} - -func GetGeoMirrors() (m Mirrors, err error) { - response, err := http.Get(mirrorsUrl) - if err != nil { - return - } - - defer response.Body.Close() - scanner := bufio.NewScanner(response.Body) - m.URLs = []string{} - - // read urls line by line - for scanner.Scan() { - m.URLs = append(m.URLs, scanner.Text()) - } - - return m, scanner.Err() -} - -func (m Mirrors) Fastest() (string, error) { - ch := make(chan benchmarkResult) - - // kick off all benchmarks in parallel - for _, url := range m.URLs { - go func(u string) { - duration, err := m.benchmark(u, benchmarkTimes) - if err == nil { - ch <- benchmarkResult{u, duration} - } - }(url) - } - - readN := len(m.URLs) - if 3 < readN { - readN = 3 - } - - // wait for the fastest results to come back - results, err := m.readResults(ch, readN) - if len(results) == 0 { - return "", errors.New("No results found: " + err.Error()) - } else if err != nil { - log.Printf("Error benchmarking mirrors: %s", err.Error()) - } - - return results[0].URL, nil -} - -func (m Mirrors) readResults(ch <-chan benchmarkResult, size int) (br []benchmarkResult, err error) { - for { - select { - case r := <-ch: - br = append(br, r) - if len(br) == size { - return - } - case <-time.After(benchmarkTimeout * time.Second): - return br, errors.New("Timed out waiting for results") - } - } -} - -func (m Mirrors) benchmark(url string, times int) (time.Duration, error) { - var sum int64 - var d time.Duration - url = url + benchmarkUrl - - for i := 0; i < times; i++ { - timer := time.Now() - response, err := http.Get(url) - if err != nil { - return d, err - } - - _, err = io.ReadAtLeast(response.Body, make([]byte, benchmarkBytes), benchmarkBytes) - if err != nil { - return d, err - } - - sum = sum + int64(time.Since(timer)) - response.Body.Close() - } - - return time.Duration(sum / int64(times)), nil -} - -type benchmarkResult struct { - URL string - Duration time.Duration -} diff --git a/ubuntu/rewriter.go b/ubuntu/rewriter.go deleted file mode 100644 index c76b831..0000000 --- a/ubuntu/rewriter.go +++ /dev/null @@ -1,50 +0,0 @@ -package ubuntu - -import ( - "log" - "net/http" - "net/url" - "regexp" -) - -type ubuntuRewriter struct { - mirror *url.URL -} - -var hostPattern = regexp.MustCompile( - `https?://(security|archive).ubuntu.com/ubuntu/(.+)$`, -) - -func NewRewriter() *ubuntuRewriter { - u := &ubuntuRewriter{} - - // benchmark in the background to make sure we have the fastest - go func() { - mirrors, err := GetGeoMirrors() - if err != nil { - log.Fatal(err) - } - - mirror, err := mirrors.Fastest() - if err != nil { - log.Println("Error finding fastest mirror", err) - } - - if mirrorUrl, err := url.Parse(mirror); err == nil { - log.Printf("using ubuntu mirror %s", mirror) - u.mirror = mirrorUrl - } - }() - - return u -} - -func (ur *ubuntuRewriter) Rewrite(r *http.Request) { - url := r.URL.String() - if ur.mirror != nil && hostPattern.MatchString(url) { - r.Header.Add("Content-Location", url) - m := hostPattern.FindAllStringSubmatch(url, -1) - r.URL.Host = ur.mirror.Host - r.URL.Path = ur.mirror.Path + m[0][2] - } -}