diff --git a/lib/files/plugins/plugins.go b/lib/files/plugins/plugins.go index 6f9b62f..a16e75d 100644 --- a/lib/files/plugins/plugins.go +++ b/lib/files/plugins/plugins.go @@ -10,5 +10,4 @@ import ( _ "github.com/puellanivis/breton/lib/files/home" _ "github.com/puellanivis/breton/lib/files/httpfiles" _ "github.com/puellanivis/breton/lib/files/socketfiles" - _ "github.com/puellanivis/breton/lib/files/unixsocket" ) diff --git a/lib/files/socketfiles/dgram.go b/lib/files/socketfiles/dgram.go new file mode 100644 index 0000000..78ffc1c --- /dev/null +++ b/lib/files/socketfiles/dgram.go @@ -0,0 +1,226 @@ +package socketfiles + +import ( + "context" + "io" + "net" + "os" + "sync" + "time" + + "github.com/puellanivis/breton/lib/files/wrapper" +) + +type datagramWriter struct { + *wrapper.Info + + mu sync.Mutex + closed chan struct{} + + noerrs bool + off int + buf []byte + + sock *socket +} + +func (w *datagramWriter) IgnoreErrors(state bool) bool { + w.mu.Lock() + defer w.mu.Unlock() + + prev := w.noerrs + + w.noerrs = state + + return prev +} + +func (w *datagramWriter) err(err error) error { + if w.noerrs && err != io.ErrShortWrite { + return nil + } + + return err +} + +func (w *datagramWriter) SetPacketSize(size int) int { + w.mu.Lock() + defer w.mu.Unlock() + + prev := len(w.buf) + + switch { + case size <= 0: + w.buf = nil + + case size <= len(w.buf): + w.buf = w.buf[:size] + + default: + w.buf = append(w.buf, make([]byte, size-len(w.buf))...) + } + + if w.off > len(w.buf) { + w.off = len(w.buf) + } + + w.sock.packetSize = len(w.buf) + w.sock.updateDelay(len(w.buf)) + + return prev +} + +func (w *datagramWriter) SetBitrate(bitrate int) int { + w.mu.Lock() + defer w.mu.Unlock() + + return w.sock.setBitrate(bitrate, len(w.buf)) +} + +func (w *datagramWriter) Sync() error { + w.mu.Lock() + defer w.mu.Unlock() + + _, err := w.sync() + return w.err(err) +} + +func (w *datagramWriter) sync() (n int, err error) { + if w.off < 1 { + return 0, nil + } + + // zero out the end of the buffer. + b := w.buf[w.off:] + for i := range b { + b[i] = 0 + } + + w.off = 0 + return w.write(w.buf) +} + +func (w *datagramWriter) write(b []byte) (n int, err error) { + // We should have already prescaled the delay, so scale=1 here. + w.sock.throttle(1) + + n, err = w.sock.conn.Write(b) + if n != len(b) { + if (w.noerrs && n > 0) || err == nil { + err = io.ErrShortWrite + } + } + + return n, err +} + +func (w *datagramWriter) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + + select { + case <-w.closed: + default: + close(w.closed) + } + + _, err := w.sync() + + if err2 := w.sock.conn.Close(); err == nil { + err = err2 + } + + return err +} + +func (w *datagramWriter) Write(b []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + if len(w.buf) < 1 { + w.sock.throttle(len(b)) + + n, err = w.sock.conn.Write(b) + return n, w.err(err) + } + + if w.off > 0 { + n = copy(w.buf[w.off:], b) + w.off += n + + if w.off < len(w.buf) { + // The full length of b was copied into buffer, + // and we haven’t filled the buffer. + // So, we’re done. + return n, nil + } + + _, err2 := w.sync() + if err = w.err(err2); err != nil { + return n, err + } + + b = b[n:] + } + + sz := len(w.buf) + for len(b) >= sz { + n2, err2 := w.write(b[:sz]) + n += n2 + + if err = w.err(err2); err != nil { + return n, err + } + + // skip the whole packet size, even if n2 < sz + b = b[sz:] + } + + if len(b) > 0 { + w.off = copy(w.buf, b) + n += w.off + } + + return n, nil +} + +func newDatagramWriter(ctx context.Context, sock *socket) *datagramWriter { + var buf []byte + if sock.packetSize > 0 { + buf = make([]byte, sock.packetSize) + } + + w := &datagramWriter{ + Info: wrapper.NewInfo(sock.uri(), 0, time.Now()), + sock: sock, + + closed: make(chan struct{}), + buf: buf, + } + + go func() { + select { + case <-w.closed: + case <-ctx.Done(): + w.Close() + } + }() + + return w +} + +type datagramReader struct { + *wrapper.Info + net.Conn +} + +func (r *datagramReader) Seek(offset int64, whence int) (int64, error) { + return 0, os.ErrInvalid +} + +func newDatagramReader(ctx context.Context, sock *socket) *datagramReader { + return &datagramReader{ + Info: wrapper.NewInfo(sock.uri(), 0, time.Now()), + Conn: sock.conn, + } +} diff --git a/lib/files/socketfiles/socket.go b/lib/files/socketfiles/socket.go index 752fb80..d5d1e15 100644 --- a/lib/files/socketfiles/socket.go +++ b/lib/files/socketfiles/socket.go @@ -28,131 +28,216 @@ const ( FieldTTL = "ttl" ) -type ipSocket struct { - laddr, raddr net.Addr +type socket struct { + conn net.Conn + + addr, qaddr net.Addr bufferSize int + packetSize int tos, ttl int throttler } -func (s *ipSocket) uriQuery() url.Values { +func (s *socket) uri() *url.URL { + q := s.uriQuery() + + switch qaddr := s.qaddr.(type) { + case *net.TCPAddr: + q.Set(FieldLocalAddress, qaddr.IP.String()) + q.Set(FieldLocalPort, strconv.Itoa(qaddr.Port)) + + case *net.UDPAddr: + q.Set(FieldLocalAddress, qaddr.IP.String()) + q.Set(FieldLocalPort, strconv.Itoa(qaddr.Port)) + + case *net.UnixAddr: + q.Set(FieldLocalAddress, qaddr.String()) + } + + host, path := s.addr.String(), "" + + switch s.addr.Network() { + case "unix", "unixgram", "unixpacket": + host, path = "", host + } + + return &url.URL{ + Scheme: s.addr.Network(), + Host: host, + Path: path, + RawQuery: q.Encode(), + } +} + +func (s *socket) uriQuery() url.Values { q := make(url.Values) if s.bitrate > 0 { - setInt(q, FieldMaxBitrate, s.bitrate) + q.Set(FieldMaxBitrate, strconv.Itoa(s.bitrate)) } if s.bufferSize > 0 { - setInt(q, FieldBufferSize, s.bufferSize) + q.Set(FieldBufferSize, strconv.Itoa(s.bufferSize)) } - if s.tos > 0 { - q.Set(FieldTOS, "0x"+strconv.FormatInt(int64(s.tos), 16)) + network := s.addr.Network() + + switch network { + case "udp", "udp4", "udp6", "unixgram", "unixpacket": + if s.packetSize > 0 { + q.Set(FieldPacketSize, strconv.Itoa(s.packetSize)) + } } - if s.ttl > 0 { - setInt(q, FieldTTL, s.ttl) + switch network { + case "udp", "udp4", "tcp", "tcp4": + if s.tos > 0 { + q.Set(FieldTOS, "0x"+strconv.FormatInt(int64(s.tos), 16)) + } + + if s.ttl > 0 { + q.Set(FieldTTL, strconv.Itoa(s.ttl)) + } } return q } -func (s *ipSocket) setForReader(conn net.Conn, q url.Values) error { - s.laddr = conn.LocalAddr() - - type bufferSizeSetter interface { - SetReadBuffer(int) error +func sockReader(conn net.Conn, q url.Values) (*socket, error) { + bufferSize, err := getSize(q, FieldBufferSize) + if err != nil { + return nil, err } - if bufferSize, ok, err := getSize(q, FieldBufferSize); ok || err != nil { - if err != nil { - return err + + if bufferSize > 0 { + type readBufferSetter interface { + SetReadBuffer(int) error } - conn, ok := conn.(bufferSizeSetter) + conn, ok := conn.(readBufferSetter) if !ok { - return syscall.EINVAL + return nil, syscall.EINVAL } if err := conn.SetReadBuffer(bufferSize); err != nil { - return err + return nil, err } - - s.bufferSize = bufferSize } - return nil + return &socket{ + conn: conn, + + addr: conn.LocalAddr(), + + bufferSize: bufferSize, + }, nil } -func (s *ipSocket) setForWriter(conn net.Conn, q url.Values) error { - s.laddr = conn.LocalAddr() - s.raddr = conn.RemoteAddr() +func sockWriter(conn net.Conn, showLocalAddr bool, q url.Values) (*socket, error) { + raddr := conn.RemoteAddr() - if err := s.setThrottle(q); err != nil { - return err + bufferSize, err := getSize(q, FieldBufferSize) + if err != nil { + return nil, err } - type bufferSizeSetter interface { - SetWriteBuffer(int) error - } - if bufferSize, ok, err := getSize(q, FieldBufferSize); ok || err != nil { - if err != nil { - return err + if bufferSize > 0 { + type writeBufferSetter interface { + SetWriteBuffer(int) error } - conn, ok := conn.(bufferSizeSetter) + conn, ok := conn.(writeBufferSetter) if !ok { - return syscall.EINVAL + return nil, syscall.EINVAL } if err := conn.SetWriteBuffer(bufferSize); err != nil { - return err + return nil, err } + } - s.bufferSize = bufferSize + var packetSize int + switch raddr.Network() { + case "udp", "udp4", "udp6", "unixgram", "unixpacket": + packetSize, err = getSize(q, FieldPacketSize) + if err != nil { + return nil, err + } + } + + bitrate, err := getSize(q, FieldMaxBitrate) + if err != nil { + return nil, err } - var p *ipv4.Conn + var t throttler + if bitrate > 0 { + t.setBitrate(bitrate, packetSize) + } + + var tos, ttl int + + switch raddr.Network() { + case "udp", "udp4", "tcp", "tcp4": + var p *ipv4.Conn - if tos, ok, err := getInt(q, FieldTOS); ok || err != nil { + tos, err = getInt(q, FieldTOS) if err != nil { - return err + return nil, err } - if p == nil { - p = ipv4.NewConn(conn) - } + if tos > 0 { + if p == nil { + p = ipv4.NewConn(conn) + } - if err := p.SetTOS(tos); err != nil { - return err - } + if err := p.SetTOS(tos); err != nil { + return nil, err + } - s.tos, _ = p.TOS() - } + tos, _ = p.TOS() + } - if ttl, ok, err := getInt(q, FieldTTL); ok || err != nil { + ttl, err = getInt(q, FieldTTL) if err != nil { - return err + return nil, err } - if p == nil { - p = ipv4.NewConn(conn) - } + if ttl > 0 { + if p == nil { + p = ipv4.NewConn(conn) + } + + if err := p.SetTTL(ttl); err != nil { + return nil, err + } - if err := p.SetTTL(ttl); err != nil { - return err + ttl, _ = p.TTL() } + } - s.ttl, _ = p.TTL() + var laddr net.Addr + if showLocalAddr { + laddr = conn.LocalAddr() } - return nil -} + return &socket{ + conn: conn, + + addr: raddr, + qaddr: laddr, + + bufferSize: bufferSize, + packetSize: packetSize, + + tos: tos, + ttl: ttl, -func setInt(q url.Values, field string, val int) { - q.Set(field, strconv.Itoa(val)) + throttler: t, + }, nil } var scales = map[byte]int{ @@ -164,65 +249,46 @@ var scales = map[byte]int{ 'k': 1000, } -func getSize(q url.Values, field string) (val int, specified bool, err error) { - s := q.Get(field) - if s == "" { - return 0, false, nil +func getSize(q url.Values, field string) (val int, err error) { + value := q.Get(field) + if value == "" { + return 0, nil } - suffix := s[len(s)-1] + suffix := value[len(value)-1] scale := 1 - if val, ok := scales[suffix]; ok { - scale = val - s = s[:len(s)-1] + if s := scales[suffix]; s > 0 { + scale = s + value = value[:len(value)-1] } - i, err := strconv.ParseInt(s, 0, strconv.IntSize) + i, err := strconv.ParseInt(value, 0, strconv.IntSize) if err != nil { - return 0, true, err + return 0, err } - return int(i) * scale, true, nil + return int(i) * scale, nil } -func getInt(q url.Values, field string) (val int, specified bool, err error) { - s := q.Get(field) - if s == "" { - return 0, false, nil +func getInt(q url.Values, field string) (val int, err error) { + value := q.Get(field) + if value == "" { + return 0, nil } - i, err := strconv.ParseInt(s, 0, strconv.IntSize) + i, err := strconv.ParseInt(value, 0, strconv.IntSize) if err != nil { - return 0, true, err - } - - return int(i), true, nil -} - -func buildAddr(addr, portString string) (ip net.IP, port int, err error) { - if addr != "" { - ip = net.ParseIP(addr) - if ip == nil { - return nil, 0, errInvalidIP - } - } - - if portString != "" { - p, err := strconv.ParseInt(portString, 10, strconv.IntSize) - if err != nil { - return nil, 0, err - } - - port = int(p) + return 0, err } - return ip, port, nil + return int(i), nil } -func withContext(ctx context.Context, fn func() error) (err error) { +func do(ctx context.Context, fn func() error) error { done := make(chan struct{}) + var err error go func() { defer close(done) diff --git a/lib/files/socketfiles/stream.go b/lib/files/socketfiles/stream.go new file mode 100644 index 0000000..01c5bdf --- /dev/null +++ b/lib/files/socketfiles/stream.go @@ -0,0 +1,178 @@ +package socketfiles + +import ( + "context" + "net" + "net/url" + "os" + "sync" + "time" + + "github.com/puellanivis/breton/lib/files" + "github.com/puellanivis/breton/lib/files/wrapper" +) + +type streamWriter struct { + *wrapper.Info + + mu sync.Mutex + closed chan struct{} + + sock *socket +} + +func (w *streamWriter) SetBitrate(bitrate int) int { + w.mu.Lock() + defer w.mu.Unlock() + + return w.sock.setBitrate(bitrate, 1) +} + +func (w *streamWriter) Sync() error { + return nil +} + +func (w *streamWriter) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + + select { + case <-w.closed: + default: + close(w.closed) + } + + return w.sock.conn.Close() +} + +func (w *streamWriter) Write(b []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + w.sock.throttle(len(b)) + + return w.sock.conn.Write(b) +} + +func (w *streamWriter) uri() *url.URL { + return w.sock.uri() +} + +func newStreamWriter(ctx context.Context, sock *socket) *streamWriter { + w := &streamWriter{ + Info: wrapper.NewInfo(sock.uri(), 0, time.Now()), + sock: sock, + + closed: make(chan struct{}), + } + + go func() { + select { + case <-w.closed: + case <-ctx.Done(): + w.Close() + } + }() + + return w +} + +type streamReader struct { + *wrapper.Info + + loading <-chan struct{} + + err error + conn net.Conn +} + +func (r *streamReader) Read(b []byte) (n int, err error) { + for range r.loading { + } + + if r.err != nil { + return 0, r.err + } + + return r.conn.Read(b) +} + +func (r *streamReader) Seek(offset int64, whence int) (int64, error) { + return 0, os.ErrInvalid +} + +func (r *streamReader) Close() error { + for range r.loading { + } + + // Never connected, so just return nil. + if r.conn == nil { + return nil + } + + // Ignore the r.err, as it is a request-scope error, and not relevant to closing. + + return r.conn.Close() +} + +func newStreamReader(ctx context.Context, l net.Listener) (*streamReader, error) { + // Maybe we asked for an arbitrary port, + // so, refresh our address to the one we’re actually listening on. + laddr := l.Addr() + + host, path := laddr.String(), "" + switch laddr.Network() { + case "unix": + host, path = "", host + } + + uri := &url.URL{ + Scheme: laddr.Network(), + Host: host, + Path: path, + } + + loading := make(chan struct{}) + r := &streamReader{ + Info: wrapper.NewInfo(uri, 0, time.Now()), + + loading: loading, + } + + go func() { + defer close(loading) + defer l.Close() + + select { + case loading <- struct{}{}: + case <-ctx.Done(): + r.err = files.PathError("open", uri.String(), ctx.Err()) + return + } + + var conn net.Conn + accept := func() error { + var err error + + conn, err = l.Accept() + + return err + } + + if err := do(ctx, accept); err != nil { + r.err = files.PathError("accept", uri.String(), err) + return + } + + // TODO: make the a configurable option? + /* if err := conn.CloseWrite(); err != nil { + conn.Close() + r.err = err + return + } */ + + r.conn = conn + }() + + return r, nil +} diff --git a/lib/files/socketfiles/tcp.go b/lib/files/socketfiles/tcp.go index 68c2978..83cb8f4 100644 --- a/lib/files/socketfiles/tcp.go +++ b/lib/files/socketfiles/tcp.go @@ -5,11 +5,8 @@ import ( "net" "net/url" "os" - "sync" - "time" "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/files/wrapper" ) type tcpHandler struct{} @@ -18,69 +15,22 @@ func init() { files.RegisterScheme(&tcpHandler{}, "tcp") } -type tcpWriter struct { - mu sync.Mutex - - closed chan struct{} - - conn *net.TCPConn - *wrapper.Info - ipSocket -} - -func (w *tcpWriter) SetBitrate(bitrate int) int { - w.mu.Lock() - defer w.mu.Unlock() - - prev := w.bitrate - - w.bitrate = bitrate - w.updateDelay(1) - - return prev -} - -func (w *tcpWriter) Sync() error { - return nil -} - -func (w *tcpWriter) Close() error { - w.mu.Lock() - defer w.mu.Unlock() - - select { - case <-w.closed: - default: - close(w.closed) +func (h *tcpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { + if uri.Host == "" { + return nil, files.PathError("open", uri.String(), errInvalidURL) } - return w.conn.Close() -} - -func (w *tcpWriter) Write(b []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - - w.throttle(len(b)) - - return w.conn.Write(b) -} - -func (w *tcpWriter) uri() *url.URL { - q := w.ipSocket.uriQuery() - - if w.laddr != nil { - laddr := w.laddr.(*net.TCPAddr) - - q.Set(FieldLocalAddress, laddr.IP.String()) - setInt(q, FieldLocalPort, laddr.Port) + laddr, err := net.ResolveTCPAddr("tcp", uri.Host) + if err != nil { + return nil, files.PathError("open", uri.String(), err) } - return &url.URL{ - Scheme: "tcp", - Host: w.raddr.String(), - RawQuery: q.Encode(), + l, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, files.PathError("open", uri.String(), err) } + + return newStreamReader(ctx, l) } func (h *tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { @@ -88,10 +38,6 @@ func (h *tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er return nil, files.PathError("create", uri.String(), errInvalidURL) } - w := &tcpWriter{ - closed: make(chan struct{}), - } - raddr, err := net.ResolveTCPAddr("tcp", uri.Host) if err != nil { return nil, files.PathError("create", uri.String(), err) @@ -99,49 +45,37 @@ func (h *tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er q := uri.Query() - port := q.Get(FieldLocalPort) - addr := q.Get(FieldLocalAddress) - var laddr *net.TCPAddr - if port != "" || addr != "" { - laddr = new(net.TCPAddr) - - laddr.IP, laddr.Port, err = buildAddr(addr, port) + host := q.Get(FieldLocalAddress) + port := q.Get(FieldLocalPort) + if host != "" || port != "" { + laddr, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port)) if err != nil { return nil, files.PathError("create", uri.String(), err) } } - dail := func() error { + var conn *net.TCPConn + dial := func() error { var err error - w.conn, err = net.DialTCP("tcp", laddr, raddr) + conn, err = net.DialTCP("tcp", laddr, raddr) return err } - if err := withContext(ctx, dail); err != nil { + if err := do(ctx, dial); err != nil { return nil, files.PathError("create", uri.String(), err) } - go func() { - select { - case <-w.closed: - case <-ctx.Done(): - w.Close() - } - }() - - if err := w.ipSocket.setForWriter(w.conn, q); err != nil { - w.Close() + sock, err := sockWriter(conn, laddr != nil, q) + if err != nil { + conn.Close() return nil, files.PathError("create", uri.String(), err) } - w.updateDelay(1) - w.Info = wrapper.NewInfo(w.uri(), 0, time.Now()) - - return w, nil + return newStreamWriter(ctx, sock), nil } func (h *tcpHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { diff --git a/lib/files/socketfiles/tcp_test.go b/lib/files/socketfiles/tcp_test.go index 53db281..63f18f3 100644 --- a/lib/files/socketfiles/tcp_test.go +++ b/lib/files/socketfiles/tcp_test.go @@ -9,48 +9,61 @@ import ( ) func TestTCPName(t *testing.T) { - w := &tcpWriter{ - ipSocket: ipSocket{ - laddr: &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 65535, - }, - raddr: &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 80, - }, - bufferSize: 1024, - ttl: 100, - tos: 0x80, - - throttler: throttler{ - bitrate: 2048, - }, + sock := &socket{ + qaddr: &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 65535, + }, + addr: &net.TCPAddr{ + IP: []byte{127, 0, 0, 2}, + Port: 80, + }, + + packetSize: 188, // should not show up + bufferSize: 1024, + + ttl: 100, + tos: 0x80, + + throttler: throttler{ + bitrate: 2048, + }, + } + + uri := sock.uri() + expected := "tcp://127.0.0.2:80?buffer_size=1024&localaddr=127.0.0.1&localport=65535&max_bitrate=2048&tos=0x80&ttl=100" + + if s := uri.String(); s != expected { + t.Errorf("got a bad URI, was expecting, but got:\n\t%v\n\t%v", expected, s) + } + + sock = &socket{ + qaddr: &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 65534, + }, + addr: &net.TCPAddr{ + IP: []byte{127, 0, 0, 2}, + Port: 443, }, } - uri := w.uri() - expected := "tcp://127.0.0.1:80?buffer_size=1024&localaddr=127.0.0.1&localport=65535&max_bitrate=2048&tos=0x80&ttl=100" + uri = sock.uri() + expected = "tcp://127.0.0.2:443?localaddr=127.0.0.1&localport=65534" if s := uri.String(); s != expected { t.Errorf("got a bad URI, was expecting, but got:\n\t%v\n\t%v", expected, s) } - w = &tcpWriter{ - ipSocket: ipSocket{ - laddr: &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 65534, - }, - raddr: &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 443, - }, + sock = &socket{ + addr: &net.TCPAddr{ + IP: []byte{127, 0, 0, 2}, + Port: 8080, }, } - uri = w.uri() - expected = "tcp://127.0.0.1:443?localaddr=127.0.0.1&localport=65534" + uri = sock.uri() + expected = "tcp://127.0.0.2:8080" if s := uri.String(); s != expected { t.Errorf("got a bad URI, was expecting, but got:\n\t%v\n\t%v", expected, s) diff --git a/lib/files/socketfiles/tcpreader.go b/lib/files/socketfiles/tcpreader.go deleted file mode 100644 index 06d0fae..0000000 --- a/lib/files/socketfiles/tcpreader.go +++ /dev/null @@ -1,100 +0,0 @@ -package socketfiles - -import ( - "context" - "net" - "net/url" - "os" - "time" - - "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/files/wrapper" -) - -type tcpReader struct { - conn *net.TCPConn - *wrapper.Info - - err error - loading <-chan struct{} -} - -func (r *tcpReader) Read(b []byte) (n int, err error) { - for range r.loading { - } - - if r.err != nil { - return 0, r.err - } - - return r.conn.Read(b) -} - -func (r *tcpReader) Seek(offset int64, whence int) (int64, error) { - return 0, os.ErrInvalid -} - -func (r *tcpReader) Close() error { - for range r.loading { - } - - // Ignore the r.err, as it is a request-scope error, and not relevant to closing. - - return r.conn.Close() -} - -func (h *tcpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { - if uri.Host == "" { - return nil, files.PathError("open", uri.String(), errInvalidURL) - } - - laddr, err := net.ResolveTCPAddr("tcp", uri.Host) - if err != nil { - return nil, files.PathError("open", uri.String(), err) - } - - l, err := net.ListenTCP("tcp", laddr) - if err != nil { - return nil, files.PathError("open", uri.String(), err) - } - - // Maybe we asked for an arbitrary port, - // so, we build our own copy of the URL, and use that. - lurl := &url.URL{ - Host: l.Addr().String(), - } - - loading := make(chan struct{}) - r := &tcpReader{ - loading: loading, - Info: wrapper.NewInfo(lurl, 0, time.Now()), - } - - go func() { - defer close(loading) - defer l.Close() - - select { - case loading <- struct{}{}: - case <-ctx.Done(): - r.err = files.PathError("open", uri.String(), ctx.Err()) - return - } - - conn, err := l.AcceptTCP() - if err != nil { - r.err = files.PathError("accept", uri.String(), err) - return - } - - /* if err := conn.CloseWrite(); err != nil { - conn.Close() - r.err = err - return - } */ - - r.conn = conn - }() - - return r, nil -} diff --git a/lib/files/socketfiles/throttling.go b/lib/files/socketfiles/throttling.go index 68682b2..7438bf0 100644 --- a/lib/files/socketfiles/throttling.go +++ b/lib/files/socketfiles/throttling.go @@ -1,7 +1,6 @@ package socketfiles import ( - "net/url" "time" ) @@ -12,16 +11,33 @@ type throttler struct { next *time.Timer } +func (t *throttler) drain() { + if t.next == nil { + return + } + + if !t.next.Stop() { + <-t.next.C + } +} + func (t *throttler) updateDelay(prescale int) { if t.bitrate <= 0 { t.delay = 0 + t.drain() t.next = nil return } + if t.next != nil { + t.drain() + t.next.Reset(0) + } else { + t.next = time.NewTimer(0) + } + // delay = nanoseconds per byte t.delay = (8 * time.Second) / time.Duration(t.bitrate) - t.next = time.NewTimer(0) // recalculate to the actual expected maximum bitrate t.bitrate = int(8 * time.Second / t.delay) @@ -46,14 +62,11 @@ func (t *throttler) throttle(scale int) { t.next.Reset(t.delay) } -func (t *throttler) setThrottle(q url.Values) error { - if bitrate, ok, err := getSize(q, FieldMaxBitrate); ok || err != nil { - if err != nil { - return err - } +func (t *throttler) setBitrate(bitrate, prescale int) int { + prev := t.bitrate - t.bitrate = bitrate - } + t.bitrate = bitrate + t.updateDelay(prescale) - return nil + return prev } diff --git a/lib/files/socketfiles/udp.go b/lib/files/socketfiles/udp.go index bef2151..47e0e54 100644 --- a/lib/files/socketfiles/udp.go +++ b/lib/files/socketfiles/udp.go @@ -2,15 +2,11 @@ package socketfiles import ( "context" - "io" "net" "net/url" "os" - "sync" - "time" "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/files/wrapper" ) type udpHandler struct{} @@ -19,205 +15,32 @@ func init() { files.RegisterScheme(&udpHandler{}, "udp") } -type udpWriter struct { - mu sync.Mutex - - closed chan struct{} - - conn *net.UDPConn - *wrapper.Info - ipSocket - - noerrs bool - - off int - buf []byte -} - -func (w *udpWriter) IgnoreErrors(state bool) bool { - w.mu.Lock() - defer w.mu.Unlock() - - prev := w.noerrs - - w.noerrs = state - - return prev -} - -func (w *udpWriter) err(err error) error { - if w.noerrs && err != io.ErrShortWrite { - return nil - } - - return err -} - -func (w *udpWriter) SetPacketSize(size int) int { - w.mu.Lock() - defer w.mu.Unlock() - - prev := len(w.buf) - - w.buf = nil - if size > 0 { - w.buf = make([]byte, size) - } - - w.updateDelay(len(w.buf)) - - return prev -} - -func (w *udpWriter) SetBitrate(bitrate int) int { - w.mu.Lock() - defer w.mu.Unlock() - - prev := w.bitrate - - w.bitrate = bitrate - w.updateDelay(len(w.buf)) - - return prev -} - -func (w *udpWriter) Sync() error { - w.mu.Lock() - defer w.mu.Unlock() - - return w.err(w.sync()) -} - -func (w *udpWriter) sync() error { - if w.off < 1 { - return nil - } - - // zero out the end of the buffer. - for i := w.off; i < len(w.buf); i++ { - w.buf[i] = 0 - } - - w.off = 0 - _, err := w.mustWrite(w.buf) - return err -} - -func (w *udpWriter) mustWrite(b []byte) (n int, err error) { - // We should have already prescaled the delay, so scale=1 here. - w.throttle(1) - - n, err = w.conn.Write(b) - if n != len(b) { - if (w.noerrs && n > 0) || err == nil { - err = io.ErrShortWrite - } - } - - return n, err -} - -func (w *udpWriter) Close() error { - w.mu.Lock() - defer w.mu.Unlock() - - err := w.sync() - - select { - case <-w.closed: - default: - close(w.closed) - } - - if err2 := w.conn.Close(); err == nil { - err = err2 - } - - return err -} - -func (w *udpWriter) Write(b []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - - if len(w.buf) < 1 { - w.throttle(len(b)) - - n, err = w.conn.Write(b) - return n, w.err(err) - } - - if w.off > 0 { - n = copy(w.buf[w.off:], b) - w.off += n - - if w.off < len(w.buf) { - // The full length of b was copied into buffer, - // and we haven’t filled the buffer. - // So, we’re done. - return n, nil - } - - w.off = 0 - b = b[n:] - - n2, err2 := w.mustWrite(w.buf) - if err = w.err(err2); err != nil { - if n2 > 0 { - w.off = copy(w.buf, w.buf[n2:]) - } - - /*n -= len(w.buf) - n2 - if n < 0 { - n = 0 - } */ - - return n, err - } +func (h *udpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { + if uri.Host == "" { + return nil, files.PathError("open", uri.String(), errInvalidURL) } - sz := len(w.buf) - - for len(b) >= sz { - n2, err2 := w.mustWrite(b[:sz]) - n += n2 - - if err = w.err(err2); err != nil { - return n, err - } - - // skip the whole packet size, even on a short write. - b = b[sz:] + laddr, err := net.ResolveUDPAddr("udp", uri.Host) + if err != nil { + return nil, files.PathError("open", uri.String(), err) } - if len(b) > 0 { - n2 := copy(w.buf, b) - w.off += n2 - n += n2 + conn, err := net.ListenUDP("udp", laddr) + if err != nil { + return nil, files.PathError("open", uri.String(), err) } - return n, nil -} + // Maybe we asked for an arbitrary port, + // so, refresh our address to the one we’re actually listening on. + laddr = conn.LocalAddr().(*net.UDPAddr) -func (w *udpWriter) uri() *url.URL { - q := w.ipSocket.uriQuery() - - if w.laddr != nil { - laddr := w.laddr.(*net.UDPAddr) - - q.Set(FieldLocalAddress, laddr.IP.String()) - setInt(q, FieldLocalPort, laddr.Port) - } - - if len(w.buf) > 0 { - setInt(q, FieldPacketSize, len(w.buf)) + sock, err := sockReader(conn, uri.Query()) + if err != nil { + conn.Close() + return nil, files.PathError("open", uri.String(), err) } - return &url.URL{ - Scheme: "udp", - Host: w.raddr.String(), - RawQuery: q.Encode(), - } + return newDatagramReader(ctx, sock), nil } func (h *udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { @@ -225,10 +48,6 @@ func (h *udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er return nil, files.PathError("create", uri.String(), errInvalidURL) } - w := &udpWriter{ - closed: make(chan struct{}), - } - raddr, err := net.ResolveUDPAddr("udp", uri.Host) if err != nil { return nil, files.PathError("create", uri.String(), err) @@ -236,58 +55,37 @@ func (h *udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er q := uri.Query() - port := q.Get(FieldLocalPort) - addr := q.Get(FieldLocalAddress) - var laddr *net.UDPAddr - if port != "" || addr != "" { - laddr = new(net.UDPAddr) - - laddr.IP, laddr.Port, err = buildAddr(addr, port) + host := q.Get(FieldLocalAddress) + port := q.Get(FieldLocalPort) + if host != "" || port != "" { + laddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(host, port)) if err != nil { return nil, files.PathError("create", uri.String(), err) } } - dail := func() error { + var conn *net.UDPConn + dial := func() error { var err error - w.conn, err = net.DialUDP("udp", laddr, raddr) + conn, err = net.DialUDP("udp", laddr, raddr) return err } - if err := withContext(ctx, dail); err != nil { + if err := do(ctx, dial); err != nil { return nil, files.PathError("create", uri.String(), err) } - go func() { - select { - case <-w.closed: - case <-ctx.Done(): - w.Close() - } - }() - - if err := w.ipSocket.setForWriter(w.conn, q); err != nil { - w.Close() + sock, err := sockWriter(conn, laddr != nil, q) + if err != nil { + conn.Close() return nil, files.PathError("create", uri.String(), err) } - if pktSize, ok, err := getSize(q, FieldPacketSize); ok || err != nil { - if err != nil { - w.Close() - return nil, files.PathError("create", uri.String(), err) - } - - w.buf = make([]byte, pktSize) - } - - w.updateDelay(len(w.buf)) - w.Info = wrapper.NewInfo(w.uri(), 0, time.Now()) - - return w, nil + return newDatagramWriter(ctx, sock), nil } func (h *udpHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { diff --git a/lib/files/socketfiles/udp_test.go b/lib/files/socketfiles/udp_test.go index e4a5bf5..cc8c7ed 100644 --- a/lib/files/socketfiles/udp_test.go +++ b/lib/files/socketfiles/udp_test.go @@ -9,49 +9,61 @@ import ( ) func TestUDPName(t *testing.T) { - w := &udpWriter{ - ipSocket: ipSocket{ - laddr: &net.UDPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 65535, - }, - raddr: &net.UDPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 80, - }, - bufferSize: 1024, - ttl: 100, - tos: 0x80, - - throttler: throttler{ - bitrate: 2048, - }, + sock := &socket{ + qaddr: &net.UDPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 65535, + }, + addr: &net.UDPAddr{ + IP: []byte{127, 0, 0, 2}, + Port: 80, + }, + + packetSize: 188, + bufferSize: 1024, + + ttl: 100, + tos: 0x80, + + throttler: throttler{ + bitrate: 2048, + }, + } + + uri := sock.uri() + expected := "udp://127.0.0.2:80?buffer_size=1024&localaddr=127.0.0.1&localport=65535&max_bitrate=2048&pkt_size=188&tos=0x80&ttl=100" + + if s := uri.String(); s != expected { + t.Errorf("got a bad URI, was expecting, but got:\n\t%v\n\t%v", expected, s) + } + + sock = &socket{ + qaddr: &net.UDPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 65534, + }, + addr: &net.UDPAddr{ + IP: []byte{127, 0, 0, 2}, + Port: 443, }, - buf: make([]byte, 188), } - uri := w.uri() - expected := "udp://127.0.0.1:80?buffer_size=1024&localaddr=127.0.0.1&localport=65535&max_bitrate=2048&pkt_size=188&tos=0x80&ttl=100" + uri = sock.uri() + expected = "udp://127.0.0.2:443?localaddr=127.0.0.1&localport=65534" if s := uri.String(); s != expected { t.Errorf("got a bad URI, was expecting, but got:\n\t%v\n\t%v", expected, s) } - w = &udpWriter{ - ipSocket: ipSocket{ - laddr: &net.UDPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 65534, - }, - raddr: &net.UDPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: 443, - }, + sock = &socket{ + addr: &net.UDPAddr{ + IP: []byte{127, 0, 0, 2}, + Port: 8080, }, } - uri = w.uri() - expected = "udp://127.0.0.1:443?localaddr=127.0.0.1&localport=65534" + uri = sock.uri() + expected = "udp://127.0.0.2:8080" if s := uri.String(); s != expected { t.Errorf("got a bad URI, was expecting, but got:\n\t%v\n\t%v", expected, s) diff --git a/lib/files/socketfiles/udpreader.go b/lib/files/socketfiles/udpreader.go deleted file mode 100644 index 040e79d..0000000 --- a/lib/files/socketfiles/udpreader.go +++ /dev/null @@ -1,69 +0,0 @@ -package socketfiles - -import ( - "context" - "net" - "net/url" - "os" - "time" - - "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/files/wrapper" -) - -type udpReader struct { - conn *net.UDPConn - *wrapper.Info - ipSocket -} - -func (r *udpReader) Read(b []byte) (n int, err error) { - return r.conn.Read(b) -} - -func (r *udpReader) Seek(offset int64, whence int) (int64, error) { - return 0, os.ErrInvalid -} - -func (r *udpReader) Close() error { - return r.conn.Close() -} - -func (r *udpReader) uri() *url.URL { - q := r.ipSocket.uriQuery() - - return &url.URL{ - Scheme: "udp", - Host: r.laddr.String(), - RawQuery: q.Encode(), - } -} - -func (h *udpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { - if uri.Host == "" { - return nil, files.PathError("open", uri.String(), errInvalidURL) - } - - r := new(udpReader) - - laddr, err := net.ResolveUDPAddr("udp", uri.Host) - if err != nil { - return nil, files.PathError("open", uri.String(), err) - } - - q := uri.Query() - - r.conn, err = net.ListenUDP("udp", laddr) - if err != nil { - return nil, files.PathError("open", uri.String(), err) - } - - if err := r.ipSocket.setForReader(r.conn, q); err != nil { - r.conn.Close() - return nil, files.PathError("open", uri.String(), err) - } - - r.Info = wrapper.NewInfo(r.uri(), 0, time.Now()) - - return r, nil -} diff --git a/lib/files/socketfiles/unixsock.go b/lib/files/socketfiles/unixsock.go new file mode 100644 index 0000000..ea1c630 --- /dev/null +++ b/lib/files/socketfiles/unixsock.go @@ -0,0 +1,115 @@ +package socketfiles + +import ( + "context" + "errors" + "net" + "net/url" + "os" + + "github.com/puellanivis/breton/lib/files" +) + +type unixHandler struct{} + +func init() { + files.RegisterScheme(&unixHandler{}, "unix", "unixgram") +} + +func (h *unixHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { + path := uri.Path + if path == "" { + path = uri.Opaque + } + network := uri.Scheme + + laddr, err := net.ResolveUnixAddr(network, path) + if err != nil { + return nil, files.PathError("open", uri.String(), err) + } + + switch laddr.Network() { + case "unixgram": + conn, err := net.ListenUnixgram(network, laddr) + if err != nil { + return nil, files.PathError("open", uri.String(), err) + } + + sock, err := sockReader(conn, uri.Query()) + if err != nil { + conn.Close() + return nil, files.PathError("open", uri.String(), err) + } + + return newDatagramReader(ctx, sock), nil + + case "unix": + l, err := net.ListenUnix(network, laddr) + if err != nil { + return nil, files.PathError("open", uri.String(), err) + } + + return newStreamReader(ctx, l) + } + + return nil, files.PathError("create", uri.String(), errors.New("unknown unix socket type")) +} + +func (h *unixHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { + path := uri.Path + if path == "" { + path = uri.Opaque + } + network := uri.Scheme + + raddr, err := net.ResolveUnixAddr(network, path) + if err != nil { + return nil, err + } + + q := uri.Query() + + var laddr *net.UnixAddr + + addr := q.Get(FieldLocalAddress) + if addr != "" { + laddr, err = net.ResolveUnixAddr(network, addr) + if err != nil { + return nil, files.PathError("create", uri.String(), err) + } + } + + var conn *net.UnixConn + dial := func() error { + var err error + + conn, err = net.DialUnix(network, laddr, raddr) + + return err + } + + if err := do(ctx, dial); err != nil { + return nil, files.PathError("create", uri.String(), err) + } + + sock, err := sockWriter(conn, laddr != nil, q) + if err != nil { + conn.Close() + return nil, files.PathError("create", uri.String(), err) + } + + switch network { + case "unix": + return newStreamWriter(ctx, sock), nil + + case "unixgram", "unixpacket": + return newDatagramWriter(ctx, sock), nil + } + + conn.Close() + return nil, files.PathError("create", uri.String(), errors.New("unknown unix socket type")) +} + +func (h *unixHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { + return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) +} diff --git a/lib/files/unixsocket/reader.go b/lib/files/unixsocket/reader.go deleted file mode 100644 index 778ef27..0000000 --- a/lib/files/unixsocket/reader.go +++ /dev/null @@ -1,93 +0,0 @@ -package unixsocket - -import ( - "context" - "net" - "net/url" - "os" - "time" - - "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/files/wrapper" -) - -type reader struct { - conn *net.UnixConn - *wrapper.Info - - err error - loading <-chan struct{} -} - -func (r *reader) Read(b []byte) (n int, err error) { - for range r.loading { - } - - if r.err != nil { - return 0, r.err - } - - return r.conn.Read(b) -} - -func (r *reader) Seek(offset int64, whence int) (int64, error) { - return 0, os.ErrInvalid -} - -func (r *reader) Close() error { - for range r.loading { - } - - // Ignore the r.err, as it is a request-scope error, and not relevant to closing. - - return r.conn.Close() -} - -func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { - path := uri.Path - if path == "" { - path = uri.Opaque - } - - laddr, err := net.ResolveUnixAddr("unix", path) - if err != nil { - return nil, err - } - - fixURL := &url.URL{ - Scheme: "unix", - Opaque: laddr.String(), - } - - l, err := net.ListenUnix("unix", laddr) - if err != nil { - return nil, err - } - - loading := make(chan struct{}) - r := &reader{ - loading: loading, - Info: wrapper.NewInfo(fixURL, 0, time.Now()), - } - - go func() { - defer close(loading) - defer l.Close() - - select { - case loading <- struct{}{}: - case <-ctx.Done(): - r.err = ctx.Err() - return - } - - conn, err := l.AcceptUnix() - if err != nil { - r.err = err - return - } - r.conn = conn - }() - - return r, nil -} diff --git a/lib/files/unixsocket/unixsocket.go b/lib/files/unixsocket/unixsocket.go deleted file mode 100644 index 41ba692..0000000 --- a/lib/files/unixsocket/unixsocket.go +++ /dev/null @@ -1,76 +0,0 @@ -// Package unixsocket implements the "unix:" URL scheme, by reading/writing to a raw unix socket. -package unixsocket - -import ( - "context" - "net" - "net/url" - "os" - "time" - - "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/files/wrapper" -) - -type handler struct{} - -func init() { - files.RegisterScheme(&handler{}, "unix") -} - -type writer struct { - *net.UnixConn - *wrapper.Info -} - -func (w *writer) Sync() error { return nil } - -// URL query field keys. -const ( - FieldLocalAddress = "local_addr" -) - -func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { - path := uri.Path - if path == "" { - path = uri.Opaque - } - - raddr, err := net.ResolveUnixAddr("unix", path) - if err != nil { - return nil, err - } - - fixURL := &url.URL{ - Scheme: "unix", - Opaque: raddr.String(), - } - - var laddr *net.UnixAddr - - q := uri.Query() - if addr := q.Get(FieldLocalAddress); addr != "" { - laddr, err = net.ResolveUnixAddr("unix", addr) - if err != nil { - return nil, err - } - q.Set(FieldLocalAddress, laddr.String()) - fixURL.RawQuery = q.Encode() - } - - conn, err := net.DialUnix("unix", laddr, raddr) - if err != nil { - return nil, err - } - - w := &writer{ - UnixConn: conn, - Info: wrapper.NewInfo(fixURL, 0, time.Now()), - } - - return w, nil -} - -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) -}