实现一个 RPC 框架
实现一个 RPC 框架
1984 年,Birrell 和 Nelson 发表了论文《Implementing Remote Procedure Calls》。他们提出了一个看起来很简单的想法:调用远程机器上的函数,应该和调用本地函数一样简单。这个想法催生了 RPC(Remote Procedure Call,远程过程调用),也催生了 40 年来分布式系统的核心基础设施。
今天,几乎所有大型系统的内部服务通信都通过 RPC 进行。Google 的 gRPC、Facebook 的 Thrift、阿里巴巴的 Dubbo——他们的本质都是在解决同一个问题:如何把网络调用包装得像函数调用一样透明。
这一章我们用 Go 从零构建一个完整的 RPC 框架,然后逐步添加生产级特性:拦截器、负载均衡、健康检查。这个过程会揭示 gRPC 等框架的每一个核心设计决策背后的原因。
Level 1 · RPC 的本质
隐藏网络调用
没有 RPC 框架时,调用另一台机器上的服务是这样的:
// 手动写网络通信的痛苦
conn, _ := net.Dial("tcp", "10.0.0.1:8080")
request := serialize(AddRequest{A: 1, B: 2})
conn.Write(request)
buf := make([]byte, 1024)
conn.Read(buf)
response := deserialize(buf)
fmt.Println(response.Result)
每一次服务调用都要手动处理:建立连接、序列化参数、发送数据、接收响应、反序列化结果、处理错误。重复、繁琐、容易出错。
RPC 框架把这些全部隐藏掉,让调用看起来是这样:
// 有了 RPC 框架
client := NewMathClient("10.0.0.1:8080")
result, err := client.Add(ctx, &AddRequest{A: 1, B: 2})
fmt.Println(result.Result)
这背后涉及几个核心问题:
- 序列化:参数怎么变成字节,字节怎么变回参数?
- 传输:字节怎么在网络上传送?用什么协议?
- 服务发现:客户端怎么知道服务器在哪里?
- 错误处理:网络错误、超时、服务端异常怎么传播回来?
REST vs RPC vs GraphQL 的本质权衡
这三种 API 风格代表了三种不同的设计哲学,各有适用场景:
REST(Representational State Transfer):以资源为中心,用 HTTP 动词(GET/POST/PUT/DELETE)表达操作,用 URL 路径表达资源层级。优势在于通用性强(任何 HTTP 客户端都能用)、可缓存、人类可读。劣势在于:操作语义被强行映射到 CRUD,复杂操作("转账"、"审批")很难用资源模型表达;协议开销大(HTTP/1.1 头部每次都要传);版本管理复杂。
RPC:以动作(操作)为中心,直接暴露服务的方法。优势在于:类型安全(参数类型在 schema 中定义)、高性能(自定义序列化,HTTP/2 多路复用)、适合服务间内部通信。劣势在于:需要特定客户端(不是普通浏览器能直接用的)、跨语言支持需要代码生成、接口变更需要版本管理。
GraphQL:以数据图为中心,客户端精确声明需要哪些字段。优势在于:解决了 REST 的 over-fetching 和 under-fetching 问题,特别适合前端驱动的场景。劣势在于:服务端实现复杂(N+1 查询问题)、无法有效缓存、不适合服务间通信。
经验法则:
- 对外 API(面向浏览器、第三方):REST 或 GraphQL
- 内部服务间通信(微服务):RPC(gRPC)
- 前端密集型产品需要灵活数据查询:GraphQL
Protobuf vs JSON:序列化的深层权衡
序列化格式的选择对性能影响巨大:
JSON:人类可读,跨语言支持好,调试方便。但是文本格式意味着:数字 1234567890 需要 10 字节,而二进制格式只需要 4 字节(int32);字段名每次都要传输("user_id": 这 9 个字节在每条记录中都重复出现);解析需要字符串比较,比二进制解析慢 5-10 倍。
Protobuf:Google 开源的二进制序列化格式。用字段编号(field number)代替字段名(user_id 在 schema 里是 field 1,序列化时只占 1-2 字节);用变长整数编码(Varint)压缩数值;无需模式时仍然可以向前和向后兼容(增减字段不破坏现有代码)。典型场景下,Protobuf 比 JSON 小 3-10 倍,解析速度快 5-20 倍。
代价:二进制不可读,必须有 .proto schema 文件才能解析,调试需要专门工具。
Level 2 · gRPC 架构原理
HTTP/2 作为传输层
gRPC 选择 HTTP/2 作为传输层不是偶然的:
多路复用:HTTP/1.1 的一个连接同一时刻只能处理一个请求(串行)。HTTP/2 的一个连接可以同时处理多个请求(流,Stream),完全消除了 HTTP/1.1 的队头阻塞(Head-of-Line Blocking)问题。gRPC 的并发调用复用同一个 TCP 连接,节省了连接建立开销。
头部压缩:HTTP/1.1 每次请求都要发送完整的 HTTP 头,gRPC 的元数据(方法名、认证 token 等)就是 HTTP/2 头部,通过 HPACK 压缩,重复头部几乎不占带宽。
流式传输:HTTP/2 原生支持双向流,gRPC 的 Server Streaming、Client Streaming、Bidirectional Streaming 都建立在这个基础上。
帧:HTTP/2 把数据切成帧(Frame)发送,每个帧有 Stream ID,接收方可以把不同流的帧重新组装。这使得一个连接真正并发处理多个请求成为可能。
拦截器:请求的中间件层
gRPC 的拦截器(Interceptor)是 RPC 框架最重要的扩展机制之一。它允许在不修改业务代码的情况下,给所有 RPC 调用统一添加:日志记录、指标收集、认证验证、限流、链路追踪。
拦截器形成一个调用链(类似 HTTP 中间件):
请求 → 拦截器A → 拦截器B → 业务处理函数 → 拦截器B → 拦截器A → 响应
gRPC 分别提供 Server-side 拦截器和 Client-side 拦截器,各自有两种:Unary(一元)和 Stream(流式)。
服务发现与负载均衡
gRPC 原生支持服务发现,通过 Resolver 接口接入 etcd、Consul、Kubernetes 等注册中心:
客户端 → Resolver(解析服务名→地址列表) → Balancer(选择一个地址) → 建立连接
负载均衡策略:
- Round Robin:轮询,最简单,适合无状态服务
- Least Connection:选择当前连接数最少的实例,适合请求处理时间差异大的场景
- Weighted Round Robin:按权重轮询,适合异构机器
- Consistent Hashing:一致性哈希,适合需要会话粘连的场景
Level 3 · 从零构建 RPC 框架
整体架构设计
我们的 RPC 框架分为以下几层:
Client Side: Server Side:
┌─────────────────────┐ ┌─────────────────────┐
│ Generated Stub │ │ Service Registry │
│ (type-safe API) │ │ (reflect-based) │
└─────────────────────┘ └─────────────────────┘
│ │
┌─────────────────────┐ ┌─────────────────────┐
│ Client Interceptor │ │ Server Interceptor │
│ Chain │ │ Chain │
└─────────────────────┘ └─────────────────────┘
│ │
┌─────────────────────┐ ┌─────────────────────┐
│ Codec (gob/json/ │ │ Codec (decode req, │
│ protobuf) │ │ encode resp) │
└─────────────────────┘ └─────────────────────┘
│ │
┌─────────────────────┐ ┌─────────────────────┐
│ TCP Transport │◄───────►│ TCP Transport │
│ + Connection Pool │ │ (per conn goroutine│
└─────────────────────┘ └─────────────────────┘
第一步:请求/响应协议和 Codec
package codec
import (
"bufio"
"encoding/gob"
"encoding/json"
"fmt"
"io"
)
// Header 每个 RPC 调用的元数据
type Header struct {
ServiceMethod string // "ServiceName.MethodName"
Seq uint64 // 请求序列号(用于异步调用的 response 匹配)
Error string // 服务端错误(非空时忽略 Body)
}
// Codec 负责将 Header+Body 编解码
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(interface{}) error
Write(*Header, interface{}) error
}
// GobCodec 使用 Go 标准库 gob 编解码
type GobCodec struct {
conn io.ReadWriteCloser
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err = c.enc.Encode(h); err != nil {
return fmt.Errorf("encode header: %w", err)
}
if err = c.enc.Encode(body); err != nil {
return fmt.Errorf("encode body: %w", err)
}
return nil
}
func (c *GobCodec) Close() error { return c.conn.Close() }
// JSONCodec 使用 JSON 编解码(更易调试)
type JSONCodec struct {
conn io.ReadWriteCloser
dec *json.Decoder
enc *json.Encoder
}
func NewJSONCodec(conn io.ReadWriteCloser) Codec {
return &JSONCodec{
conn: conn,
dec: json.NewDecoder(conn),
enc: json.NewEncoder(conn),
}
}
func (c *JSONCodec) ReadHeader(h *Header) error { return c.dec.Decode(h) }
func (c *JSONCodec) ReadBody(body interface{}) error { return c.dec.Decode(body) }
func (c *JSONCodec) Write(h *Header, body interface{}) error {
if err := c.enc.Encode(h); err != nil {
return err
}
return c.enc.Encode(body)
}
func (c *JSONCodec) Close() error { return c.conn.Close() }
type CodecType string
const (
GobType CodecType = "application/gob"
JSONType CodecType = "application/json"
)
// NewCodecFunc 根据类型创建 Codec 的工厂函数
type NewCodecFunc func(io.ReadWriteCloser) Codec
var NewCodecFuncMap = map[CodecType]NewCodecFunc{
GobType: NewGobCodec,
JSONType: NewJSONCodec,
}
第二步:服务注册(反射实现)
package server
import (
"fmt"
"go/token"
"log"
"reflect"
"strings"
"sync"
)
// methodType 描述一个可供远程调用的方法
type methodType struct {
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint64
}
func (m *methodType) newArgv() reflect.Value {
// 如果是指针类型,创建对应元素类型的值
if m.ArgType.Kind() == reflect.Ptr {
return reflect.New(m.ArgType.Elem())
}
return reflect.New(m.ArgType).Elem()
}
func (m *methodType) newReplyv() reflect.Value {
// Reply 一定是指针类型
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}
// service 代表一个注册的服务
type service struct {
name string
typ reflect.Type
rcvr reflect.Value
methods map[string]*methodType
}
func newService(rcvr interface{}) *service {
s := &service{}
s.rcvr = reflect.ValueOf(rcvr)
s.typ = reflect.TypeOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
if !token.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}
func (s *service) registerMethods() {
s.methods = make(map[string]*methodType)
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
// 合法的 RPC 方法签名: func (t *T) MethodName(ctx context.Context, args *Args, reply *Reply) error
// 简化版: func (t *T) MethodName(args ArgType, reply *ReplyType) error
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.methods[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
f := m.method.Func
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}
func isExportedOrBuiltinType(t reflect.Type) bool {
return token.IsExported(t.Name()) || t.PkgPath() == ""
}
// Server RPC 服务器
type Server struct {
serviceMap sync.Map
}
func NewServer() *Server { return &Server{} }
func (s *Server) Register(rcvr interface{}) error {
svc := newService(rcvr)
if _, dup := s.serviceMap.LoadOrStore(svc.name, svc); dup {
return fmt.Errorf("rpc: service already defined: %s", svc.name)
}
return nil
}
func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = fmt.Errorf("rpc server: service/method request ill-formed: %s", serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := s.serviceMap.Load(serviceName)
if !ok {
err = fmt.Errorf("rpc server: can't find service %s", serviceName)
return
}
svc = svci.(*service)
mtype, ok = svc.methods[methodName]
if !ok {
err = fmt.Errorf("rpc server: can't find method %s", methodName)
}
return
}
第三步:服务器连接处理
package server
import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"reflect"
"strings"
"sync"
"time"
"github.com/yourname/minirpc/codec"
)
// Option 协商阶段的配置(客户端发送,服务端读取)
type Option struct {
MagicNumber int // 魔数,验证这是一个 minirpc 请求
CodecType codec.CodecType
ConnTimeout time.Duration // 连接超时
HandleTimeout time.Duration // 处理超时
}
const MagicNumber = 0x3bef5c
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnTimeout: time.Second * 10,
}
// ServeConn 处理单个连接
func (s *Server) ServeConn(conn io.ReadWriteCloser) {
defer conn.Close()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Printf("rpc server: decode option error: %v", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
newCodec := codec.NewCodecFuncMap[opt.CodecType]
if newCodec == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
s.serveCodec(newCodec(conn), &opt)
}
type request struct {
h *codec.Header
argv reflect.Value
replyv reflect.Value
mtype *methodType
svc *service
}
func (s *Server) serveCodec(cc codec.Codec, opt *Option) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for {
req, err := s.readRequest(cc)
if err != nil {
if req == nil {
break // 无法恢复的错误,关闭连接
}
req.h.Error = err.Error()
s.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
go s.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
}
wg.Wait()
_ = cc.Close()
}
var invalidRequest = struct{}{}
func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && !strings.HasSuffix(err.Error(), "EOF") {
log.Printf("rpc server: read header error: %v", err)
}
return nil, err
}
return &h, nil
}
func (s *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := s.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = s.findService(h.ServiceMethod)
if err != nil {
if err := cc.ReadBody(nil); err != nil {
log.Printf("rpc server: read body err: %v", err)
}
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Printf("rpc server: read body err: %v", err)
return req, err
}
return req, nil
}
func (s *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
s.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
s.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}
func (s *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Printf("rpc server: write response error: %v", err)
}
}
// Accept 监听并处理请求
func (s *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Printf("rpc server: accept error: %v", err)
return
}
go s.ServeConn(conn)
}
}
var DefaultServer = NewServer()
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
第四步:客户端(支持异步调用)
package client
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"sync"
"time"
"github.com/yourname/minirpc/codec"
"github.com/yourname/minirpc/server"
)
// Call 代表一次正在进行的 RPC 调用
type Call struct {
Seq uint64
ServiceMethod string
Args interface{}
Reply interface{}
Error error
Done chan *Call // 调用完成后通知
}
func (c *Call) done() {
c.Done <- c
}
// Client RPC 客户端
type Client struct {
cc codec.Codec
opt *server.Option
sending sync.Mutex
header codec.Header
mu sync.Mutex
seq uint64
pending map[uint64]*Call
closing bool // 用户主动关闭
shutdown bool // 服务端要求关闭或错误
}
func NewClient(conn net.Conn, opt *server.Option) (*Client, error) {
newCodecFunc := codec.NewCodecFuncMap[opt.CodecType]
if newCodecFunc == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
return nil, err
}
// 发送协商选项
if err := json.NewEncoder(conn).Encode(opt); err != nil {
return nil, fmt.Errorf("rpc client: encode option error: %w", err)
}
client := &Client{
seq: 1,
cc: newCodecFunc(conn),
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client, nil
}
func Dial(network, address string, opts ...*server.Option) (client *Client, err error) {
opt := server.DefaultOption
if len(opts) > 0 && opts[0] != nil {
opt = opts[0]
}
conn, err := net.DialTimeout(network, address, opt.ConnTimeout)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing {
return fmt.Errorf("connection already closed")
}
c.closing = true
return c.cc.Close()
}
func (c *Client) IsAvailable() bool {
c.mu.Lock()
defer c.mu.Unlock()
return !c.shutdown && !c.closing
}
func (c *Client) registerCall(call *Call) (uint64, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing || c.shutdown {
return 0, fmt.Errorf("rpc client: client is shutting down")
}
call.Seq = c.seq
c.pending[call.Seq] = call
c.seq++
return call.Seq, nil
}
func (c *Client) removeCall(seq uint64) *Call {
c.mu.Lock()
defer c.mu.Unlock()
call := c.pending[seq]
delete(c.pending, seq)
return call
}
func (c *Client) terminateCalls(err error) {
c.sending.Lock()
defer c.sending.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
c.shutdown = true
for _, call := range c.pending {
call.Error = err
call.done()
}
}
func (c *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = c.cc.ReadHeader(&h); err != nil {
break
}
call := c.removeCall(h.Seq)
switch {
case call == nil:
err = c.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = c.cc.ReadBody(nil)
call.done()
default:
err = c.cc.ReadBody(call.Reply)
if err != nil {
call.Error = fmt.Errorf("reading body: %w", err)
}
call.done()
}
}
c.terminateCalls(err)
}
func (c *Client) send(call *Call) {
c.sending.Lock()
defer c.sending.Unlock()
seq, err := c.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
c.header.ServiceMethod = call.ServiceMethod
c.header.Seq = seq
c.header.Error = ""
if err := c.cc.Write(&c.header, call.Args); err != nil {
call := c.removeCall(seq)
if call != nil {
call.Error = err
call.done()
}
}
}
// Go 异步调用,立即返回 Call
func (c *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 1)
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
c.send(call)
return call
}
// Call 同步调用,阻塞直到完成或超时
func (c *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := c.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
c.removeCall(call.Seq)
return fmt.Errorf("rpc client: call failed: %s", ctx.Err())
case call := <-call.Done:
return call.Error
}
}
第五步:拦截器和负载均衡
package middleware
import (
"context"
"log"
"time"
)
// UnaryInterceptor 一元 RPC 拦截器类型
type UnaryInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error)
type UnaryServerInfo struct {
Server interface{}
FullMethod string
}
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
// ChainUnaryInterceptors 将多个拦截器链接成一个
func ChainUnaryInterceptors(interceptors ...UnaryInterceptor) UnaryInterceptor {
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
chained := handler
for i := len(interceptors) - 1; i >= 0; i-- {
interceptor := interceptors[i]
next := chained
chained = func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptor(ctx, req, info, next)
}
}
return chained(ctx, req)
}
}
// LoggingInterceptor 日志拦截器
func LoggingInterceptor(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
start := time.Now()
log.Printf("RPC call: %s, req: %v", info.FullMethod, req)
resp, err := handler(ctx, req)
elapsed := time.Since(start)
if err != nil {
log.Printf("RPC failed: %s, error: %v, elapsed: %v", info.FullMethod, err, elapsed)
} else {
log.Printf("RPC success: %s, elapsed: %v", info.FullMethod, elapsed)
}
return resp, err
}
// RecoveryInterceptor panic 恢复拦截器
func RecoveryInterceptor(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("RPC panic recovered: %s, panic: %v", info.FullMethod, r)
err = fmt.Errorf("internal server error: panic")
}
}()
return handler(ctx, req)
}
// ===================== 负载均衡 =====================
package balancer
import (
"fmt"
"math/rand"
"sync"
"sync/atomic"
)
// SelectMode 负载均衡策略
type SelectMode int
const (
RandomSelect SelectMode = iota
RoundRobinSelect
)
// Discovery 服务发现接口
type Discovery interface {
Refresh() error // 从注册中心刷新服务列表
Update(servers []string) error // 手动更新服务列表
Get(mode SelectMode) (string, error) // 根据策略选择一个服务
GetAll() ([]string, error) // 获取所有服务
}
// MultiServerDiscovery 手动维护的服务列表(不依赖注册中心)
type MultiServerDiscovery struct {
r *rand.Rand
mu sync.RWMutex
servers []string
index uint64 // Round Robin 计数器
}
func NewMultiServerDiscovery(servers []string) *MultiServerDiscovery {
return &MultiServerDiscovery{
r: rand.New(rand.NewSource(rand.Int63())),
servers: servers,
}
}
func (d *MultiServerDiscovery) Refresh() error { return nil }
func (d *MultiServerDiscovery) Update(servers []string) error {
d.mu.Lock()
defer d.mu.Unlock()
d.servers = servers
return nil
}
func (d *MultiServerDiscovery) Get(mode SelectMode) (string, error) {
d.mu.RLock()
defer d.mu.RUnlock()
n := len(d.servers)
if n == 0 {
return "", fmt.Errorf("rpc discovery: no available servers")
}
switch mode {
case RandomSelect:
return d.servers[d.r.Intn(n)], nil
case RoundRobinSelect:
// 原子自增,取模选择
idx := atomic.AddUint64(&d.index, 1) - 1
return d.servers[idx%uint64(n)], nil
default:
return "", fmt.Errorf("rpc discovery: unsupported select mode")
}
}
func (d *MultiServerDiscovery) GetAll() ([]string, error) {
d.mu.RLock()
defer d.mu.RUnlock()
return d.servers, nil
}
// XClient 支持负载均衡的 RPC 客户端
type XClient struct {
d Discovery
mode SelectMode
opt *server.Option
mu sync.Mutex
clients map[string]*client.Client
}
func NewXClient(d Discovery, mode SelectMode, opt *server.Option) *XClient {
return &XClient{
d: d,
mode: mode,
opt: opt,
clients: make(map[string]*client.Client),
}
}
func (xc *XClient) Close() error {
xc.mu.Lock()
defer xc.mu.Unlock()
for _, c := range xc.clients {
_ = c.Close()
}
return nil
}
func (xc *XClient) dial(rpcAddr string) (*client.Client, error) {
xc.mu.Lock()
defer xc.mu.Unlock()
c, ok := xc.clients[rpcAddr]
if ok && !c.IsAvailable() {
_ = c.Close()
delete(xc.clients, rpcAddr)
c = nil
}
if c == nil {
var err error
c, err = client.Dial("tcp", rpcAddr, xc.opt)
if err != nil {
return nil, err
}
xc.clients[rpcAddr] = c
}
return c, nil
}
func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error {
c, err := xc.dial(rpcAddr)
if err != nil {
return err
}
return c.Call(ctx, serviceMethod, args, reply)
}
func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
rpcAddr, err := xc.d.Get(xc.mode)
if err != nil {
return err
}
return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}
// Broadcast 广播:调用所有服务,有任一成功则返回,有任一失败则收集错误
func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error {
servers, err := xc.d.GetAll()
if err != nil {
return err
}
var wg sync.WaitGroup
var mu sync.Mutex
var e error
replyDone := reply == nil
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for _, rpcAddr := range servers {
wg.Add(1)
go func(rpcAddr string) {
defer wg.Done()
var clonedReply interface{}
if reply != nil {
clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
}
err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply)
mu.Lock()
defer mu.Unlock()
if err != nil && e == nil {
e = err
cancel() // 有错误,取消其他调用
}
if err == nil && !replyDone {
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
replyDone = true
}
}(rpcAddr)
}
wg.Wait()
return e
}
完整使用示例
package main
import (
"context"
"fmt"
"log"
"net"
"sync"
"time"
"github.com/yourname/minirpc/client"
"github.com/yourname/minirpc/server"
)
// 定义服务
type MathService struct{}
type Args struct{ A, B int }
type Reply struct{ C int }
func (m *MathService) Add(args Args, reply *Reply) error {
reply.C = args.A + args.B
return nil
}
func (m *MathService) Multiply(args Args, reply *Reply) error {
reply.C = args.A * args.B
return nil
}
func startServer(wg *sync.WaitGroup) {
var m MathService
if err := server.Register(&m); err != nil {
log.Fatal("register error:", err)
}
lis, err := net.Listen("tcp", ":9999")
if err != nil {
log.Fatal("listen error:", err)
}
log.Println("rpc server started on :9999")
wg.Done()
server.Accept(lis)
}
func main() {
var wg sync.WaitGroup
wg.Add(1)
go startServer(&wg)
wg.Wait()
time.Sleep(time.Second)
c, err := client.Dial("tcp", ":9999")
if err != nil {
log.Fatal("dial error:", err)
}
defer c.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
var reply Reply
if err := c.Call(ctx, "MathService.Add", Args{A: 1, B: 2}, &reply); err != nil {
log.Fatal("call error:", err)
}
fmt.Printf("1 + 2 = %d\n", reply.C) // 1 + 2 = 3
if err := c.Call(ctx, "MathService.Multiply", Args{A: 6, B: 7}, &reply); err != nil {
log.Fatal("call error:", err)
}
fmt.Printf("6 * 7 = %d\n", reply.C) // 6 * 7 = 42
}
Level 4 · 进阶与边界
健康检查协议
生产环境中的 RPC 框架必须能检测到后端实例是否健康。gRPC 定义了标准的健康检查协议,服务端实现 grpc.health.v1.Health 服务:
// 健康检查服务接口(简化)
type HealthServer interface {
Check(ctx context.Context, req *HealthCheckRequest) (*HealthCheckResponse, error)
Watch(req *HealthCheckRequest, stream Health_WatchServer) error
}
// 周期性探活
func healthCheck(addr string, interval time.Duration, unhealthy chan<- string) {
for {
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
unhealthy <- addr
} else {
conn.Close()
}
time.Sleep(interval)
}
}
gRPC-Gateway:REST 兼容性
内部服务用 gRPC,但对外 API 需要 REST 接口——这是常见需求。grpc-gateway 是解决方案:它读取 Protobuf 文件中的 HTTP 注解,自动生成一个 HTTP 代理层,把 REST 请求转换成 gRPC 调用:
service MathService {
rpc Add(AddRequest) returns (AddResponse) {
option (google.api.http) = {
post: "/v1/math/add"
body: "*"
};
}
}
超时传播
在微服务链路中,超时必须向下传播。如果调用链是 A → B → C,A 设置了 1 秒超时,B 调用 C 时必须知道还剩多少时间,否则 C 可能在 A 已经超时后还在执行。
gRPC 用 context.Context 传播截止时间(Deadline):
// A 发起调用,设置截止时间
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// context.Deadline() 会被 gRPC 序列化到请求的元数据中
resp, err := clientB.Call(ctx, ...)
// B 收到请求时,context 已经带有截止时间
// B 调用 C 时,把同一个 ctx 传下去
respC, err := clientC.Call(ctx, ...) // ctx.Deadline() 仍然是 A 设置的时间点
gRPC vs Thrift vs Dubbo
gRPC:Google 出品,HTTP/2 + Protobuf,跨语言支持最广(几乎所有语言都有官方支持),流式 RPC 支持好,CNCF 生态深度集成(Kubernetes、Istio 等都原生支持)。
Thrift:Facebook 出品,自定义二进制协议,比 gRPC 更老,性能接近,但生态稍逊,流式支持有限。在 Facebook 内部使用,HBase 的客户端协议也用 Thrift。
Dubbo:阿里巴巴出品,专注 Java 生态,服务治理功能更完整(限流、熔断、灰度发布内建),适合 Java 微服务体系。
性能粗比(QPS,仅供参考):
| 框架 | 序列化 | 传输 | 相对性能 |
|---|---|---|---|
| gRPC | Protobuf | HTTP/2 | 基准 |
| Thrift Binary | Thrift | TCP | ~110% |
| JSON over HTTP/1.1 | JSON | HTTP/1.1 | ~30-40% |
| gob over TCP | gob | TCP | ~90% |
Protobuf Schema 演化规则:
- 可以添加新字段(旧代码忽略未知字段)
- 不能修改已有字段的编号
- 不能删除已有字段的编号(标记为
reserved而非直接删除) - 更改字段类型需谨慎(int32 → int64 兼容,int32 → string 不兼容)
这些规则保证了服务的滚动升级(rolling upgrade)不会因为 schema 变化导致新旧代码无法通信。这是生产环境中真正让 RPC 框架可用的关键特性之一。