第 42 章

实现一个 RPC 框架

实现一个 RPC 框架

1984 年,Birrell 和 Nelson 发表了论文《Implementing Remote Procedure Calls》。他们提出了一个看起来很简单的想法:调用远程机器上的函数,应该和调用本地函数一样简单。这个想法催生了 RPC(Remote Procedure Call,远程过程调用),也催生了 40 年来分布式系统的核心基础设施。

今天,几乎所有大型系统的内部服务通信都通过 RPC 进行。Google 的 gRPC、Facebook 的 Thrift、阿里巴巴的 Dubbo——他们的本质都是在解决同一个问题:如何把网络调用包装得像函数调用一样透明

这一章我们用 Go 从零构建一个完整的 RPC 框架,然后逐步添加生产级特性:拦截器、负载均衡、健康检查。这个过程会揭示 gRPC 等框架的每一个核心设计决策背后的原因。

Level 1 · RPC 的本质

隐藏网络调用

没有 RPC 框架时,调用另一台机器上的服务是这样的:

// 手动写网络通信的痛苦
conn, _ := net.Dial("tcp", "10.0.0.1:8080")
request := serialize(AddRequest{A: 1, B: 2})
conn.Write(request)
buf := make([]byte, 1024)
conn.Read(buf)
response := deserialize(buf)
fmt.Println(response.Result)

每一次服务调用都要手动处理:建立连接、序列化参数、发送数据、接收响应、反序列化结果、处理错误。重复、繁琐、容易出错。

RPC 框架把这些全部隐藏掉,让调用看起来是这样:

// 有了 RPC 框架
client := NewMathClient("10.0.0.1:8080")
result, err := client.Add(ctx, &AddRequest{A: 1, B: 2})
fmt.Println(result.Result)

这背后涉及几个核心问题:

  1. 序列化:参数怎么变成字节,字节怎么变回参数?
  2. 传输:字节怎么在网络上传送?用什么协议?
  3. 服务发现:客户端怎么知道服务器在哪里?
  4. 错误处理:网络错误、超时、服务端异常怎么传播回来?

REST vs RPC vs GraphQL 的本质权衡

这三种 API 风格代表了三种不同的设计哲学,各有适用场景:

REST(Representational State Transfer):以资源为中心,用 HTTP 动词(GET/POST/PUT/DELETE)表达操作,用 URL 路径表达资源层级。优势在于通用性强(任何 HTTP 客户端都能用)、可缓存、人类可读。劣势在于:操作语义被强行映射到 CRUD,复杂操作("转账"、"审批")很难用资源模型表达;协议开销大(HTTP/1.1 头部每次都要传);版本管理复杂。

RPC:以动作(操作)为中心,直接暴露服务的方法。优势在于:类型安全(参数类型在 schema 中定义)、高性能(自定义序列化,HTTP/2 多路复用)、适合服务间内部通信。劣势在于:需要特定客户端(不是普通浏览器能直接用的)、跨语言支持需要代码生成、接口变更需要版本管理。

GraphQL:以数据图为中心,客户端精确声明需要哪些字段。优势在于:解决了 REST 的 over-fetching 和 under-fetching 问题,特别适合前端驱动的场景。劣势在于:服务端实现复杂(N+1 查询问题)、无法有效缓存、不适合服务间通信。

经验法则

Protobuf vs JSON:序列化的深层权衡

序列化格式的选择对性能影响巨大:

JSON:人类可读,跨语言支持好,调试方便。但是文本格式意味着:数字 1234567890 需要 10 字节,而二进制格式只需要 4 字节(int32);字段名每次都要传输("user_id": 这 9 个字节在每条记录中都重复出现);解析需要字符串比较,比二进制解析慢 5-10 倍。

Protobuf:Google 开源的二进制序列化格式。用字段编号(field number)代替字段名(user_id 在 schema 里是 field 1,序列化时只占 1-2 字节);用变长整数编码(Varint)压缩数值;无需模式时仍然可以向前和向后兼容(增减字段不破坏现有代码)。典型场景下,Protobuf 比 JSON 小 3-10 倍,解析速度快 5-20 倍。

代价:二进制不可读,必须有 .proto schema 文件才能解析,调试需要专门工具。

Level 2 · gRPC 架构原理

HTTP/2 作为传输层

