目的

  • 以最少的代码,实现 RPC 框架中最为重要的部分,帮助大家理解 RPC 框架在设计时需要考虑什么。代码简洁是第一位的,功能是第二位的。
  • 因此,选择从零实现 Go 语言官方的标准库 net/rpc,并在此基础上,新增了协议交换 (protocol exchange)、注册中心 (registry)、服务发现 (service discovery)、负载均衡 (load balance)、超时处理 (timeout processing) 等特性。

RPC 框架需要解决什么问题

  • 确定采用的传输协议
  • 确定报文的编码格式
  • 如果服务端的实例很多,客户端并不关心这些实例的地址和部署位置,只关心自己能否获取到期待的结果,那就引出了注册中心 (registry) 和负载均衡 (load balance) 的问题。

简单地说,即客户端和服务端互相不感知对方的存在,服务端启动时将自己注册到注册中心,客户端调用时,从注册中心获取到所有可用的实例,选择一个来调用。这样服务端和客户端只需要感知注册中心的存在就够了。

注册中心通常还需要实现服务动态添加、删除,使用心跳确保服务处于可用状态等功能。

消息的序列化与反序列化

一个典型的 RPC 调用如下:

1
err = client.Call("Arith.Multiply", args, &reply)

客户端发送的请求包括服务名 Arith,方法名 Multiply,参数 args 三个,服务端的响应包括错误 error,返回值 reply 2 个。我们将请求和响应中的参数和返回值抽象为 body,剩余的信息放在 header 中,那么就可以抽象出数据结构 Header:

1
2
3
4
5
6
7
8
9
package codec

import "io"

type Header struct {
ServiceMethod string // format "Service.Method"
Seq uint64 // sequence number chosen by client
Error string
}
  • ServiceMethod 是服务名和方法名,通常与 Go 语言中的结构体和方法相映射。
  • Seq 是请求的序号,也可以认为是某个请求的 ID,用来区分不同的请求
  • Error 是错误信息,客户端置为空,服务端如果如果发生错误,将错误信息置于 Error 中。

对消息体进行编解码的接口 Codec

1
2
3
4
5
6
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package codec

import (
"bufio"
"encoding/gob"
"io"
"log"
)

type GobCodec struct {
conn io.ReadWriteCloser // conn 是由构建函数传入,通常是通过 TCP 或者 Unix 建立 socket 时得到的链接实例
buf *bufio.Writer // buf 是为了防止阻塞而创建的带缓冲的 Writer,一般这么做能提升性能
dec *gob.Decoder // dec 和 enc 对应 gob 的 Decoder 和 Encoder
enc *gob.Encoder
}

var _ Codec = (*GobCodec)(nil)

func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn), // dec不使用buf
enc: gob.NewEncoder(buf), // encode使用buf
}
}

Noted:可以使用抽象出io.ReadWriteCloser表示conn

通信过程与通信协议

  • 客户端与服务端的通信需要协商一些内容,例如 HTTP 报文,分为 header 和 body 2 部分,body 的格式和长度通过 header 中的 Content-TypeContent-Length 指定,服务端通过解析 header 就能够知道如何从 body 中读取需要的信息。

  • 对于 RPC 协议来说,这部分协商是需要自主设计的。为了提升性能,一般在报文的最开始会规划固定的字节,来协商相关的信息

    比如第 1 个字节用来表示序列化方式,第 2 个字节表示压缩方式,第 3-6 字节表示 header 的长度,7-10 字节表示 body 的长度。

  • 目前需要协商的唯一一项内容是消息的编解码方式。我们将这部分信息,放到结构体 Option 中承载。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    package geerpc

    const MagicNumber = 0x3bef5c

    type Option struct {
    MagicNumber int // MagicNumber marks this's a geerpc request
    CodecType codec.Type // client may choose different Codec to encode body
    }

    var DefaultOption = &Option{
    MagicNumber: MagicNumber,
    CodecType: codec.GobType,
    }
  • 一般来说,涉及协议协商的这部分信息,需要设计固定的字节来传输的。但是为了实现上更简单,GeeRPC 客户端固定采用 JSON 编码 Option,后续的 header 和 body 的编码方式由 Option 中的 CodeType 指定,服务端首先使用 JSON 解码 Option,然后通过 Option 的 CodeType 解码剩余的内容。即报文将以这样的形式发送:

    1
    2
    | Option{MagicNumber: xxx, CodecType: xxx} | Header{ServiceMethod ...} | Body interface{} |
    | <------ 固定 JSON 编码 ------> | <------- 编码方式由 CodeType 决定 ------->|
  • 在一次连接中,Option 固定在报文的最开始,Header 和 Body 可以有多个,即报文可能是这样的。

    1
    | Option | Header1 | Body1 | Header2 | Body2 | ...

