-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.go
340 lines (300 loc) · 7.77 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
package gorpc
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"gorpc/codec"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
)
type Call struct{
Seq uint64
ServiceMethod string
Args interface{}
Reply interface{}
Error error
Done chan *Call
}
func (call *Call)done(){
call.Done <- call
}
// a RPC client
// asycn supported, concurrent call supported
type Client struct{
cc codec.Codec
opt *Option
// sending process needs to be sequential
sending sync.Mutex // protect
header codec.Header
// operation need to be sequential
mu sync.Mutex // protect
seq uint64 // unqiue id for each call
pending map[uint64]*Call //ongoing calls
closing bool // client send stop
shutdown bool // server send stop
}
var _ io.Closer = (*Client)(nil)
// any call while the client is shutting down will trigger this function
var ErrShutdown = errors.New("connection is shut down")
// close current client
func(client *Client)Close()error{
client.mu.Lock()
defer client.mu.Unlock()
// if some other route send close
if client.closing{
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
func (client *Client)IsAvailable()bool{
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
// this will register a call in client
func(client *Client)registerCall(call *Call)(uint64,error){
client.mu.Lock()
defer client.mu.Unlock()
// client cannot be shut down or closing
if client.closing || client.shutdown{
return 0,ErrShutdown
}
// call will have unique id
call.Seq = client.seq
// add into pending
client.pending[call.Seq] = call
// increment id
client.seq++
return call.Seq,nil
}
// remove corresponding calls
func(client *Client)removeCall(seq uint64)*Call{
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending,seq)
return call
}
// remove wrong call, shutdown all the calls
func(client *Client)terminateCalls(err error){
// hold lock for any sending or operation
client.sending.Lock()
defer client.sending.Lock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending{
call.Error = err
call.done()
}
}
// recieve response from server
func(client *Client)receive(){
var err error
// update error in the loop
for err == nil{
var h codec.Header
if err = client.cc.ReadHeader(&h);err != nil{
break
}
call := client.removeCall(h.Seq)
switch{
case call == nil:
// request not sent completely or cancled, by still processed the server
err = client.cc.ReadBody(nil)
case h.Error != "":
// serverside return an error back
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
// assume no error, readbody into call.reply
err = client.cc.ReadBody(call.Reply)
if err != nil{
// if reading error occurs place it into call error
call.Error = errors.New("reading body "+err.Error())
}
call.done()
}
}
// if an error occurs, terminate all calls
client.terminateCalls(err)
}
func(client *Client)send(call *Call){
// make sure request sent complete
client.sending.Lock()
defer client.sending.Unlock()
seq, err := client.registerCall(call)
if err != nil{
call.Error = err
call.done()
return
}
// place info in header
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// encode and send request to the server
if err := client.cc.Write(&client.header,call.Args);err != nil{
call := client.removeCall(call.Seq)
// if call is not nil, which usually means its partially failed.
// client has recieved the reponsee and handled
if call != nil{
call.Error = err
call.done()
}
}
}
func(client *Client)Go(serviceMethod string, args, reply interface{}, done chan *Call)*Call{
if done == nil{
done = make(chan *Call,10)
}else if cap(done) == 0{
log.Panic("rpc client: done channl is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
func(client *Client)Call(ctx context.Context,serviceMethod string, args, reply interface{})error{
// user can use context withtime out to add timeout during call
call := client.Go(serviceMethod,args,reply,make(chan *Call,1))
select{
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client:call failed:"+ctx.Err().Error())
case call:= <-call.Done:
return call.Error
}
}
type clientResult struct{
client *Client
err error
}
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
func dialTimeout(f newClientFunc, network,address string, opts ...*Option)(client *Client,err error){
// parse option
opt,err := parseOptions(opts...)
if err != nil{
return nil, err
}
// dial with timeout
conn,err := net.DialTimeout(network,address,opt.ConnectTimeout)
if err != nil{
return nil,err
}
defer func(){
if err != nil{
_ = conn.Close()
}
}()
// channel to place connetion result
ch := make(chan clientResult)
// create client
go func(){
client, err := f(conn,opt)
ch <- clientResult{client:client,err: err}
}()
// not timeout
if opt.ConnectTimeout == 0{
result := <-ch
return result.client,err
}
// check weather timeout reach first or result reach first
select{
case <- time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client:connect timeout:expect within %s",opt.ConnectTimeout)
case result := <-ch:
return result.client,result.err
}
}
// constructor for http client
func NewHTTPClient(conn net.Conn, opt*Option)(*Client,error){
_,_ = io.WriteString(conn,fmt.Sprintf("CONNECT %s HTTP/1.0\n\n",defaultRPCPath))
resp, err := http.ReadResponse(bufio.NewReader(conn),&http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected{
return NewClient(conn,opt)
}
if err == nil{
err = errors.New("unexpected HTTP response" + resp.Status)
}
return nil,err
}
// Dial with httpnetwork first
func DialHTTP(network,addr string, opts ...*Option)(*Client,error){
return dialTimeout(NewHTTPClient,network,addr,opts...)
}
// General Dial entry
func XDial(rpcAddr string, opts ...*Option)(*Client,error){
parts := strings.Split(rpcAddr,"@")
if len(parts) != 2{
return nil, fmt.Errorf("rpc client err:wrong format '%s', expect protocol@addr",rpcAddr)
}
protocol,addr := parts[0],parts[1]
switch protocol{
case "http":
return DialHTTP("tcp",addr,opts...)
default:
return Dial(protocol,addr,opts...)
}
}
// constructor for client
func NewClient(conn net.Conn, opt *Option)(*Client, error){
// check if codec function can be found
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil{
err := fmt.Errorf("invalid codec type %s",opt.CodecType)
log.Println("rpc client:codec error:",err)
return nil, err
}
// send opt to server to check validity.
if err := json.NewEncoder(conn).Encode(opt);err != nil{
log.Println("rpc client:options error:",err)
_ = conn.Close()
return nil, err
}
// create codec and client
return newClientCodec(f(conn),opt),nil
}
func newClientCodec(cc codec.Codec,opt *Option)*Client{
client := &Client{
seq: 1,
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
// start reciveing response
go client.receive()
return client
}
// we allow user to enter option or just using default
func parseOptions(opts ...*Option)(*Option,error){
if len(opts) == 0 || opts[0] == nil{
return DefaultOption,nil
}
if len(opts) != 1{
return nil, errors.New("number of options is more than one")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if(opt.CodecType == ""){
opt.CodecType = DefaultOption.CodecType
}
return opt,nil
}
// dial function that allows user to pass in address and port, option is optional
func Dial(network, address string, opts...*Option)(client *Client,err error){
return dialTimeout(NewClient,network,address,opts...)
}