gRPC 选择 HTTP/2 作为传输层不是偶然的:

多路复用:HTTP/1.1 的一个连接同一时刻只能处理一个请求(串行)。HTTP/2 的一个连接可以同时处理多个请求(流,Stream),完全消除了 HTTP/1.1 的队头阻塞(Head-of-Line Blocking)问题。gRPC 的并发调用复用同一个 TCP 连接,节省了连接建立开销。

头部压缩:HTTP/1.1 每次请求都要发送完整的 HTTP 头,gRPC 的元数据(方法名、认证 token 等)就是 HTTP/2 头部,通过 HPACK 压缩,重复头部几乎不占带宽。

流式传输:HTTP/2 原生支持双向流,gRPC 的 Server Streaming、Client Streaming、Bidirectional Streaming 都建立在这个基础上。

:HTTP/2 把数据切成帧(Frame)发送,每个帧有 Stream ID,接收方可以把不同流的帧重新组装。这使得一个连接真正并发处理多个请求成为可能。

拦截器:请求的中间件层

gRPC 的拦截器(Interceptor)是 RPC 框架最重要的扩展机制之一。它允许在不修改业务代码的情况下,给所有 RPC 调用统一添加:日志记录、指标收集、认证验证、限流、链路追踪。

拦截器形成一个调用链(类似 HTTP 中间件):

请求 → 拦截器A → 拦截器B → 业务处理函数 → 拦截器B → 拦截器A → 响应

gRPC 分别提供 Server-side 拦截器和 Client-side 拦截器,各自有两种:Unary(一元)和 Stream(流式)。

服务发现与负载均衡

gRPC 原生支持服务发现,通过 Resolver 接口接入 etcd、Consul、Kubernetes 等注册中心:

客户端 → Resolver(解析服务名→地址列表) → Balancer(选择一个地址) → 建立连接

负载均衡策略:

Level 3 · 从零构建 RPC 框架

整体架构设计

我们的 RPC 框架分为以下几层:

Client Side:                     Server Side:
                                 
┌─────────────────────┐         ┌─────────────────────┐
│  Generated Stub     │         │   Service Registry  │
│  (type-safe API)    │         │   (reflect-based)   │
└─────────────────────┘         └─────────────────────┘
         │                               │
┌─────────────────────┐         ┌─────────────────────┐
│  Client Interceptor │         │  Server Interceptor │
│  Chain              │         │  Chain              │
└─────────────────────┘         └─────────────────────┘
         │                               │
┌─────────────────────┐         ┌─────────────────────┐
│  Codec (gob/json/   │         │  Codec (decode req, │
│  protobuf)          │         │  encode resp)       │
└─────────────────────┘         └─────────────────────┘
         │                               │
┌─────────────────────┐         ┌─────────────────────┐
│  TCP Transport      │◄───────►│  TCP Transport      │
│  + Connection Pool  │         │  (per conn goroutine│
└─────────────────────┘         └─────────────────────┘

第一步:请求/响应协议和 Codec

package codec

import (
    "bufio"
    "encoding/gob"
    "encoding/json"
    "fmt"
    "io"
)

// Header 每个 RPC 调用的元数据
type Header struct {
    ServiceMethod string // "ServiceName.MethodName"
    Seq           uint64 // 请求序列号(用于异步调用的 response 匹配)
    Error         string // 服务端错误(非空时忽略 Body)
}

// Codec 负责将 Header+Body 编解码
type Codec interface {
    io.Closer
    ReadHeader(*Header) error
    ReadBody(interface{}) error
    Write(*Header, interface{}) error
}

// GobCodec 使用 Go 标准库 gob 编解码
type GobCodec struct {
    conn io.ReadWriteCloser
    buf  *bufio.Writer
    dec  *gob.Decoder
    enc  *gob.Encoder
}

func NewGobCodec(conn io.ReadWriteCloser) Codec {
    buf := bufio.NewWriter(conn)
    return &GobCodec{
        conn: conn,
        buf:  buf,
        dec:  gob.NewDecoder(conn),
        enc:  gob.NewEncoder(buf),
    }
}

func (c *GobCodec) ReadHeader(h *Header) error {
    return c.dec.Decode(h)
}

func (c *GobCodec) ReadBody(body interface{}) error {
    return c.dec.Decode(body)
}

func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
    defer func() {
        _ = c.buf.Flush()
        if err != nil {
            _ = c.Close()
        }
    }()
    if err = c.enc.Encode(h); err != nil {
        return fmt.Errorf("encode header: %w", err)
    }
    if err = c.enc.Encode(body); err != nil {
        return fmt.Errorf("encode body: %w", err)
    }
    return nil
}