Call 的设计

net/rpc 而言,一个函数需要能够被远程调用,需要满足如下五个条件:

  • the method’s type is exported.
  • the method is exported.
  • the method has two arguments, both exported (or builtin) types.
  • the method’s second argument is a pointer.
  • the method has return type error.

更直观一些:

1
func (t *T) MethodName(argType T1, replyType *T2) error
1
2
3
4
5
6
7
8
9
10
11
12
13
// Call represents an active RPC.
type Call struct {
Seq uint64
ServiceMethod string // format "<service>.<method>"
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Done chan *Call // Strobes when call is complete.
}

func (call *Call) done() {
call.Done <- call
}

为了支持异步调用,Call 结构体中添加了一个字段 Done,Done 的类型是 chan *Call,当调用结束时,会调用 call.done() 通知调用方。

实现 Client

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
cc codec.Codec
opt *Option
sending sync.Mutex // protect following
header codec.Header
mu sync.Mutex // protect following
seq uint64
pending map[uint64]*Call
closing bool // user has called Close
shutdown bool // server has told us to stop
}

var _ io.Closer = (*Client)(nil)

var ErrShutdown = errors.New("connection is shut down")

// Close the connection
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}

// IsAvailable return true if the client does work
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}

Client 的字段比较复杂:

  • cc 是消息的编解码器,和服务端类似,用来序列化将要发送出去的请求,以及反序列化接收到的响应。
  • sending 是一个互斥锁,和服务端类似,为了保证请求的有序发送,即防止出现多个请求报文混淆。
  • header 是每个请求的消息头,header 只有在请求发送时才需要,而请求发送是互斥的,因此每个客户端只需要一个,声明在 Client 结构体中可以复用。
  • seq 用于给发送的请求编号,每个请求拥有唯一编号。
  • pending 存储未处理完的请求,键是编号,值是 Call 实例。
  • closing 和 shutdown 任意一个值置为 true,则表示 Client 处于不可用的状态,但有些许的差别,closing 是用户主动关闭的,即调用 Close 方法,而 shutdown 置为 true 一般是有错误发生。

Call 相关的三个方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}

func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}

func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
  • registerCall:将参数 call 添加到 client.pending 中,并更新 client.seq。
  • removeCall:根据 seq,从 client.pending 中移除对应的 call,并返回。
  • terminateCalls:服务端或客户端发生错误时调用,将 shutdown 设置为 true,且将错误信息通知所有 pending 状态的 call。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// Go invokes the function asynchronously.
// It returns the Call structure representing the invocation.
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}

// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(serviceMethod string, args, reply interface{}) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
  • GoCall 是客户端暴露给用户的两个 RPC 服务调用接口,Go 是一个异步接口,返回 call 实例。
  • Call 是对 Go 的封装,阻塞 call.Done,等待响应返回,是一个同步接口。

结构体映射为服务

假设客户端发过来一个请求,包含 ServiceMethod 和 Argv。

1
2
3
4
{
"ServiceMethod""T.MethodName"
"Argv""0101110101..." // 序列化之后的字节流
}

通过 “T.MethodName” 可以确定调用的是类型 T 的 MethodName,如果硬编码实现这个功能,很可能是这样:

1
2
3
4
5
6
7
8
9
10
11
12
switch req.ServiceMethod {
case "T.MethodName":
t := new(t)
reply := new(T2)
var argv T1
gob.NewDecoder(conn).Decode(&argv)
err := t.MethodName(argv, reply)
server.sendMessage(reply, err)
case "Foo.Sum":
f := new(Foo)
...
}

有没有什么方式,能够将这个映射过程自动化呢?可以借助反射

