type Option struct { MagicNumber int// MagicNumber marks this's a geerpc request CodecType codec.Type // client may choose different Codec to encode body }
var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, }
// Call represents an active RPC. type Call struct { Seq uint64 ServiceMethod string// format "<service>.<method>" Args interface{} // arguments to the function Reply interface{} // reply from the function Error error // if error occurs, it will be set Done chan *Call // Strobes when call is complete. }
// There may be multiple outstanding Calls associated // with a single Client, and a Client may be used by // multiple goroutines simultaneously. type Client struct { cc codec.Codec opt *Option sending sync.Mutex // protect following header codec.Header mu sync.Mutex // protect following seq uint64 pending map[uint64]*Call closing bool// user has called Close shutdown bool// server has told us to stop }
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shut down")
// Close the connection func(client *Client)Close()error { client.mu.Lock() defer client.mu.Unlock() if client.closing { return ErrShutdown } client.closing = true return client.cc.Close() }
// IsAvailable return true if the client does work func(client *Client)IsAvailable()bool { client.mu.Lock() defer client.mu.Unlock() return !client.shutdown && !client.closing }
// Go invokes the function asynchronously. // It returns the Call structure representing the invocation. func(client *Client)Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { if done == nil { done = make(chan *Call, 10) } elseifcap(done) == 0 { log.Panic("rpc client: done channel is unbuffered") } call := &Call{ ServiceMethod: serviceMethod, Args: args, Reply: reply, Done: done, } client.send(call) return call }
// Call invokes the named function, waits for it to complete, // and returns its error status. func(client *Client)Call(serviceMethod string, args, reply interface{})error { call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done return call.Error }
Go 和 Call 是客户端暴露给用户的两个 RPC 服务调用接口,Go 是一个异步接口,返回 call 实例。
func(m *methodType)newArgv()reflect.Value { var argv reflect.Value // arg may be a pointer type, or a value type if m.ArgType.Kind() == reflect.Ptr { argv = reflect.New(m.ArgType.Elem()) } else { argv = reflect.New(m.ArgType).Elem() } return argv }
func(m *methodType)newReplyv()reflect.Value { // reply must be a pointer type replyv := reflect.New(m.ReplyType.Elem()) switch m.ReplyType.Elem().Kind() { case reflect.Map: replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) case reflect.Slice: replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0)) } return replyv }
funcnewService(rcvr interface{}) *service { s := new(service) s.rcvr = reflect.ValueOf(rcvr) s.name = reflect.Indirect(s.rcvr).Type().Name() s.typ = reflect.TypeOf(rcvr) if !ast.IsExported(s.name) { log.Fatalf("rpc server: %s is not a valid service name", s.name) } s.registerMethods() return s }
// request stores all information of a call type request struct { h *codec.Header // header of request argv, replyv reflect.Value // argv and replyv of request mtype *methodType svc *service }
// make sure that argvi is a pointer, ReadBody need a pointer as parameter argvi := req.argv.Interface() if req.argv.Type().Kind() != reflect.Ptr { argvi = req.argv.Addr().Interface() } if err = cc.ReadBody(argvi); err != nil { log.Println("rpc server: read body err:", err) return req, err } return req, nil }
type Option struct { MagicNumber int// MagicNumber marks this's a geerpc request CodecType codec.Type // client may choose different Codec to encode body ConnectTimeout time.Duration // 0 means no limit HandleTimeout time.Duration }
// Dial connects to an RPC server at the specified network address funcDial(network, address string, opts ...*Option)(*Client, error) { return dialTimeout(NewClient, network, address, opts...) }
funcdialTimeout(f newClientFunc, network, address string, opts ...*Option)(client *Client, err error) { opt, err := parseOptions(opts...) if err != nil { returnnil, err } conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) if err != nil { returnnil, err } // close the connection if client is nil deferfunc() { if err != nil { _ = conn.Close() } }() ch := make(chan clientResult) gofunc() { client, err := f(conn, opt) ch <- clientResult{client: client, err: err} }() if opt.ConnectTimeout == 0 { result := <-ch return result.client, result.err } select { case <-time.After(opt.ConnectTimeout): returnnil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } }
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath. // It is still necessary to invoke http.Serve(), typically in a go statement. func(server *Server)HandleHTTP() { http.Handle(defaultRPCPath, server) }
// HandleHTTP is a convenient approach for default server to register HTTP handlers funcHandleHTTP() { DefaultServer.HandleHTTP() }
为了简化调用,提供了一个统一入口 XDial:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// XDial calls different functions to connect to a RPC server // according the first parameter rpcAddr. // rpcAddr is a general format (protocol@addr) to represent a rpc server // eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock funcXDial(rpcAddr string, opts ...*Option)(*Client, error) { parts := strings.Split(rpcAddr, "@") iflen(parts) != 2 { returnnil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr) } protocol, addr := parts[0], parts[1] switch protocol { case"http": return DialHTTP("tcp", addr, opts...) default: // tcp, unix or other transport protocol return Dial(protocol, addr, opts...) } }
简单的 DEBUG 页面
在 /debug/geerpc 上展示服务的调用统计视图。我们将返回一个 HTML 报文,这个报文将展示注册所有的 service 的每一个方法的调用情况。将 debugHTTP 实例绑定到地址 /debug/geerpc。
// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file.
// MultiServersDiscovery is a discovery for multi servers without a registry center // user provides the server addresses explicitly instead type MultiServersDiscovery struct { r *rand.Rand // generate random number mu sync.RWMutex // protect following servers []string index int// record the selected position for robin algorithm }
// NewMultiServerDiscovery creates a MultiServersDiscovery instance funcNewMultiServerDiscovery(servers []string) *MultiServersDiscovery { d := &MultiServersDiscovery{ servers: servers, r: rand.New(rand.NewSource(time.Now().UnixNano())), } d.index = d.r.Intn(math.MaxInt32 - 1) return d }
r 是一个产生随机数的实例,初始化时使用时间戳设定随机数种子,避免每次产生相同的随机数序列。
index 记录 Round Robin 算法已经轮询到的位置,为了避免每次从 0 开始,初始化时随机设定一个值。
// Refresh doesn't make sense for MultiServersDiscovery, so ignore it func(d *MultiServersDiscovery)Refresh()error { returnnil }
// Update the servers of discovery dynamically if needed func(d *MultiServersDiscovery)Update(servers []string)error { d.mu.Lock() defer d.mu.Unlock() d.servers = servers returnnil }
// Get a server according to mode func(d *MultiServersDiscovery)Get(mode SelectMode)(string, error) { d.mu.Lock() defer d.mu.Unlock() n := len(d.servers) if n == 0 { return"", errors.New("rpc discovery: no available servers") } switch mode { case RandomSelect: return d.servers[d.r.Intn(n)], nil case RoundRobinSelect: s := d.servers[d.index%n] // servers could be updated, so mode n to ensure safety d.index = (d.index + 1) % n return s, nil default: return"", errors.New("rpc discovery: not supported select mode") } }
// returns all servers in discovery func(d *MultiServersDiscovery)GetAll()([]string, error) { d.mu.RLock() defer d.mu.RUnlock() // return a copy of d.servers servers := make([]string, len(d.servers), len(d.servers)) copy(servers, d.servers) return servers, nil }
func(xc *XClient)Close()error { xc.mu.Lock() defer xc.mu.Unlock() for key, client := range xc.clients { // I have no idea how to deal with error, just ignore it. _ = client.Close() delete(xc.clients, key) } returnnil }
// Broadcast invokes the named function for every server registered in discovery func(xc *XClient)Broadcast(ctx context.Context, serviceMethod string, args, reply interface{})error { servers, err := xc.d.GetAll() if err != nil { return err } var wg sync.WaitGroup var mu sync.Mutex // protect e and replyDone var e error replyDone := reply == nil// if reply is nil, don't need to set value ctx, cancel := context.WithCancel(ctx) for _, rpcAddr := range servers { wg.Add(1) gofunc(rpcAddr string) { defer wg.Done() var clonedReply interface{} if reply != nil { clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface() } err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply) mu.Lock() if err != nil && e == nil { e = err cancel() // if any call failed, cancel unfinished calls } if err == nil && !replyDone { reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem()) replyDone = true } mu.Unlock() }(rpcAddr) } wg.Wait() return e }
// GeeRegistry is a simple register center, provide following functions. // add a server and receive heartbeat to keep it alive. // returns all alive servers and delete dead servers sync simultaneously. type GeeRegistry struct { timeout time.Duration mu sync.Mutex // protect following servers map[string]*ServerItem }
type ServerItem struct { Addr string start time.Time }
// Runs at /_geerpc_/registry func(r *GeeRegistry)ServeHTTP(w http.ResponseWriter, req *http.Request) { switch req.Method { case"GET": // keep it simple, server is in req.Header w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ",")) case"POST": // keep it simple, server is in req.Header addr := req.Header.Get("X-Geerpc-Server") if addr == "" { w.WriteHeader(http.StatusInternalServerError) return } r.putServer(addr) default: w.WriteHeader(http.StatusMethodNotAllowed) } }
// HandleHTTP registers an HTTP handler for GeeRegistry messages on registryPath func(r *GeeRegistry)HandleHTTP(registryPath string) { http.Handle(registryPath, r) log.Println("rpc registry path:", registryPath) }
// Heartbeat send a heartbeat message every once in a while // it's a helper function for a server to register or send heartbeat funcHeartbeat(registry, addr string, duration time.Duration) { if duration == 0 { // make sure there is enough time to send heart beat // before it's removed from registry duration = defaultTimeout - time.Duration(1)*time.Minute } var err error err = sendHeartbeat(registry, addr) gofunc() { t := time.NewTicker(duration) for err == nil { <-t.C err = sendHeartbeat(registry, addr) } }() }