func (c *GobCodec) Close() error { return c.conn.Close() }

// JSONCodec 使用 JSON 编解码(更易调试)
type JSONCodec struct {
    conn io.ReadWriteCloser
    dec  *json.Decoder
    enc  *json.Encoder
}

func NewJSONCodec(conn io.ReadWriteCloser) Codec {
    return &JSONCodec{
        conn: conn,
        dec:  json.NewDecoder(conn),
        enc:  json.NewEncoder(conn),
    }
}

func (c *JSONCodec) ReadHeader(h *Header) error   { return c.dec.Decode(h) }
func (c *JSONCodec) ReadBody(body interface{}) error { return c.dec.Decode(body) }
func (c *JSONCodec) Write(h *Header, body interface{}) error {
    if err := c.enc.Encode(h); err != nil {
        return err
    }
    return c.enc.Encode(body)
}
func (c *JSONCodec) Close() error { return c.conn.Close() }

type CodecType string

const (
    GobType  CodecType = "application/gob"
    JSONType CodecType = "application/json"
)

// NewCodecFunc 根据类型创建 Codec 的工厂函数
type NewCodecFunc func(io.ReadWriteCloser) Codec

var NewCodecFuncMap = map[CodecType]NewCodecFunc{
    GobType:  NewGobCodec,
    JSONType: NewJSONCodec,
}

第二步:服务注册(反射实现)

package server

import (
    "fmt"
    "go/token"
    "log"
    "reflect"
    "strings"
    "sync"
)

// methodType 描述一个可供远程调用的方法
type methodType struct {
    method    reflect.Method
    ArgType   reflect.Type
    ReplyType reflect.Type
    numCalls  uint64
}

func (m *methodType) newArgv() reflect.Value {
    // 如果是指针类型,创建对应元素类型的值
    if m.ArgType.Kind() == reflect.Ptr {
        return reflect.New(m.ArgType.Elem())
    }
    return reflect.New(m.ArgType).Elem()
}

func (m *methodType) newReplyv() reflect.Value {
    // Reply 一定是指针类型
    replyv := reflect.New(m.ReplyType.Elem())
    switch m.ReplyType.Elem().Kind() {
    case reflect.Map:
        replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
    case reflect.Slice:
        replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
    }
    return replyv
}

// service 代表一个注册的服务
type service struct {
    name    string
    typ     reflect.Type
    rcvr    reflect.Value
    methods map[string]*methodType
}

func newService(rcvr interface{}) *service {
    s := &service{}
    s.rcvr = reflect.ValueOf(rcvr)
    s.typ = reflect.TypeOf(rcvr)
    s.name = reflect.Indirect(s.rcvr).Type().Name()
    if !token.IsExported(s.name) {
        log.Fatalf("rpc server: %s is not a valid service name", s.name)
    }
    s.registerMethods()
    return s
}

func (s *service) registerMethods() {
    s.methods = make(map[string]*methodType)
    for i := 0; i < s.typ.NumMethod(); i++ {
        method := s.typ.Method(i)
        mType := method.Type
        // 合法的 RPC 方法签名: func (t *T) MethodName(ctx context.Context, args *Args, reply *Reply) error
        // 简化版: func (t *T) MethodName(args ArgType, reply *ReplyType) error
        if mType.NumIn() != 3 || mType.NumOut() != 1 {
            continue
        }
        if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
            continue
        }
        argType, replyType := mType.In(1), mType.In(2)
        if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
            continue
        }
        s.methods[method.Name] = &methodType{
            method:    method,
            ArgType:   argType,
            ReplyType: replyType,
        }
        log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
    }
}

func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
    f := m.method.Func
    returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
    if errInter := returnValues[0].Interface(); errInter != nil {
        return errInter.(error)
    }
    return nil
}

func isExportedOrBuiltinType(t reflect.Type) bool {
    return token.IsExported(t.Name()) || t.PkgPath() == ""
}