通过反射,我们能够非常容易地获取某个结构体的所有方法,并且能够通过方法,获取到该方法所有的参数类型与返回值。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
func main() {
var wg sync.WaitGroup
typ := reflect.TypeOf(&wg)
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
argv := make([]string, 0, method.Type.NumIn())
returns := make([]string, 0, method.Type.NumOut())
// j 从 1 开始,第 0 个入参是 wg 自己。
for j := 1; j < method.Type.NumIn(); j++ {
argv = append(argv, method.Type.In(j).Name())
}
for j := 0; j < method.Type.NumOut(); j++ {
returns = append(returns, method.Type.Out(j).Name())
}
log.Printf("func (w *%s) %s(%s) %s",
typ.Elem().Name(),
method.Name,
strings.Join(argv, ","),
strings.Join(returns, ","))
}
}

// 运行的结果是:
// func (w *WaitGroup) Add(int)
// func (w *WaitGroup) Done()
// func (w *WaitGroup) Wait()

通过反射实现 service

通过反射实现结构体与服务的映射关系

第一步,定义结构体 methodType:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
type methodType struct {
method reflect.Method // 方法本身
ArgType reflect.Type // 第一个参数的类型
ReplyType reflect.Type // 第二个参数的类型
numCalls uint64 // 后续统计方法调用次数时会用到
}

func (m *methodType) NumCalls() uint64 {
return atomic.LoadUint64(&m.numCalls)
}

func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
// arg may be a pointer type, or a value type
if m.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(m.ArgType.Elem())
} else {
argv = reflect.New(m.ArgType).Elem()
}
return argv
}

func (m *methodType) newReplyv() reflect.Value {
// reply must be a pointer type
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}

2 个方法 newArgvnewReplyv,用于创建对应类型的实例。newArgv 方法有一个小细节,指针类型和值类型创建实例的方式有细微区别。

第二步,定义结构体 service:

1
2
3
4
5
6
type service struct {
name string // 映射的结构体的名称,比如 `T`,比如 `WaitGroup`
typ reflect.Type // 结构体的类型
rcvr reflect.Value // 结构体的实例本身
method map[string]*methodType // 存储映射的结构体的所有符合条件的方法
}

接下来,完成构造函数 newService,入参是任意需要映射为服务的结构体实例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
func newService(rcvr interface{}) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
s.typ = reflect.TypeOf(rcvr)
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}

func (s *service) registerMethods() {
s.method = make(map[string]*methodType)
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}

func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}

registerMethods 过滤出了符合条件的方法:

  • 两个导出或内置类型的入参(反射时为 3 个,第 0 个是自身,类似于 python 的 self,java 中的 this)
  • 返回值有且只有 1 个,类型为 error

最后,我们还需要实现 call 方法,即能够通过反射值调用方法。

1
2
3
4
5
6
7
8
9
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}

最后,我们还需要实现 call 方法,即能够通过反射值调用方法。

1
2
3
4
5
6
7
8
9
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}

service 的测试用例

定义结构体 Foo,实现 2 个方法,导出方法 Sum 和 非导出方法 sum。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
type Foo int

type Args struct{ Num1, Num2 int }

func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}

// it's not a exported Method
func (f Foo) sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}

func _assert(condition bool, msg string, v ...interface{}) {
if !condition {
panic(fmt.Sprintf("assertion failed: "+msg, v...))
}
}

测试 newService 和 call 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
func TestNewService(t *testing.T) {
var foo Foo
s := newService(&foo)
_assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method))
mType := s.method["Sum"]
_assert(mType != nil, "wrong Method, Sum shouldn't nil")
}

func TestMethodType_Call(t *testing.T) {
var foo Foo
s := newService(&foo)
mType := s.method["Sum"]

argv := mType.newArgv()
replyv := mType.newReplyv()
argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
err := s.call(mType, argv, replyv)
_assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum")
}

readRequest 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// request stores all information of a call
type request struct {
h *codec.Header // header of request
argv, replyv reflect.Value // argv and replyv of request
mtype *methodType
svc *service
}

func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()

// make sure that argvi is a pointer, ReadBody need a pointer as parameter
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err:", err)
return req, err
}
return req, nil
}

readRequest 方法中最重要的部分,即通过 newArgv()newReplyv() 两个方法创建出两个入参实例,然后通过 cc.ReadBody() 将请求报文反序列化为第一个入参 argv,在这里同样需要注意 argv 可能是值类型,也可能是指针类型,所以处理方式有点差异。

超时机制

