diff --git a/README.md b/README.md index e470c59..89e0ed6 100644 --- a/README.md +++ b/README.md @@ -20,25 +20,26 @@ var log = gn.GetLogger() var server *gn.Server -var encoder = gn.NewHeaderLenEncoder(2, 1024) +type Handler struct{} -type Handler struct {} - -func (Handler) OnConnect(c *gn.Conn) { +func (*Handler) OnConnect(c *gn.Conn) { log.Info("connect:", c.GetFd(), c.GetAddr()) } -func (Handler) OnMessage(c *gn.Conn, bytes []byte) { - encoder.EncodeToFD(c.GetFd(), bytes) +func (*Handler) OnMessage(c *gn.Conn, bytes []byte) { + c.WriteWithEncoder(bytes) log.Info("read:", string(bytes)) } -func (Handler) OnClose(c *gn.Conn, err error) { +func (*Handler) OnClose(c *gn.Conn, err error) { log.Info("close:", c.GetFd(), err) } func main() { var err error - server, err = gn.NewServer(":8080", Handler{}, gn.NewHeaderLenDecoder(2), - gn.WithTimeout(5*time.Second), gn.WithReadBufferLen(10)) + server, err = gn.NewServer(":8080", &Handler{}, + gn.WithDecoder(gn.NewHeaderLenDecoder(2)), + gn.WithEncoder(gn.NewHeaderLenEncoder(2, 1024)), + gn.WithTimeout(5*time.Second), + gn.WithReadBufferLen(10)) if err != nil { log.Info("err") return @@ -46,4 +47,5 @@ func main() { server.Run() } + ``` diff --git a/buffer.go b/buffer.go index 6dc11c1..813a7fb 100644 --- a/buffer.go +++ b/buffer.go @@ -80,6 +80,12 @@ func (b *Buffer) Read(offset, limit int) ([]byte, error) { return buf, nil } +// ReadAll 读取所有字节 +func (b *Buffer) ReadAll() []byte { + buf, _ := b.Read(b.start, b.end) + return buf +} + // reset 重新设置缓存区(将有用字节前移) func (b *Buffer) reset() { if b.start == 0 { diff --git a/codec.go b/codec.go index 3ba93a6..89354dd 100644 --- a/codec.go +++ b/codec.go @@ -6,17 +6,16 @@ import ( "fmt" "io" "sync" - "syscall" ) // Decoder 解码器 type Decoder interface { - Decode(c *Conn) error + Decode(*Buffer, func([]byte)) error } // Encoder 编码器 type Encoder interface { - EncodeToFD(fd int32, bytes []byte) error + EncodeToWriter(w io.Writer, bytes []byte) error } type headerLenDecoder struct { @@ -37,8 +36,7 @@ func NewHeaderLenDecoder(headerLen int) Decoder { } // Decode 解码 -func (d *headerLenDecoder) Decode(c *Conn) error { - buffer := c.GetBuffer() +func (d *headerLenDecoder) Decode(buffer *Buffer, handle func([]byte)) error { for { header, err := buffer.Seek(d.headerLen) if err == ErrNotEnough { @@ -55,7 +53,7 @@ func (d *headerLenDecoder) Decode(c *Conn) error { return nil } - c.OnMessage(body) + handle(body) } } @@ -85,27 +83,6 @@ func NewHeaderLenEncoder(headerLen, writeBufferLen int) *headerLenEncoder { } } -// EncodeToFD 编码数据,并且写入文件描述符 -func (e headerLenEncoder) EncodeToFD(fd int32, bytes []byte) error { - l := len(bytes) - var buffer []byte - if l <= e.writeBufferLen-e.headerLen { - obj := e.writeBufferPool.Get() - defer e.writeBufferPool.Put(obj) - buffer = obj.([]byte)[0 : l+e.headerLen] - } else { - buffer = make([]byte, l+e.headerLen) - } - - // 将消息长度写入buffer - binary.BigEndian.PutUint16(buffer[0:2], uint16(l)) - // 将消息内容内容写入buffer - copy(buffer[e.headerLen:], bytes) - - _, err := syscall.Write(int(fd), buffer) - return err -} - // EncodeToWriter 编码数据,并且写入Writer func (e headerLenEncoder) EncodeToWriter(w io.Writer, bytes []byte) error { l := len(bytes) diff --git a/conn.go b/conn.go index 153f3f9..b0708b8 100644 --- a/conn.go +++ b/conn.go @@ -50,7 +50,7 @@ func (c *Conn) GetBuffer() *Buffer { } // Read 读取数据 -func (c *Conn) Read() error { +func (c *Conn) read() error { if c.server.options.timeout != 0 { c.timer.Reset(c.server.options.timeout) } @@ -66,13 +66,25 @@ func (c *Conn) Read() error { return err } - err = c.server.decoder.Decode(c) - if err != nil { - return err + if c.server.options.decoder == nil { + c.server.handler.OnMessage(c, c.buffer.ReadAll()) + } else { + var handle = func(bytes []byte) { + c.server.handler.OnMessage(c, bytes) + } + err = c.server.options.decoder.Decode(c.buffer, handle) + if err != nil { + return err + } } } } +// WriteWithEncoder 使用编码器写入 +func (c *Conn) WriteWithEncoder(bytes []byte) error { + return c.server.options.encoder.EncodeToWriter(c, bytes) +} + // Write 写入数据 func (c *Conn) Write(bytes []byte) (int, error) { return syscall.Write(int(c.fd), bytes) @@ -105,11 +117,6 @@ func (c *Conn) CloseRead() error { return nil } -// OnMessage 消息处理 -func (c *Conn) OnMessage(bytes []byte) { - c.server.handler.OnMessage(c, bytes) -} - // GetData 获取数据 func (c *Conn) GetData() interface{} { return c.data diff --git a/options.go b/options.go new file mode 100644 index 0000000..b3fba48 --- /dev/null +++ b/options.go @@ -0,0 +1,115 @@ +package gn + +import ( + "runtime" + "time" +) + +// options Server初始化参数 +type options struct { + decoder Decoder // 解码器 + encoder Encoder // 编码器 + readBufferLen int // 所读取的客户端包的最大长度,客户端发送的包不能超过这个长度,默认值是1024字节 + acceptGNum int // 处理接受请求的goroutine数量 + ioGNum int // 处理io的goroutine数量 + ioEventQueueLen int // io事件队列长度 + timeout time.Duration // 超时时间 +} + +type Option interface { + apply(*options) +} + +type funcServerOption struct { + f func(*options) +} + +func (fdo *funcServerOption) apply(do *options) { + fdo.f(do) +} + +func newFuncServerOption(f func(*options)) *funcServerOption { + return &funcServerOption{ + f: f, + } +} + +// WithDecoder 设置解码器 +func WithDecoder(decoder Decoder) Option { + return newFuncServerOption(func(o *options) { + o.decoder = decoder + }) +} + +// WithEncoder 设置解码器 +func WithEncoder(encoder Encoder) Option { + return newFuncServerOption(func(o *options) { + o.encoder = encoder + }) +} + +// WithReadBufferLen 设置缓存区大小 +func WithReadBufferLen(len int) Option { + return newFuncServerOption(func(o *options) { + if len <= 0 { + panic("acceptGNum must greater than 0") + } + o.readBufferLen = len + }) +} + +// WithAcceptGNum 设置建立连接的goroutine数量 +func WithAcceptGNum(num int) Option { + return newFuncServerOption(func(o *options) { + if num <= 0 { + panic("acceptGNum must greater than 0") + } + o.acceptGNum = num + }) +} + +// WithIOGNum 设置处理IO的goroutine数量 +func WithIOGNum(num int) Option { + return newFuncServerOption(func(o *options) { + if num <= 0 { + panic("IOGNum must greater than 0") + } + o.ioGNum = num + }) +} + +// WithIOEventQueueLen 设置IO事件队列长度,默认值是1024 +func WithIOEventQueueLen(num int) Option { + return newFuncServerOption(func(o *options) { + if num <= 0 { + panic("ioEventQueueLen must greater than 0") + } + o.ioEventQueueLen = num + }) +} + +// WithTimeout 设置TCP超时检查的间隔时间以及超时时间 +func WithTimeout(timeout time.Duration) Option { + return newFuncServerOption(func(o *options) { + if timeout <= 0 { + panic("timeoutTicker must greater than 0") + } + + o.timeout = timeout + }) +} + +func getOptions(opts ...Option) *options { + cpuNum := runtime.NumCPU() + options := &options{ + readBufferLen: 1024, + acceptGNum: cpuNum, + ioGNum: cpuNum, + ioEventQueueLen: 1024, + } + + for _, o := range opts { + o.apply(options) + } + return options +} diff --git a/server.go b/server.go index 8f595fc..017ad2a 100644 --- a/server.go +++ b/server.go @@ -4,11 +4,9 @@ import ( "errors" "fmt" "io" - "runtime" "sync" "sync/atomic" "syscall" - "time" ) var ( @@ -22,99 +20,6 @@ type Handler interface { OnClose(c *Conn, err error) // OnClose 当客户端主动断开链接或者超时时回调,err返回关闭的原因 } -// options Server初始化参数 -type options struct { - readBufferLen int // 所读取的客户端包的最大长度,客户端发送的包不能超过这个长度,默认值是1024字节 - acceptGNum int // 处理接受请求的goroutine数量 - ioGNum int // 处理io的goroutine数量 - ioEventQueueLen int // io事件队列长度 - timeout time.Duration // 超时时间 -} - -type Option interface { - apply(*options) -} - -type funcServerOption struct { - f func(*options) -} - -func (fdo *funcServerOption) apply(do *options) { - fdo.f(do) -} - -func newFuncServerOption(f func(*options)) *funcServerOption { - return &funcServerOption{ - f: f, - } -} - -// WithReadBufferLen 设置缓存区大小 -func WithReadBufferLen(len int) Option { - return newFuncServerOption(func(o *options) { - if len <= 0 { - panic("acceptGNum must greater than 0") - } - o.readBufferLen = len - }) -} - -// WithAcceptGNum 设置建立连接的goroutine数量 -func WithAcceptGNum(num int) Option { - return newFuncServerOption(func(o *options) { - if num <= 0 { - panic("acceptGNum must greater than 0") - } - o.acceptGNum = num - }) -} - -// WithIOGNum 设置处理IO的goroutine数量 -func WithIOGNum(num int) Option { - return newFuncServerOption(func(o *options) { - if num <= 0 { - panic("IOGNum must greater than 0") - } - o.ioGNum = num - }) -} - -// WithIOEventQueueLen 设置IO事件队列长度,默认值是1024 -func WithIOEventQueueLen(num int) Option { - return newFuncServerOption(func(o *options) { - if num <= 0 { - panic("ioEventQueueLen must greater than 0") - } - o.ioEventQueueLen = num - }) -} - -// WithTimeout 设置TCP超时检查的间隔时间以及超时时间 -func WithTimeout(timeout time.Duration) Option { - return newFuncServerOption(func(o *options) { - if timeout <= 0 { - panic("timeoutTicker must greater than 0") - } - - o.timeout = timeout - }) -} - -func getOptions(opts ...Option) *options { - cpuNum := runtime.NumCPU() - options := &options{ - readBufferLen: 1024, - acceptGNum: cpuNum, - ioGNum: cpuNum, - ioEventQueueLen: 1024, - } - - for _, o := range opts { - o.apply(options) - } - return options -} - const ( EventIn = 1 // 数据流入 EventClose = 2 // 断开连接 @@ -132,7 +37,6 @@ type Server struct { options *options // 服务参数 readBufferPool *sync.Pool // 读缓存区内存池 handler Handler // 注册的处理 - decoder Decoder // 解码器 ioEventQueues []chan event // IO事件队列集合 ioQueueNum int32 // IO事件队列集合数量 conns sync.Map // TCP长连接管理 @@ -141,7 +45,7 @@ type Server struct { } // NewServer 创建server服务器 -func NewServer(address string, handler Handler, decoder Decoder, opts ...Option) (*Server, error) { +func NewServer(address string, handler Handler, opts ...Option) (*Server, error) { options := getOptions(opts...) // 初始化读缓存区内存池 @@ -170,7 +74,6 @@ func NewServer(address string, handler Handler, decoder Decoder, opts ...Option) options: options, readBufferPool: readBufferPool, handler: handler, - decoder: decoder, ioEventQueues: ioEventQueues, ioQueueNum: int32(options.ioGNum), conns: sync.Map{}, @@ -179,6 +82,15 @@ func NewServer(address string, handler Handler, decoder Decoder, opts ...Option) }, nil } +// GetConn 获取Conn +func (s *Server) GetConn(fd int32) (*Conn, bool) { + value, ok := s.conns.Load(fd) + if !ok { + return nil, false + } + return value.(*Conn), true +} + // Run 启动服务 func (s *Server) Run() { log.Info("gn server run") @@ -285,7 +197,7 @@ func (s *Server) consumeIOEvent(queue chan event) { continue } - err := c.Read() + err := c.read() if err != nil { // 服务端关闭连接 if err == syscall.EBADF { diff --git a/test/client/main.go b/test/client/client_test.go similarity index 54% rename from test/client/main.go rename to test/client/client_test.go index d59f514..d29c523 100644 --- a/test/client/main.go +++ b/test/client/client_test.go @@ -5,6 +5,7 @@ import ( "log" "net" "strconv" + "testing" ) var codecFactory = util.NewHeaderLenCodecFactory(2, 1024) @@ -13,14 +14,50 @@ func init() { log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) } -func main() { +func TestClient(t *testing.T) { + for i := 0; i < 1; i++ { + go startWithCodec(i) + } + select {} +} + +func start(i int) { + log.Println(i, "start") + conn, err := net.Dial("tcp", "127.0.0.1:8085") + if err != nil { + log.Println(i, "Error dialing", err.Error()) + return // 终止程序 + } + + go func() { + for { + bytes := make([]byte, 0, 50) + n, err := conn.Read(bytes) + if err != nil { + log.Println(err) + return + } + log.Println(i, string(bytes[0:n])) + } + }() + for i := 0; i < 10; i++ { + _, err := conn.Write([]byte("hello" + strconv.Itoa(i))) + if err != nil { + log.Println(err) + return + } + } +} + +func TestClientWithCodec(t *testing.T) { + for i := 0; i < 1; i++ { go start(i) } select {} } -func start(i int) { +func startWithCodec(i int) { log.Println(i, "start") conn, err := net.Dial("tcp", "127.0.0.1:8085") if err != nil { diff --git a/test/server/main.go b/test/server/main.go index 68e48fa..34f688f 100644 --- a/test/server/main.go +++ b/test/server/main.go @@ -10,15 +10,13 @@ var log = gn.GetLogger() var server *gn.Server -var encoder = gn.NewHeaderLenEncoder(2, 1024) - type Handler struct{} func (*Handler) OnConnect(c *gn.Conn) { log.Info("connect:", c.GetFd(), c.GetAddr()) } func (*Handler) OnMessage(c *gn.Conn, bytes []byte) { - encoder.EncodeToFD(c.GetFd(), bytes) + c.WriteWithEncoder(bytes) log.Info("read:", string(bytes)) } func (*Handler) OnClose(c *gn.Conn, err error) { @@ -28,7 +26,8 @@ func (*Handler) OnClose(c *gn.Conn, err error) { func main() { var err error server, err = gn.NewServer(":8080", &Handler{}, - gn.NewHeaderLenDecoder(2), + gn.WithDecoder(gn.NewHeaderLenDecoder(2)), + gn.WithEncoder(gn.NewHeaderLenEncoder(2, 1024)), gn.WithTimeout(5*time.Second), gn.WithReadBufferLen(10)) if err != nil {