// Server RPC 服务器
type Server struct {
    serviceMap sync.Map
}

func NewServer() *Server { return &Server{} }

func (s *Server) Register(rcvr interface{}) error {
    svc := newService(rcvr)
    if _, dup := s.serviceMap.LoadOrStore(svc.name, svc); dup {
        return fmt.Errorf("rpc: service already defined: %s", svc.name)
    }
    return nil
}

func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
    dot := strings.LastIndex(serviceMethod, ".")
    if dot < 0 {
        err = fmt.Errorf("rpc server: service/method request ill-formed: %s", serviceMethod)
        return
    }
    serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
    svci, ok := s.serviceMap.Load(serviceName)
    if !ok {
        err = fmt.Errorf("rpc server: can't find service %s", serviceName)
        return
    }
    svc = svci.(*service)
    mtype, ok = svc.methods[methodName]
    if !ok {
        err = fmt.Errorf("rpc server: can't find method %s", methodName)
    }
    return
}

第三步:服务器连接处理

package server

import (
    "encoding/json"
    "fmt"
    "io"
    "log"
    "net"
    "reflect"
    "strings"
    "sync"
    "time"

    "github.com/yourname/minirpc/codec"
)

// Option 协商阶段的配置(客户端发送,服务端读取)
type Option struct {
    MagicNumber int        // 魔数,验证这是一个 minirpc 请求
    CodecType   codec.CodecType
    ConnTimeout time.Duration // 连接超时
    HandleTimeout time.Duration // 处理超时
}

const MagicNumber = 0x3bef5c

var DefaultOption = &Option{
    MagicNumber:   MagicNumber,
    CodecType:     codec.GobType,
    ConnTimeout:   time.Second * 10,
}

// ServeConn 处理单个连接
func (s *Server) ServeConn(conn io.ReadWriteCloser) {
    defer conn.Close()
    var opt Option
    if err := json.NewDecoder(conn).Decode(&opt); err != nil {
        log.Printf("rpc server: decode option error: %v", err)
        return
    }
    if opt.MagicNumber != MagicNumber {
        log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
        return
    }
    newCodec := codec.NewCodecFuncMap[opt.CodecType]
    if newCodec == nil {
        log.Printf("rpc server: invalid codec type %s", opt.CodecType)
        return
    }
    s.serveCodec(newCodec(conn), &opt)
}

type request struct {
    h      *codec.Header
    argv   reflect.Value
    replyv reflect.Value
    mtype  *methodType
    svc    *service
}

func (s *Server) serveCodec(cc codec.Codec, opt *Option) {
    sending := new(sync.Mutex)
    wg := new(sync.WaitGroup)

    for {
        req, err := s.readRequest(cc)
        if err != nil {
            if req == nil {
                break // 无法恢复的错误,关闭连接
            }
            req.h.Error = err.Error()
            s.sendResponse(cc, req.h, invalidRequest, sending)
            continue
        }
        wg.Add(1)
        go s.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
    }
    wg.Wait()
    _ = cc.Close()
}

var invalidRequest = struct{}{}

func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
    var h codec.Header
    if err := cc.ReadHeader(&h); err != nil {
        if err != io.EOF && !strings.HasSuffix(err.Error(), "EOF") {
            log.Printf("rpc server: read header error: %v", err)
        }
        return nil, err
    }
    return &h, nil
}

func (s *Server) readRequest(cc codec.Codec) (*request, error) {
    h, err := s.readRequestHeader(cc)
    if err != nil {
        return nil, err
    }
    req := &request{h: h}
    req.svc, req.mtype, err = s.findService(h.ServiceMethod)
    if err != nil {
        if err := cc.ReadBody(nil); err != nil {
            log.Printf("rpc server: read body err: %v", err)
        }
        return req, err
    }
    req.argv = req.mtype.newArgv()
    req.replyv = req.mtype.newReplyv()

    argvi := req.argv.Interface()
    if req.argv.Type().Kind() != reflect.Ptr {
        argvi = req.argv.Addr().Interface()
    }
    if err = cc.ReadBody(argvi); err != nil {
        log.Printf("rpc server: read body err: %v", err)
        return req, err
    }
    return req, nil
}