纵观整个远程调用的过程,需要客户端处理超时的地方有:

  • 与服务端建立连接,导致的超时
  • 发送请求到服务端,写报文导致的超时
  • 等待服务端处理时,等待处理导致的超时(比如服务端已挂死,迟迟不响应)
  • 从服务端接收响应时,读报文导致的超时

需要服务端处理超时的地方有:

  • 读取客户端请求报文时,读报文导致的超时
  • 发送响应报文时,写报文导致的超时
  • 调用映射服务的方法时,处理报文导致的超时

创建连接超时(select+chan 方式)

为了实现上的简单,将超时设定放在了 Option 中。ConnectTimeout 默认值为 10s,HandleTimeout 默认值为 0,即不设限。

1
2
3
4
5
6
7
8
9
10
11
12
type Option struct {
MagicNumber int // MagicNumber marks this's a geerpc request
CodecType codec.Type // client may choose different Codec to encode body
ConnectTimeout time.Duration // 0 means no limit
HandleTimeout time.Duration
}

var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
}

客户端连接超时,只需要为 Dial 添加一层超时处理的外壳即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewClient, network, address, opts...)
}

func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}

Call超时(context 方式)

Client.Call 的超时处理机制,使用 context 包实现,控制权交给用户,控制更为灵活。

1
2
3
4
5
6
7
8
9
10
11
12
// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}

用户可以使用 context.WithTimeout 创建具备超时检测能力的 context 对象来控制。例如:

1
2
3
4
ctx, _ := context.WithTimeout(context.Background(), time.Second)
var reply int
err := client.Call(ctx, "Foo.Sum", &Args{1, 2}, &reply)
...

服务端处理超时(select+chan 方式)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()

if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}

这里需要确保 sendResponse 仅调用一次,因此将整个过程拆分为 calledsent 两个阶段,在这段代码中只会发生如下两种情况:

  1. called 信道接收到消息,代表处理没有超时,继续执行 sendResponse。
  2. time.After() 先于 called 接收到消息,说明处理已经超时,called 和 sent 都将被阻塞。在 case <-time.After(timeout) 处调用 sendResponse

支持 HTTP 协议

RPC 的消息格式与标准的 HTTP 协议并不兼容,在这种情况下,就需要一个协议的转换过程。HTTP 协议的 CONNECT 方法恰好提供了这个能力,CONNECT 一般用于代理服务。

假设浏览器与服务器之间的 HTTPS 通信都是加密的,浏览器通过代理服务器发起 HTTPS 请求时,由于请求的站点地址和端口号都是加密保存在 HTTPS 请求报文头中的,代理服务器如何知道往哪里发送请求呢?为了解决这个问题,浏览器通过 HTTP 明文形式向代理服务器发送一个 CONNECT 请求告诉代理服务器目标地址和端口,代理服务器接收到这个请求后,会在对应端口与目标站点建立一个 TCP 连接,连接建立成功后返回 HTTP 200 状态码告诉浏览器与该站点的加密通道已经完成。接下来代理服务器仅需透传浏览器和服务器之间的加密数据包即可,代理服务器无需解析 HTTPS 报文

举一个简单例子:

  1. 浏览器向代理服务器发送 CONNECT 请求。

    1
    CONNECT geektutu.com:443 HTTP/1.0
  2. 代理服务器返回 HTTP 200 状态码表示连接已经建立。

    1
    HTTP/1.0 200 Connection Established
  3. 之后浏览器和服务器开始 HTTPS 握手并交换加密数据,代理服务器只负责传输彼此的数据包,并不能读取具体数据内容(代理服务器也可以选择安装可信根证书解密 HTTPS 报文)。

事实上,这个过程其实是通过代理服务器将 HTTP 协议转换为 HTTPS 协议的过程。对 RPC 服务端来,需要做的是将 HTTP 协议转换为 RPC 协议,对客户端来说,需要新增通过 HTTP CONNECT 请求创建连接的逻辑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
const (
connected = "200 Connected to Gee RPC"
defaultRPCPath = "/_geeprc_"
defaultDebugPath = "/debug/geerpc"
)

// ServeHTTP implements an http.Handler that answers RPC requests.
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}

// HandleHTTP registers an HTTP handler for RPC messages on rpcPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
}

// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
DefaultServer.HandleHTTP()
}

为了简化调用,提供了一个统一入口 XDial

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// XDial calls different functions to connect to a RPC server
// according the first parameter rpcAddr.
// rpcAddr is a general format (protocol@addr) to represent a rpc server
// eg, http@10.0.0.1:7001, tcp@10.0.0.1:9999, unix@/tmp/geerpc.sock
func XDial(rpcAddr string, opts ...*Option) (*Client, error) {
parts := strings.Split(rpcAddr, "@")
if len(parts) != 2 {
return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
}
protocol, addr := parts[0], parts[1]
switch protocol {
case "http":
return DialHTTP("tcp", addr, opts...)
default:
// tcp, unix or other transport protocol
return Dial(protocol, addr, opts...)
}
}

简单的 DEBUG 页面

/debug/geerpc 上展示服务的调用统计视图。我们将返回一个 HTML 报文,这个报文将展示注册所有的 service 的每一个方法的调用情况。将 debugHTTP 实例绑定到地址 /debug/geerpc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package geerpc

import (
"fmt"
"html/template"
"net/http"
)

const debugText = `<html>
<body>
<title>GeeRPC Services</title>
{{range .}}
<hr>
Service {{.Name}}
<hr>
<table>
<th align=center>Method</th><th align=center>Calls</th>
{{range $name, $mtype := .Method}}
<tr>
<td align=left font=fixed>{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error</td>
<td align=center>{{$mtype.NumCalls}}</td>
</tr>
{{end}}
</table>
{{end}}
</body>
</html>`

var debug = template.Must(template.New("RPC debug").Parse(debugText))

type debugHTTP struct {
*Server
}

type debugService struct {
Name string
Method map[string]*methodType
}

// Runs at /debug/geerpc
func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Build a sorted version of the data.
var services []debugService
server.serviceMap.Range(func(namei, svci interface{}) bool {
svc := svci.(*service)
services = append(services, debugService{
Name: namei.(string),
Method: svc.method,
})
return true
})
err := debug.Execute(w, services)
if err != nil {
_, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error())
}
}

负载均衡

通过随机选择和 Round Robin 轮询调度算法实现服务端负载均衡

负载均衡策略

假设有多个服务实例,每个实例提供相同的功能,为了提高整个系统的吞吐量,每个实例部署在不同的机器上。客户端可以选择任意一个实例进行调用,获取想要的结果。那如何选择呢?取决了负载均衡的策略。对于 RPC 框架来说,我们可以很容易地想到这么几种策略:

  • 随机选择策略 - 从服务列表中随机选择一个。
  • 轮询算法 (Round Robin) - 依次调度不同的服务器,每次调度执行 i = (i + 1) mode n。
  • 加权轮询 (Weight Round Robin) - 在轮询算法的基础上,为每个服务实例设置一个权重,高性能的机器赋予更高的权重,也可以根据服务实例的当前的负载情况做动态的调整,例如考虑最近 5 分钟部署服务器的 CPU、内存消耗情况。
  • 哈希 / 一致性哈希策略 - 依据请求的某些特征,计算一个 hash 值,根据 hash 值将请求发送到对应的机器。一致性 hash 还可以解决服务实例动态添加情况下,调度抖动的问题。一致性哈希的一个典型应用场景是分布式缓存服务。感兴趣可以阅读动手写分布式缓存 - GeeCache 第四天 一致性哈希 (hash)

服务发现

负载均衡的前提是有多个服务实例,那我们首先实现一个最基础的服务发现模块 Discovery。为了与通信部分解耦,这部分的代码统一放置在 xclient 子目录下。

定义 2 个类型:

  • SelectMode 代表不同的负载均衡策略,简单起见,GeeRPC 仅实现 Random 和 RoundRobin 两种策略。
  • Discovery 是一个接口类型,包含了服务发现所需要的最基本的函数。
    • Refresh():从注册中心更新服务列表
    • Update(servers [] string):手动更新服务列表
    • Get(mode SelectMode):根据负载均衡策略,选择一个服务实例
    • GetAll():返回所有的服务实例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
package xclient

type SelectMode int

const (
RandomSelect SelectMode = iota // select randomly
RoundRobinSelect // select using Robbin algorithm
)

type Discovery interface {
Refresh() error // refresh from remote registry
Update(servers []string) error
Get(mode SelectMode) (string, error)
GetAll() ([]string, error)
}