func (s *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
    defer wg.Done()

    called := make(chan struct{})
    sent := make(chan struct{})

    go func() {
        err := req.svc.call(req.mtype, req.argv, req.replyv)
        called <- struct{}{}
        if err != nil {
            req.h.Error = err.Error()
            s.sendResponse(cc, req.h, invalidRequest, sending)
            sent <- struct{}{}
            return
        }
        s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
        sent <- struct{}{}
    }()

    if timeout == 0 {
        <-called
        <-sent
        return
    }
    select {
    case <-time.After(timeout):
        req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
        s.sendResponse(cc, req.h, invalidRequest, sending)
    case <-called:
        <-sent
    }
}

func (s *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
    sending.Lock()
    defer sending.Unlock()
    if err := cc.Write(h, body); err != nil {
        log.Printf("rpc server: write response error: %v", err)
    }
}

// Accept 监听并处理请求
func (s *Server) Accept(lis net.Listener) {
    for {
        conn, err := lis.Accept()
        if err != nil {
            log.Printf("rpc server: accept error: %v", err)
            return
        }
        go s.ServeConn(conn)
    }
}

var DefaultServer = NewServer()

func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
func Accept(lis net.Listener)         { DefaultServer.Accept(lis) }

第四步:客户端(支持异步调用)

package client

import (
    "context"
    "encoding/json"
    "fmt"
    "io"
    "log"
    "net"
    "sync"
    "time"

    "github.com/yourname/minirpc/codec"
    "github.com/yourname/minirpc/server"
)

// Call 代表一次正在进行的 RPC 调用
type Call struct {
    Seq           uint64
    ServiceMethod string
    Args          interface{}
    Reply         interface{}
    Error         error
    Done          chan *Call // 调用完成后通知
}

func (c *Call) done() {
    c.Done <- c
}

// Client RPC 客户端
type Client struct {
    cc       codec.Codec
    opt      *server.Option
    sending  sync.Mutex
    header   codec.Header
    mu       sync.Mutex
    seq      uint64
    pending  map[uint64]*Call
    closing  bool // 用户主动关闭
    shutdown bool // 服务端要求关闭或错误
}

func NewClient(conn net.Conn, opt *server.Option) (*Client, error) {
    newCodecFunc := codec.NewCodecFuncMap[opt.CodecType]
    if newCodecFunc == nil {
        err := fmt.Errorf("invalid codec type %s", opt.CodecType)
        return nil, err
    }
    // 发送协商选项
    if err := json.NewEncoder(conn).Encode(opt); err != nil {
        return nil, fmt.Errorf("rpc client: encode option error: %w", err)
    }
    client := &Client{
        seq:     1,
        cc:      newCodecFunc(conn),
        opt:     opt,
        pending: make(map[uint64]*Call),
    }
    go client.receive()
    return client, nil
}

func Dial(network, address string, opts ...*server.Option) (client *Client, err error) {
    opt := server.DefaultOption
    if len(opts) > 0 && opts[0] != nil {
        opt = opts[0]
    }

    conn, err := net.DialTimeout(network, address, opt.ConnTimeout)
    if err != nil {
        return nil, err
    }
    defer func() {
        if err != nil {
            _ = conn.Close()
        }
    }()
    return NewClient(conn, opt)
}

func (c *Client) Close() error {
    c.mu.Lock()
    defer c.mu.Unlock()
    if c.closing {
        return fmt.Errorf("connection already closed")
    }
    c.closing = true
    return c.cc.Close()
}

func (c *Client) IsAvailable() bool {
    c.mu.Lock()
    defer c.mu.Unlock()
    return !c.shutdown && !c.closing
}

func (c *Client) registerCall(call *Call) (uint64, error) {
    c.mu.Lock()
    defer c.mu.Unlock()
    if c.closing || c.shutdown {
        return 0, fmt.Errorf("rpc client: client is shutting down")
    }
    call.Seq = c.seq
    c.pending[call.Seq] = call
    c.seq++
    return call.Seq, nil
}

func (c *Client) removeCall(seq uint64) *Call {
    c.mu.Lock()
    defer c.mu.Unlock()
    call := c.pending[seq]
    delete(c.pending, seq)
    return call
}

func (c *Client) terminateCalls(err error) {
    c.sending.Lock()
    defer c.sending.Unlock()
    c.mu.Lock()
    defer c.mu.Unlock()
    c.shutdown = true
    for _, call := range c.pending {
        call.Error = err
        call.done()
    }
}