紧接着,我们实现一个不需要注册中心,服务列表由手工维护的服务发现的结构体:MultiServersDiscovery

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// MultiServersDiscovery is a discovery for multi servers without a registry center
// user provides the server addresses explicitly instead
type MultiServersDiscovery struct {
r *rand.Rand // generate random number
mu sync.RWMutex // protect following
servers []string
index int // record the selected position for robin algorithm
}

// NewMultiServerDiscovery creates a MultiServersDiscovery instance
func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery {
d := &MultiServersDiscovery{
servers: servers,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}
d.index = d.r.Intn(math.MaxInt32 - 1)
return d
}
  • r 是一个产生随机数的实例,初始化时使用时间戳设定随机数种子,避免每次产生相同的随机数序列。
  • index 记录 Round Robin 算法已经轮询到的位置,为了避免每次从 0 开始,初始化时随机设定一个值。

然后,实现 Discovery 接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
var _ Discovery = (*MultiServersDiscovery)(nil)

// Refresh doesn't make sense for MultiServersDiscovery, so ignore it
func (d *MultiServersDiscovery) Refresh() error {
return nil
}

// Update the servers of discovery dynamically if needed
func (d *MultiServersDiscovery) Update(servers []string) error {
d.mu.Lock()
defer d.mu.Unlock()
d.servers = servers
return nil
}

// Get a server according to mode
func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) {
d.mu.Lock()
defer d.mu.Unlock()
n := len(d.servers)
if n == 0 {
return "", errors.New("rpc discovery: no available servers")
}
switch mode {
case RandomSelect:
return d.servers[d.r.Intn(n)], nil
case RoundRobinSelect:
s := d.servers[d.index%n] // servers could be updated, so mode n to ensure safety
d.index = (d.index + 1) % n
return s, nil
default:
return "", errors.New("rpc discovery: not supported select mode")
}
}

// returns all servers in discovery
func (d *MultiServersDiscovery) GetAll() ([]string, error) {
d.mu.RLock()
defer d.mu.RUnlock()
// return a copy of d.servers
servers := make([]string, len(d.servers), len(d.servers))
copy(servers, d.servers)
return servers, nil
}

支持负载均衡的客户端

我们向用户暴露一个支持负载均衡的客户端 XClient。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package xclient

type XClient struct {
d Discovery
mode SelectMode
opt *Option
mu sync.Mutex // protect following
clients map[string]*Client
}

var _ io.Closer = (*XClient)(nil)

func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient {
return &XClient{d: d, mode: mode, opt: opt, clients: make(map[string]*Client)}
}

func (xc *XClient) Close() error {
xc.mu.Lock()
defer xc.mu.Unlock()
for key, client := range xc.clients {
// I have no idea how to deal with error, just ignore it.
_ = client.Close()
delete(xc.clients, key)
}
return nil
}

XClient 的构造函数需要传入三个参数,服务发现实例 Discovery、负载均衡模式 SelectMode 以及协议选项 Option。为了尽量地复用已经创建好的 Socket 连接,使用 clients 保存创建成功的 Client 实例,并提供 Close 方法在结束后,关闭已经建立的连接。

接下来,实现客户端最基本的功能 Call

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
func (xc *XClient) dial(rpcAddr string) (*Client, error) {
xc.mu.Lock()
defer xc.mu.Unlock()
client, ok := xc.clients[rpcAddr]
if ok && !client.IsAvailable() {
_ = client.Close()
delete(xc.clients, rpcAddr)
client = nil
}
if client == nil {
var err error
client, err = XDial(rpcAddr, xc.opt)
if err != nil {
return nil, err
}
xc.clients[rpcAddr] = client
}
return client, nil
}

func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error {
client, err := xc.dial(rpcAddr)
if err != nil {
return err
}
return client.Call(ctx, serviceMethod, args, reply)
}

// Call invokes the named function, waits for it to complete,
// and returns its error status.
// xc will choose a proper server.
func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
rpcAddr, err := xc.d.Get(xc.mode)
if err != nil {
return err
}
return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}

我们将复用 Client 的能力封装在方法 dial 中,dial 的处理逻辑如下:

  1. 检查 xc.clients 是否有缓存的 Client,如果有,检查是否是可用状态,如果是则返回缓存的 Client,如果不可用,则从缓存中删除。
  2. 如果步骤 1 没有返回缓存的 Client,则说明需要创建新的 Client,缓存并返回。