func (c *Client) receive() {
    var err error
    for err == nil {
        var h codec.Header
        if err = c.cc.ReadHeader(&h); err != nil {
            break
        }
        call := c.removeCall(h.Seq)
        switch {
        case call == nil:
            err = c.cc.ReadBody(nil)
        case h.Error != "":
            call.Error = fmt.Errorf(h.Error)
            err = c.cc.ReadBody(nil)
            call.done()
        default:
            err = c.cc.ReadBody(call.Reply)
            if err != nil {
                call.Error = fmt.Errorf("reading body: %w", err)
            }
            call.done()
        }
    }
    c.terminateCalls(err)
}

func (c *Client) send(call *Call) {
    c.sending.Lock()
    defer c.sending.Unlock()

    seq, err := c.registerCall(call)
    if err != nil {
        call.Error = err
        call.done()
        return
    }

    c.header.ServiceMethod = call.ServiceMethod
    c.header.Seq = seq
    c.header.Error = ""

    if err := c.cc.Write(&c.header, call.Args); err != nil {
        call := c.removeCall(seq)
        if call != nil {
            call.Error = err
            call.done()
        }
    }
}

// Go 异步调用,立即返回 Call
func (c *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
    if done == nil {
        done = make(chan *Call, 1)
    }
    call := &Call{
        ServiceMethod: serviceMethod,
        Args:          args,
        Reply:         reply,
        Done:          done,
    }
    c.send(call)
    return call
}

// Call 同步调用,阻塞直到完成或超时
func (c *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    call := c.Go(serviceMethod, args, reply, make(chan *Call, 1))
    select {
    case <-ctx.Done():
        c.removeCall(call.Seq)
        return fmt.Errorf("rpc client: call failed: %s", ctx.Err())
    case call := <-call.Done:
        return call.Error
    }
}

第五步:拦截器和负载均衡

package middleware

import (
    "context"
    "log"
    "time"
)

// UnaryInterceptor 一元 RPC 拦截器类型
type UnaryInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error)

type UnaryServerInfo struct {
    Server        interface{}
    FullMethod    string
}

type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)

// ChainUnaryInterceptors 将多个拦截器链接成一个
func ChainUnaryInterceptors(interceptors ...UnaryInterceptor) UnaryInterceptor {
    return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
        chained := handler
        for i := len(interceptors) - 1; i >= 0; i-- {
            interceptor := interceptors[i]
            next := chained
            chained = func(ctx context.Context, req interface{}) (interface{}, error) {
                return interceptor(ctx, req, info, next)
            }
        }
        return chained(ctx, req)
    }
}

// LoggingInterceptor 日志拦截器
func LoggingInterceptor(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
    start := time.Now()
    log.Printf("RPC call: %s, req: %v", info.FullMethod, req)
    resp, err := handler(ctx, req)
    elapsed := time.Since(start)
    if err != nil {
        log.Printf("RPC failed: %s, error: %v, elapsed: %v", info.FullMethod, err, elapsed)
    } else {
        log.Printf("RPC success: %s, elapsed: %v", info.FullMethod, elapsed)
    }
    return resp, err
}

// RecoveryInterceptor panic 恢复拦截器
func RecoveryInterceptor(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
    defer func() {
        if r := recover(); r != nil {
            log.Printf("RPC panic recovered: %s, panic: %v", info.FullMethod, r)
            err = fmt.Errorf("internal server error: panic")
        }
    }()
    return handler(ctx, req)
}

// ===================== 负载均衡 =====================

package balancer

import (
    "fmt"
    "math/rand"
    "sync"
    "sync/atomic"
)

// SelectMode 负载均衡策略
type SelectMode int

const (
    RandomSelect     SelectMode = iota
    RoundRobinSelect
)

// Discovery 服务发现接口
type Discovery interface {
    Refresh() error                   // 从注册中心刷新服务列表
    Update(servers []string) error    // 手动更新服务列表
    Get(mode SelectMode) (string, error) // 根据策略选择一个服务
    GetAll() ([]string, error)        // 获取所有服务
}

// MultiServerDiscovery 手动维护的服务列表(不依赖注册中心)
type MultiServerDiscovery struct {
    r       *rand.Rand
    mu      sync.RWMutex
    servers []string
    index   uint64 // Round Robin 计数器
}

func NewMultiServerDiscovery(servers []string) *MultiServerDiscovery {
    return &MultiServerDiscovery{
        r:       rand.New(rand.NewSource(rand.Int63())),
        servers: servers,
    }
}

func (d *MultiServerDiscovery) Refresh() error { return nil }

func (d *MultiServerDiscovery) Update(servers []string) error {
    d.mu.Lock()
    defer d.mu.Unlock()
    d.servers = servers
    return nil
}

func (d *MultiServerDiscovery) Get(mode SelectMode) (string, error) {
    d.mu.RLock()
    defer d.mu.RUnlock()
    n := len(d.servers)
    if n == 0 {
        return "", fmt.Errorf("rpc discovery: no available servers")
    }
    switch mode {
    case RandomSelect:
        return d.servers[d.r.Intn(n)], nil
    case RoundRobinSelect:
        // 原子自增,取模选择
        idx := atomic.AddUint64(&d.index, 1) - 1
        return d.servers[idx%uint64(n)], nil
    default:
        return "", fmt.Errorf("rpc discovery: unsupported select mode")
    }
}

func (d *MultiServerDiscovery) GetAll() ([]string, error) {
    d.mu.RLock()
    defer d.mu.RUnlock()
    return d.servers, nil
}

// XClient 支持负载均衡的 RPC 客户端
type XClient struct {
    d       Discovery
    mode    SelectMode
    opt     *server.Option
    mu      sync.Mutex
    clients map[string]*client.Client
}

func NewXClient(d Discovery, mode SelectMode, opt *server.Option) *XClient {
    return &XClient{
        d:       d,
        mode:    mode,
        opt:     opt,
        clients: make(map[string]*client.Client),
    }
}

func (xc *XClient) Close() error {
    xc.mu.Lock()
    defer xc.mu.Unlock()
    for _, c := range xc.clients {
        _ = c.Close()
    }
    return nil
}

func (xc *XClient) dial(rpcAddr string) (*client.Client, error) {
    xc.mu.Lock()
    defer xc.mu.Unlock()
    c, ok := xc.clients[rpcAddr]
    if ok && !c.IsAvailable() {
        _ = c.Close()
        delete(xc.clients, rpcAddr)
        c = nil
    }
    if c == nil {
        var err error
        c, err = client.Dial("tcp", rpcAddr, xc.opt)
        if err != nil {
            return nil, err
        }
        xc.clients[rpcAddr] = c
    }
    return c, nil
}

func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error {
    c, err := xc.dial(rpcAddr)
    if err != nil {
        return err
    }
    return c.Call(ctx, serviceMethod, args, reply)
}

func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    rpcAddr, err := xc.d.Get(xc.mode)
    if err != nil {
        return err
    }
    return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}

// Broadcast 广播:调用所有服务,有任一成功则返回,有任一失败则收集错误
func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    servers, err := xc.d.GetAll()
    if err != nil {
        return err
    }
    var wg sync.WaitGroup
    var mu sync.Mutex
    var e error
    replyDone := reply == nil

    ctx, cancel := context.WithCancel(ctx)
    defer cancel()

    for _, rpcAddr := range servers {
        wg.Add(1)
        go func(rpcAddr string) {
            defer wg.Done()
            var clonedReply interface{}
            if reply != nil {
                clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
            }
            err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply)
            mu.Lock()
            defer mu.Unlock()
            if err != nil && e == nil {
                e = err
                cancel() // 有错误,取消其他调用
            }
            if err == nil && !replyDone {
                reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
                replyDone = true
            }
        }(rpcAddr)
    }
    wg.Wait()
    return e
}

完整使用示例

package main

import (
    "context"
    "fmt"
    "log"
    "net"
    "sync"
    "time"

    "github.com/yourname/minirpc/client"
    "github.com/yourname/minirpc/server"
)

// 定义服务
type MathService struct{}

type Args struct{ A, B int }
type Reply struct{ C int }

func (m *MathService) Add(args Args, reply *Reply) error {
    reply.C = args.A + args.B
    return nil
}

func (m *MathService) Multiply(args Args, reply *Reply) error {
    reply.C = args.A * args.B
    return nil
}