另外,我们为 XClient 添加一个常用功能:Broadcast

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// Broadcast invokes the named function for every server registered in discovery
func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error {
servers, err := xc.d.GetAll()
if err != nil {
return err
}
var wg sync.WaitGroup
var mu sync.Mutex // protect e and replyDone
var e error
replyDone := reply == nil // if reply is nil, don't need to set value
ctx, cancel := context.WithCancel(ctx)
for _, rpcAddr := range servers {
wg.Add(1)
go func(rpcAddr string) {
defer wg.Done()
var clonedReply interface{}
if reply != nil {
clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
}
err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply)
mu.Lock()
if err != nil && e == nil {
e = err
cancel() // if any call failed, cancel unfinished calls
}
if err == nil && !replyDone {
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
replyDone = true
}
mu.Unlock()
}(rpcAddr)
}
wg.Wait()
return e
}

Broadcast 将请求广播到所有的服务实例,如果任意一个实例发生错误,则返回其中一个错误;如果调用成功,则返回其中一个的结果。有以下几点需要注意:

  1. 为了提升性能,请求是并发的。
  2. 并发情况下需要使用互斥锁保证 error 和 reply 能被正确赋值。
  3. 借助 context.WithCancel 确保有错误发生时,快速失败。

服务发现与注册中心 (registry)

实现一个简单的注册中心,支持服务注册、接收心跳等功能。客户端实现基于注册中心的服务发现机制。

注册中心的位置

geerpc registry

注册中心的好处在于,客户端和服务端都只需要感知注册中心的存在,而无需感知对方的存在。

  1. 服务端启动后,向注册中心发送注册消息,注册中心得知该服务已经启动,处于可用状态。一般来说,服务端还需要定期向注册中心发送心跳,证明自己还活着。
  2. 客户端向注册中心询问,当前哪天服务是可用的,注册中心将可用的服务列表返回客户端。
  3. 客户端根据注册中心得到的服务列表,选择其中一个发起调用。

如果没有注册中心,客户端需要硬编码服务端的地址,而且没有机制保证服务端是否处于可用状态

当然注册中心的功能还有很多,比如配置的动态同步、通知机制等。比较常用的注册中心有 etcdzookeeperconsul,一般比较出名的微服务或者 RPC 框架,这些主流的注册中心都是支持的。

Registry

主流的注册中心 etcd、zookeeper 等功能强大,与这类注册中心的对接代码量是比较大的,需要实现的接口很多。GeeRPC 选择自己实现一个简单的支持心跳保活的注册中心。

GeeRegistry 的代码独立放置在子目录 registry 中。

首先定义 GeeRegistry 结构体,默认超时时间设置为 5 min,也就是说,任何注册的服务超过 5 min,即视为不可用状态。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// GeeRegistry is a simple register center, provide following functions.
// add a server and receive heartbeat to keep it alive.
// returns all alive servers and delete dead servers sync simultaneously.
type GeeRegistry struct {
timeout time.Duration
mu sync.Mutex // protect following
servers map[string]*ServerItem
}

type ServerItem struct {
Addr string
start time.Time
}

const (
defaultPath = "/_geerpc_/registry"
defaultTimeout = time.Minute * 5
)

// New create a registry instance with timeout setting
func New(timeout time.Duration) *GeeRegistry {
return &GeeRegistry{
servers: make(map[string]*ServerItem),
timeout: timeout,
}
}

var DefaultGeeRegister = New(defaultTimeout)

为 GeeRegistry 实现添加服务实例和返回服务列表的方法。

  • putServer:添加服务实例,如果服务已经存在,则更新 start。
  • aliveServers:返回可用的服务列表,如果存在超时的服务,则删除。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func (r *GeeRegistry) putServer(addr string) {
r.mu.Lock()
defer r.mu.Unlock()
s := r.servers[addr]
if s == nil {
r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()}
} else {
s.start = time.Now() // if exists, update start time to keep alive
}
}

func (r *GeeRegistry) aliveServers() []string {
r.mu.Lock()
defer r.mu.Unlock()
var alive []string
for addr, s := range r.servers {
if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) {
alive = append(alive, addr)
} else {
delete(r.servers, addr)
}
}
sort.Strings(alive)
return alive
}

为了实现上的简单,GeeRegistry 采用 HTTP 协议提供服务,且所有的有用信息都承载在 HTTP Header 中。

  • Get:返回所有可用的服务列表,通过自定义字段 X-Geerpc-Servers 承载。
  • Post:添加服务实例或发送心跳,通过自定义字段 X-Geerpc-Server 承载。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// Runs at /_geerpc_/registry
func (r *GeeRegistry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case "GET":
// keep it simple, server is in req.Header
w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ","))
case "POST":
// keep it simple, server is in req.Header
addr := req.Header.Get("X-Geerpc-Server")
if addr == "" {
w.WriteHeader(http.StatusInternalServerError)
return
}
r.putServer(addr)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}

// HandleHTTP registers an HTTP handler for GeeRegistry messages on registryPath
func (r *GeeRegistry) HandleHTTP(registryPath string) {
http.Handle(registryPath, r)
log.Println("rpc registry path:", registryPath)
}

func HandleHTTP() {
DefaultGeeRegister.HandleHTTP(defaultPath)
}

另外,提供 Heartbeat 方法,便于服务启动时定时向注册中心发送心跳,默认周期比注册中心设置的过期时间少 1 min。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// Heartbeat send a heartbeat message every once in a while
// it's a helper function for a server to register or send heartbeat
func Heartbeat(registry, addr string, duration time.Duration) {
if duration == 0 {
// make sure there is enough time to send heart beat
// before it's removed from registry
duration = defaultTimeout - time.Duration(1)*time.Minute
}
var err error
err = sendHeartbeat(registry, addr)
go func() {
t := time.NewTicker(duration)
for err == nil {
<-t.C
err = sendHeartbeat(registry, addr)
}
}()
}

func sendHeartbeat(registry, addr string) error {
log.Println(addr, "send heart beat to registry", registry)
httpClient := &http.Client{}
req, _ := http.NewRequest("POST", registry, nil)
req.Header.Set("X-Geerpc-Server", addr)
if _, err := httpClient.Do(req); err != nil {
log.Println("rpc server: heart beat err:", err)
return err
}
return nil
}

GeeRegistryDiscovery

在 xclient 中对应实现 Discovery。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package xclient

type GeeRegistryDiscovery struct {
*MultiServersDiscovery // 嵌套了 MultiServersDiscovery,很多能力可以复用。
registry string // 注册中心的地址
timeout time.Duration // 服务列表的过期时间
lastUpdate time.Time // 代表最后从注册中心更新服务列表的时间,默认 10s 过期,即 10s 之后,需要从注册中心更新新的列表。
}

const defaultUpdateTimeout = time.Second * 10

func NewGeeRegistryDiscovery(registerAddr string, timeout time.Duration) *GeeRegistryDiscovery {
if timeout == 0 {
timeout = defaultUpdateTimeout
}
d := &GeeRegistryDiscovery{
MultiServersDiscovery: NewMultiServerDiscovery(make([]string, 0)),
registry: registerAddr,
timeout: timeout,
}
return d
}

实现 Update 和 Refresh 方法,超时重新获取的逻辑在 Refresh 中实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
func (d *GeeRegistryDiscovery) Update(servers []string) error {
d.mu.Lock()
defer d.mu.Unlock()
d.servers = servers
d.lastUpdate = time.Now()
return nil
}

func (d *GeeRegistryDiscovery) Refresh() error {
d.mu.Lock()
defer d.mu.Unlock()
if d.lastUpdate.Add(d.timeout).After(time.Now()) {
return nil
}
log.Println("rpc registry: refresh servers from registry", d.registry)
resp, err := http.Get(d.registry)
if err != nil {
log.Println("rpc registry refresh err:", err)
return err
}
servers := strings.Split(resp.Header.Get("X-Geerpc-Servers"), ",")
d.servers = make([]string, 0, len(servers))
for _, server := range servers {
if strings.TrimSpace(server) != "" {
d.servers = append(d.servers, strings.TrimSpace(server))
}
}
d.lastUpdate = time.Now()
return nil
}

func (d *GeeRegistryDiscovery) Get(mode SelectMode) (string, error) {
// GeeRegistryDiscovery 需要先调用 Refresh 确保服务列表没有过期。
if err := d.Refresh(); err != nil {
return "", err
}
return d.MultiServersDiscovery.Get(mode)
}

func (d *GeeRegistryDiscovery) GetAll() ([]string, error) {
if err := d.Refresh(); err != nil {
return nil, err
}
return d.MultiServersDiscovery.GetAll()
}

reference