func startServer(wg *sync.WaitGroup) {
    var m MathService
    if err := server.Register(&m); err != nil {
        log.Fatal("register error:", err)
    }
    lis, err := net.Listen("tcp", ":9999")
    if err != nil {
        log.Fatal("listen error:", err)
    }
    log.Println("rpc server started on :9999")
    wg.Done()
    server.Accept(lis)
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go startServer(&wg)
    wg.Wait()
    time.Sleep(time.Second)

    c, err := client.Dial("tcp", ":9999")
    if err != nil {
        log.Fatal("dial error:", err)
    }
    defer c.Close()

    ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
    defer cancel()

    var reply Reply
    if err := c.Call(ctx, "MathService.Add", Args{A: 1, B: 2}, &reply); err != nil {
        log.Fatal("call error:", err)
    }
    fmt.Printf("1 + 2 = %d\n", reply.C) // 1 + 2 = 3

    if err := c.Call(ctx, "MathService.Multiply", Args{A: 6, B: 7}, &reply); err != nil {
        log.Fatal("call error:", err)
    }
    fmt.Printf("6 * 7 = %d\n", reply.C) // 6 * 7 = 42
}

Level 4 · 进阶与边界

健康检查协议

生产环境中的 RPC 框架必须能检测到后端实例是否健康。gRPC 定义了标准的健康检查协议,服务端实现 grpc.health.v1.Health 服务:

// 健康检查服务接口(简化)
type HealthServer interface {
    Check(ctx context.Context, req *HealthCheckRequest) (*HealthCheckResponse, error)
    Watch(req *HealthCheckRequest, stream Health_WatchServer) error
}

// 周期性探活
func healthCheck(addr string, interval time.Duration, unhealthy chan<- string) {
    for {
        conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
        if err != nil {
            unhealthy <- addr
        } else {
            conn.Close()
        }
        time.Sleep(interval)
    }
}

gRPC-Gateway:REST 兼容性

内部服务用 gRPC,但对外 API 需要 REST 接口——这是常见需求。grpc-gateway 是解决方案:它读取 Protobuf 文件中的 HTTP 注解,自动生成一个 HTTP 代理层,把 REST 请求转换成 gRPC 调用:

service MathService {
  rpc Add(AddRequest) returns (AddResponse) {
    option (google.api.http) = {
      post: "/v1/math/add"
      body: "*"
    };
  }
}

超时传播

在微服务链路中,超时必须向下传播。如果调用链是 A → B → C,A 设置了 1 秒超时,B 调用 C 时必须知道还剩多少时间,否则 C 可能在 A 已经超时后还在执行。

gRPC 用 context.Context 传播截止时间(Deadline):

// A 发起调用,设置截止时间
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// context.Deadline() 会被 gRPC 序列化到请求的元数据中
resp, err := clientB.Call(ctx, ...)

// B 收到请求时,context 已经带有截止时间
// B 调用 C 时,把同一个 ctx 传下去
respC, err := clientC.Call(ctx, ...) // ctx.Deadline() 仍然是 A 设置的时间点

gRPC vs Thrift vs Dubbo

gRPC:Google 出品,HTTP/2 + Protobuf,跨语言支持最广(几乎所有语言都有官方支持),流式 RPC 支持好,CNCF 生态深度集成(Kubernetes、Istio 等都原生支持)。

Thrift:Facebook 出品,自定义二进制协议,比 gRPC 更老,性能接近,但生态稍逊,流式支持有限。在 Facebook 内部使用,HBase 的客户端协议也用 Thrift。

Dubbo:阿里巴巴出品,专注 Java 生态,服务治理功能更完整(限流、熔断、灰度发布内建),适合 Java 微服务体系。

性能粗比(QPS,仅供参考)

框架 序列化 传输 相对性能
gRPC Protobuf HTTP/2 基准
Thrift Binary Thrift TCP ~110%
JSON over HTTP/1.1 JSON HTTP/1.1 ~30-40%
gob over TCP gob TCP ~90%

Protobuf Schema 演化规则

这些规则保证了服务的滚动升级(rolling upgrade)不会因为 schema 变化导致新旧代码无法通信。这是生产环境中真正让 RPC 框架可用的关键特性之一。

本章评分
4.7  / 5  (3 评分)

💬 留言讨论