第 27 章

中间件模式

中间件模式

1992 年,面向对象社区流行一句话:"好的软件是由层组成的"。这个思想后来演化出了 MVC、分层架构、洋葱模型……但有一个问题始终没有被这些架构模式完美解决:横切关注点(Cross-Cutting Concerns)

横切关注点是那些穿越多个层的功能:日志、鉴权、限流、追踪、缓存、CORS。你不能把它们放在某一层——它们属于所有层,或者说,它们在所有层之外。

中间件就是解决这个问题的答案。不是唯一的答案,但是在 HTTP 服务领域,它是最优雅的答案。

Level 1 · 你需要知道的

中间件是什么:洋葱模型

想象一个洋葱的横截面。HTTP 请求从最外层进入,经过每一层洋葱,到达最内层的业务 handler,然后响应从内到外穿回每一层。

请求方向 →                        ← 响应方向

┌─────────────────────────────────────────────┐
│  Logger Middleware                          │
│  ┌───────────────────────────────────────┐  │
│  │  Auth Middleware                      │  │
│  │  ┌─────────────────────────────────┐  │  │
│  │  │  Rate Limiter Middleware        │  │  │
│  │  │  ┌───────────────────────────┐  │  │  │
│  │  │  │  Business Handler        │  │  │  │
│  │  │  └───────────────────────────┘  │  │  │
│  │  └─────────────────────────────────┘  │  │
│  └───────────────────────────────────────┘  │
└─────────────────────────────────────────────┘

每个中间件都有两个执行阶段:

这就是洋葱模型的精髓:同一个函数同时处理请求的进入和退出

为什么中间件是横切关注点的正确抽象

考虑替代方案。假设我们不用中间件,而是在每个 handler 里手写日志逻辑:

// 没有中间件的世界
func getUserHandler(c *gin.Context) {
    start := time.Now()
    // 鉴权
    token := c.GetHeader("Authorization")
    if !validateToken(token) {
        log.Printf("auth failed: %s %s", c.Request.Method, c.Request.URL)
        c.JSON(401, gin.H{"error": "unauthorized"})
        return
    }
    // 限流检查
    if !rateLimiter.Allow() {
        log.Printf("rate limited: %s %s", c.Request.Method, c.Request.URL)
        c.JSON(429, gin.H{"error": "too many requests"})
        return
    }
    // 实际业务逻辑
    user := getUser(c.Param("id"))
    c.JSON(200, user)
    // 记录日志
    log.Printf("method=%s path=%s status=%d latency=%v",
        c.Request.Method, c.Request.URL, 200, time.Since(start))
}

这段代码有几个问题:

  1. 重复(DRY 违反):每个 handler 都要写一遍鉴权、日志、限流。500 个 handler 就是 500 份副本。
  2. 耦合:业务逻辑和基础设施逻辑混在一起,单元测试 getUserHandler 时必须 mock 鉴权和限流。
  3. 不一致风险:某个实习生写了一个 handler,忘记写限流检查,谁也不知道。

中间件解决了这三个问题:

  1. 集中化:每个横切关注点只有一份实现,所有 handler 自动受益。
  2. 解耦:handler 只关注业务逻辑,中间件负责基础设施。handler 的单元测试可以直接调用,不需要 mock 任何中间件。
  3. 强制性:通过路由组级别应用中间件,所有加入该组的路由自动获得保护,无法遗漏。

中间件的核心哲学:可组合性

中间件的威力来自可组合性。每个中间件是独立的、可单独测试的单元,可以任意组合:

基础中间件:Logger, Recovery, RequestID
安全中间件:CORS, SecurityHeaders, CSRF
认证中间件:JWTAuth, APIKeyAuth, SessionAuth
业务中间件:RateLimiter, Tenant, Permission

不同的路由组使用不同的组合:

public := r.Group("/")           // Logger + Recovery
authed := r.Group("/")           // Logger + Recovery + JWTAuth
admin  := r.Group("/admin")      // Logger + Recovery + JWTAuth + AdminPermission

这是 Unix 管道哲学在 HTTP 世界的体现:每个组件做好一件事,然后把它们串联起来。

Level 2 · 它是怎么工作的

Gin 中间件的执行机制:c.Next() 深度解析

Gin 中间件的执行模型本质上是一个带状态的循环迭代器

// gin/context.go 核心部分(经过简化和注释)
type Context struct {
    // ...
    handlers HandlersChain  // []HandlerFunc,当前请求的完整 handler 链
    index    int8            // 当前执行到哪个 handler,-1 表示未开始
}

func (c *Context) Next() {
    c.index++
    for c.index < int8(len(c.handlers)) {
        c.handlers[c.index](c)
        c.index++
    }
}

让我们追踪一个请求经过 [Logger, Auth, BusinessHandler] 链的完整执行过程:

初始状态:index = -1

1. 服务器调用 c.Next()
   index 变为 0,执行 Logger(前置代码)
   Logger 调用 c.Next()
     index 变为 1,执行 Auth(前置代码)
     Auth 调用 c.Next()
       index 变为 2,执行 BusinessHandler
       BusinessHandler 返回
       index 变为 3,循环条件不满足,退出内层 Next()
     Auth(后置代码)执行
     Auth 返回
     index 变为 2(递增),但这是在 Auth 返回后,外层 Next() 的 index++ 执行
   Logger(后置代码)执行
   Logger 返回

关键洞察c.Next() 是递归调用的。每个中间件的 c.Next() 调用都会阻塞,直到它后面的所有 handler 都执行完。这就是为什么你可以在 c.Next() 之后访问响应状态码:

func LoggerMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        path := c.Request.URL.Path
        
        c.Next()  // ← 执行到这里时,后续所有 handler 都已完成
        
        // 这里可以访问响应的状态码!
        latency := time.Since(start)
        status := c.Writer.Status()  // 已经写入的状态码
        size := c.Writer.Size()      // 已经写入的字节数
        
        log.Printf("[%d] %s %s %v %d bytes",
            status, c.Request.Method, path, latency, size)
    }
}

中间件状态传递:c.Setc.Get

中间件之间需要传递数据(例如,Auth 中间件提取用户 ID,下游 handler 使用它)。Gin 提供了 c.Set/c.Get 机制:

// Auth 中间件设置用户信息
func JWTAuthMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := extractToken(c)
        claims, err := validateJWT(token)
        if err != nil {
            c.AbortWithStatusJSON(401, gin.H{"error": "invalid token"})
            return
        }
        
        // 将解析出的用户信息存入 Context
        c.Set("user_id", claims.UserID)
        c.Set("user_role", claims.Role)
        c.Set("user_email", claims.Email)
        
        c.Next()
    }
}

// 下游 handler 获取用户信息
func getProfileHandler(c *gin.Context) {
    userID, exists := c.Get("user_id")
    if !exists {
        c.JSON(500, gin.H{"error": "user_id not found in context"})
        return
    }
    
    // 类型断言(c.Get 返回 interface{})
    uid, ok := userID.(int64)
    if !ok {
        c.JSON(500, gin.H{"error": "user_id type assertion failed"})
        return
    }
    
    profile := fetchProfile(uid)
    c.JSON(200, profile)
}

c.Set/c.Get 底层是 map[string]interface{},有并发安全保证(Gin 内部使用读写锁)。但要注意:这个 map 只在单个请求的生命周期内有效,不能跨请求共享。

类型安全的辅助方法:Gin 提供了 c.GetStringc.GetInt64c.GetBool 等方法,避免手动类型断言:

userID := c.GetInt64("user_id")   // 如果不存在或类型不对,返回零值
userRole := c.GetString("user_role")

中断链:c.Abort() 的变体

除了基础的 c.Abort(),Gin 还提供了几个便捷方法:

c.Abort()                           // 停止链,不写任何响应
c.AbortWithStatus(403)              // 停止链,写入状态码
c.AbortWithStatusJSON(403, data)    // 停止链,写入 JSON 响应
c.AbortWithError(500, err)          // 停止链,写入错误(用于错误收集中间件)

c.AbortWithError 的使用场景:有时你希望中间件"记录错误但让后续中间件决定如何响应":

func ValidationMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        if err := validateRequest(c); err != nil {
            // 记录错误,但不写响应
            c.AbortWithError(http.StatusBadRequest, err).SetType(gin.ErrorTypePublic)
            return
        }
        c.Next()
    }
}

// 错误处理中间件(必须在最前面注册,才能捕获后续中间件的错误)
func ErrorHandlerMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Next()
        
        // 检查是否有错误
        if len(c.Errors) > 0 {
            err := c.Errors.Last()
            if err.IsType(gin.ErrorTypePublic) {
                c.JSON(c.Writer.Status(), gin.H{"error": err.Error()})
            } else {
                c.JSON(500, gin.H{"error": "internal server error"})
            }
        }
    }
}

Level 3 · 代码实战

限流中间件:令牌桶算法

Go 标准库的 golang.org/x/time/rate 提供了生产级的令牌桶实现:

package middleware

import (
    "net/http"
    "sync"

    "github.com/gin-gonic/gin"
    "golang.org/x/time/rate"
)

// IPRateLimiter 基于 IP 的限流器
type IPRateLimiter struct {
    mu       sync.Mutex
    limiters map[string]*rate.Limiter
    r        rate.Limit  // 每秒允许的请求数
    b        int         // 令牌桶容量(允许的突发量)
}

func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter {
    return &IPRateLimiter{
        limiters: make(map[string]*rate.Limiter),
        r:        r,
        b:        b,
    }
}

func (i *IPRateLimiter) getLimiter(ip string) *rate.Limiter {
    i.mu.Lock()
    defer i.mu.Unlock()

    limiter, exists := i.limiters[ip]
    if !exists {
        limiter = rate.NewLimiter(i.r, i.b)
        i.limiters[ip] = limiter
    }
    return limiter
}

// RateLimitMiddleware 返回基于 IP 的限流中间件
// r: 每秒请求数(令牌生成速率)
// b: 突发量(令牌桶容量)
func RateLimitMiddleware(r rate.Limit, b int) gin.HandlerFunc {
    limiter := NewIPRateLimiter(r, b)
    
    return func(c *gin.Context) {
        // 获取真实 IP(考虑代理)
        ip := c.ClientIP()
        
        ipLimiter := limiter.getLimiter(ip)
        if !ipLimiter.Allow() {
            c.Header("X-RateLimit-Limit", fmt.Sprintf("%.0f", float64(r)))
            c.Header("Retry-After", "1")
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "too many requests",
                "code":  4029,
            })
            return
        }
        
        c.Next()
    }
}

// 使用示例
func main() {
    r := gin.New()
    
    // 全局限流:每秒 100 个请求,允许突发 200
    r.Use(RateLimitMiddleware(100, 200))
    
    // API 路由可以有更严格的限流
    api := r.Group("/api")
    api.Use(RateLimitMiddleware(10, 20))  // 每秒 10 个请求
    {
        api.POST("/login", loginHandler)  // 登录接口最严格
    }
}

令牌桶工作原理:令牌桶以恒定速率(r 个/秒)向桶中添加令牌,桶的容量上限为 b。每个请求消耗一个令牌,令牌不够时拒绝请求。桶的容量 b 决定了允许的突发流量——如果一段时间内没有请求,桶会积累令牌,允许后续的短时间高流量。

生产注意事项:上面的实现将限流状态存在内存中。在多实例部署时,需要将 limiter 状态存在 Redis 中,否则每个实例独立计数,无法实现全局限流。使用 github.com/go-redis/redis_rate 可以实现基于 Redis 的分布式令牌桶。

请求日志中间件

一个生产级的请求日志中间件需要记录:请求路径、方法、状态码、延迟、响应大小、客户端 IP、错误信息。

package middleware

import (
    "time"
    "github.com/gin-gonic/gin"
    "go.uber.org/zap"
)

func LoggerMiddleware(logger *zap.Logger) gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        path := c.Request.URL.Path
        query := c.Request.URL.RawQuery
        
        // 记录请求开始(可选,适合长时间运行的请求)
        // logger.Info("request started", zap.String("path", path))
        
        c.Next()  // 执行后续 handler
        
        // 请求结束后记录
        end := time.Now()
        latency := end.Sub(start)
        
        if query != "" {
            path = path + "?" + query
        }
        
        // 收集所有中间件/handler 设置的错误
        errors := c.Errors.Errors()
        
        fields := []zap.Field{
            zap.Int("status",          c.Writer.Status()),
            zap.String("method",       c.Request.Method),
            zap.String("path",         path),
            zap.String("ip",           c.ClientIP()),
            zap.Duration("latency",    latency),
            zap.Int("body_size",       c.Writer.Size()),
            zap.String("user_agent",   c.Request.UserAgent()),
            zap.String("request_id",   c.GetString("X-Request-ID")),
        }
        
        if len(errors) > 0 {
            fields = append(fields, zap.Strings("errors", errors))
        }
        
        // 根据状态码选择日志级别
        status := c.Writer.Status()
        switch {
        case status >= 500:
            logger.Error("server error", fields...)
        case status >= 400:
            logger.Warn("client error", fields...)
        case latency > 3*time.Second:
            logger.Warn("slow request", fields...)
        default:
            logger.Info("request", fields...)
        }
    }
}

CORS 中间件

CORS(跨域资源共享)是浏览器的安全机制,服务器需要通过响应头告知浏览器哪些跨域请求是允许的:

package middleware

import (
    "net/http"
    "strings"
    "github.com/gin-gonic/gin"
)

type CORSConfig struct {
    AllowOrigins     []string
    AllowMethods     []string
    AllowHeaders     []string
    ExposeHeaders    []string
    AllowCredentials bool
    MaxAge           int  // preflight 结果缓存时间(秒)
}

func DefaultCORSConfig() CORSConfig {
    return CORSConfig{
        AllowOrigins:  []string{"*"},
        AllowMethods:  []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
        AllowHeaders:  []string{"Origin", "Content-Type", "Authorization", "X-Request-ID"},
        ExposeHeaders: []string{"X-Request-ID"},
        MaxAge:        86400,  // 24 小时
    }
}

func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
    allowOriginsMap := make(map[string]bool)
    for _, origin := range config.AllowOrigins {
        allowOriginsMap[origin] = true
    }
    
    allowMethodsStr := strings.Join(config.AllowMethods, ", ")
    allowHeadersStr := strings.Join(config.AllowHeaders, ", ")
    exposeHeadersStr := strings.Join(config.ExposeHeaders, ", ")
    
    return func(c *gin.Context) {
        origin := c.Request.Header.Get("Origin")
        
        // 检查 Origin 是否被允许
        originAllowed := allowOriginsMap["*"] || allowOriginsMap[origin]
        if !originAllowed {
            c.Next()
            return
        }
        
        // 设置 CORS 响应头
        if allowOriginsMap["*"] {
            c.Header("Access-Control-Allow-Origin", "*")
        } else {
            c.Header("Access-Control-Allow-Origin", origin)
            c.Header("Vary", "Origin")  // 告诉缓存这个响应因 Origin 而异
        }
        
        c.Header("Access-Control-Allow-Methods", allowMethodsStr)
        c.Header("Access-Control-Allow-Headers", allowHeadersStr)
        c.Header("Access-Control-Expose-Headers", exposeHeadersStr)
        
        if config.AllowCredentials {
            c.Header("Access-Control-Allow-Credentials", "true")
        }
        
        // 处理预检请求(OPTIONS)
        if c.Request.Method == http.MethodOptions {
            c.Header("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
            c.AbortWithStatus(http.StatusNoContent)  // 204
            return
        }
        
        c.Next()
    }
}

CORS 的常见陷阱

安全头中间件

安全头是防御多种 Web 攻击的第一道防线:

package middleware

import "github.com/gin-gonic/gin"

func SecurityHeadersMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 防止 MIME 类型嗅探攻击
        c.Header("X-Content-Type-Options", "nosniff")
        
        // 防止点击劫持
        c.Header("X-Frame-Options", "DENY")
        
        // 启用浏览器 XSS 过滤(较旧的浏览器)
        c.Header("X-XSS-Protection", "1; mode=block")
        
        // 控制 Referer 信息暴露
        c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
        
        // 强制 HTTPS(仅在 HTTPS 下生效)
        c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
        
        // 内容安全策略(CSP)——这是最强大也最复杂的安全头
        // 根据你的实际需求调整,过于严格的 CSP 会破坏页面功能
        csp := strings.Join([]string{
            "default-src 'self'",
            "script-src 'self' 'unsafe-inline' https://cdn.example.com",  // 允许自身和 CDN 的脚本
            "style-src 'self' 'unsafe-inline'",
            "img-src 'self' data: https:",  // 允许 HTTPS 图片和 data URI
            "font-src 'self' https://fonts.gstatic.com",
            "connect-src 'self' https://api.example.com",
            "frame-ancestors 'none'",  // 等同于 X-Frame-Options: DENY
            "base-uri 'self'",
            "form-action 'self'",
        }, "; ")
        c.Header("Content-Security-Policy", csp)
        
        // 权限策略:禁用不需要的浏览器 API
        c.Header("Permissions-Policy", 
            "geolocation=(), camera=(), microphone=(), payment=()")
        
        c.Next()
    }
}

Panic 恢复中间件(带完整堆栈信息)

Gin 的 gin.Recovery() 可以防止 panic 导致服务崩溃,但它的错误信息有限。下面是一个增强版:

package middleware

import (
    "fmt"
    "net/http"
    "runtime/debug"
    "time"

    "github.com/gin-gonic/gin"
    "go.uber.org/zap"
)

func RecoveryMiddleware(logger *zap.Logger) gin.HandlerFunc {
    return func(c *gin.Context) {
        defer func() {
            if err := recover(); err != nil {
                // 捕获完整的堆栈信息
                stack := debug.Stack()
                
                // 记录结构化日志
                logger.Error("panic recovered",
                    zap.Any("error", err),
                    zap.String("stack", string(stack)),
                    zap.String("path", c.Request.URL.Path),
                    zap.String("method", c.Request.Method),
                    zap.String("ip", c.ClientIP()),
                    zap.String("request_id", c.GetString("X-Request-ID")),
                    zap.Time("time", time.Now()),
                )
                
                // 发送告警(生产中可以接入钉钉、Slack、PagerDuty 等)
                go sendAlert(fmt.Sprintf("PANIC: %v\n%s", err, stack))
                
                // 给客户端返回 500,但不暴露内部错误详情
                c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
                    "error": "internal server error",
                    "code":  5000,
                    // 在开发环境可以暴露 request_id,生产只返回这个让用户提交 bug 报告
                    "request_id": c.GetString("X-Request-ID"),
                })
            }
        }()
        c.Next()
    }
}

func sendAlert(message string) {
    // 实际实现:调用告警 webhook
    // 这里是占位实现
    _ = message
}

Level 4 · 深水区

中间件顺序的工程原则

中间件的注册顺序决定了执行顺序,错误的顺序会导致安全漏洞或逻辑错误。以下是推荐的顺序:

r := gin.New()

// 第 1 层:基础设施(最外层,总是执行)
r.Use(RecoveryMiddleware(logger))   // 必须最早注册,防止其他中间件的 panic
r.Use(RequestIDMiddleware())        // 最早生成 ID,后续日志都可以带上它

// 第 2 层:可观测性(在安全检查之前,可以记录被拒绝的请求)
r.Use(LoggerMiddleware(logger))     // 日志要包含所有请求,包括被拒绝的

// 第 3 层:安全与访问控制
r.Use(SecurityHeadersMiddleware())  // 所有响应都要有安全头
r.Use(CORSMiddleware(corsConfig))   // CORS 要在业务逻辑之前处理

// 第 4 层:流量控制
r.Use(RateLimitMiddleware(100, 200))

// 第 5 层:认证(在业务中间件之前)
authGroup := r.Group("/")
authGroup.Use(JWTAuthMiddleware())

// 第 6 层:业务中间件(最内层,靠近 handler)
authGroup.Use(TenantMiddleware())
authGroup.Use(PermissionMiddleware())

为什么 Recovery 必须最早? 因为中间件链是从外到内执行的,Recovery 的 defer 语句是在最外层注册的,可以捕获所有内层 panic,包括其他中间件的 panic。如果 Recovery 在 Logger 之后注册,Logger 的 panic 就无法被捕获。

为什么 Logger 在安全中间件之前? 安全中间件可能会拒绝请求(返回 403、429 等),你需要记录这些被拒绝的请求用于安全审计和监控。如果 Logger 在安全中间件之后,被拒绝的请求就不会有日志。

熔断器模式

熔断器(Circuit Breaker)是微服务架构中防止级联故障的关键模式。它监控调用成功率,当失败率超过阈值时"断路"——快速失败而不是让请求等待超时:

package middleware

import (
    "net/http"
    "sync"
    "time"
    
    "github.com/gin-gonic/gin"
)

type CircuitState int

const (
    StateClosed   CircuitState = iota  // 正常,允许请求
    StateOpen                          // 断路,快速拒绝
    StateHalfOpen                      // 半开,允许少量请求探测
)

type CircuitBreaker struct {
    mu           sync.Mutex
    state        CircuitState
    failures     int
    successes    int
    lastFailTime time.Time
    
    // 配置
    maxFailures    int           // 触发断路的失败次数阈值
    resetTimeout   time.Duration // 断路后多久进入半开状态
    halfOpenMaxReq int           // 半开状态允许的最大请求数
}

func NewCircuitBreaker(maxFailures int, resetTimeout time.Duration) *CircuitBreaker {
    return &CircuitBreaker{
        state:          StateClosed,
        maxFailures:    maxFailures,
        resetTimeout:   resetTimeout,
        halfOpenMaxReq: 3,
    }
}

func (cb *CircuitBreaker) Allow() bool {
    cb.mu.Lock()
    defer cb.mu.Unlock()
    
    switch cb.state {
    case StateClosed:
        return true
    case StateOpen:
        // 检查是否应该转入半开状态
        if time.Since(cb.lastFailTime) > cb.resetTimeout {
            cb.state = StateHalfOpen
            cb.successes = 0
            return true
        }
        return false
    case StateHalfOpen:
        return cb.successes < cb.halfOpenMaxReq
    }
    return false
}

func (cb *CircuitBreaker) RecordSuccess() {
    cb.mu.Lock()
    defer cb.mu.Unlock()
    
    switch cb.state {
    case StateHalfOpen:
        cb.successes++
        if cb.successes >= cb.halfOpenMaxReq {
            cb.state = StateClosed
            cb.failures = 0
        }
    case StateClosed:
        cb.failures = 0  // 重置失败计数
    }
}

func (cb *CircuitBreaker) RecordFailure() {
    cb.mu.Lock()
    defer cb.mu.Unlock()
    
    cb.lastFailTime = time.Now()
    
    switch cb.state {
    case StateClosed:
        cb.failures++
        if cb.failures >= cb.maxFailures {
            cb.state = StateOpen
        }
    case StateHalfOpen:
        cb.state = StateOpen  // 半开状态失败,重新断路
    }
}

// CircuitBreakerMiddleware 将熔断器应用于 Gin 路由
func CircuitBreakerMiddleware(cb *CircuitBreaker) gin.HandlerFunc {
    return func(c *gin.Context) {
        if !cb.Allow() {
            c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
                "error": "service temporarily unavailable, please retry later",
                "code":  5030,
            })
            return
        }
        
        c.Next()
        
        // 根据响应状态码记录成功或失败
        if c.Writer.Status() >= 500 {
            cb.RecordFailure()
        } else {
            cb.RecordSuccess()
        }
    }
}

分布式追踪注入

在微服务架构中,可以在中间件层注入 OpenTelemetry 追踪信息,实现自动化的分布式追踪:

package middleware

import (
    "github.com/gin-gonic/gin"
    "go.opentelemetry.io/otel"
    "go.opentelemetry.io/otel/attribute"
    "go.opentelemetry.io/otel/propagation"
    semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
    "go.opentelemetry.io/otel/trace"
)

func TracingMiddleware(serviceName string) gin.HandlerFunc {
    tracer := otel.Tracer(serviceName)
    propagator := otel.GetTextMapPropagator()
    
    return func(c *gin.Context) {
        // 从请求头提取上游的追踪上下文(W3C TraceContext / B3 格式)
        ctx := propagator.Extract(c.Request.Context(), 
            propagation.HeaderCarrier(c.Request.Header))
        
        // 创建新的 Span
        spanName := c.FullPath()  // 使用路由模板(/users/:id),而非实际路径
        if spanName == "" {
            spanName = c.Request.URL.Path
        }
        
        ctx, span := tracer.Start(ctx, spanName,
            trace.WithSpanKind(trace.SpanKindServer),
            trace.WithAttributes(
                semconv.HTTPMethodKey.String(c.Request.Method),
                semconv.HTTPURLKey.String(c.Request.URL.String()),
                semconv.HTTPSchemeKey.String(c.Request.URL.Scheme),
                semconv.NetHostNameKey.String(c.Request.Host),
            ),
        )
        defer span.End()
        
        // 将追踪上下文注入请求的 Context
        c.Request = c.Request.WithContext(ctx)
        
        // 将 TraceID 存入 Gin Context,便于日志关联
        c.Set("trace_id", span.SpanContext().TraceID().String())
        
        c.Next()
        
        // 记录响应信息
        span.SetAttributes(
            semconv.HTTPStatusCodeKey.Int(c.Writer.Status()),
        )
        
        if c.Writer.Status() >= 500 {
            span.RecordError(fmt.Errorf("HTTP %d", c.Writer.Status()))
        }
    }
}

httptest 测试中间件

中间件应该有独立的单元测试,不依赖完整的服务启动:

package middleware_test

import (
    "net/http"
    "net/http/httptest"
    "testing"

    "github.com/gin-gonic/gin"
    "github.com/stretchr/testify/assert"
    "yourproject/middleware"
)

func TestRateLimitMiddleware(t *testing.T) {
    gin.SetMode(gin.TestMode)
    
    r := gin.New()
    // 极小的限流:每秒 1 个请求,桶容量 1
    r.Use(middleware.RateLimitMiddleware(1, 1))
    r.GET("/test", func(c *gin.Context) {
        c.JSON(200, gin.H{"ok": true})
    })
    
    // 第一个请求应该成功
    w1 := httptest.NewRecorder()
    req1, _ := http.NewRequest("GET", "/test", nil)
    r.ServeHTTP(w1, req1)
    assert.Equal(t, http.StatusOK, w1.Code)
    
    // 第二个请求应该被限流(桶已空)
    w2 := httptest.NewRecorder()
    req2, _ := http.NewRequest("GET", "/test", nil)
    r.ServeHTTP(w2, req2)
    assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}

func TestCORSMiddleware(t *testing.T) {
    gin.SetMode(gin.TestMode)
    
    r := gin.New()
    config := middleware.CORSConfig{
        AllowOrigins: []string{"https://example.com"},
        AllowMethods: []string{"GET", "POST"},
        AllowHeaders: []string{"Content-Type"},
    }
    r.Use(middleware.CORSMiddleware(config))
    r.GET("/test", func(c *gin.Context) {
        c.JSON(200, gin.H{"ok": true})
    })
    
    t.Run("allowed origin", func(t *testing.T) {
        w := httptest.NewRecorder()
        req, _ := http.NewRequest("GET", "/test", nil)
        req.Header.Set("Origin", "https://example.com")
        r.ServeHTTP(w, req)
        
        assert.Equal(t, "https://example.com", 
            w.Header().Get("Access-Control-Allow-Origin"))
    })
    
    t.Run("disallowed origin", func(t *testing.T) {
        w := httptest.NewRecorder()
        req, _ := http.NewRequest("GET", "/test", nil)
        req.Header.Set("Origin", "https://evil.com")
        r.ServeHTTP(w, req)
        
        assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
    })
    
    t.Run("preflight request", func(t *testing.T) {
        w := httptest.NewRecorder()
        req, _ := http.NewRequest("OPTIONS", "/test", nil)
        req.Header.Set("Origin", "https://example.com")
        req.Header.Set("Access-Control-Request-Method", "POST")
        r.ServeHTTP(w, req)
        
        assert.Equal(t, http.StatusNoContent, w.Code)
    })
}

// 测试中间件状态传递
func TestContextValuePropagation(t *testing.T) {
    gin.SetMode(gin.TestMode)
    
    r := gin.New()
    
    // 设置值的中间件
    r.Use(func(c *gin.Context) {
        c.Set("test_key", "test_value")
        c.Next()
    })
    
    var capturedValue string
    r.GET("/test", func(c *gin.Context) {
        capturedValue = c.GetString("test_key")
        c.JSON(200, nil)
    })
    
    w := httptest.NewRecorder()
    req, _ := http.NewRequest("GET", "/test", nil)
    r.ServeHTTP(w, req)
    
    assert.Equal(t, "test_value", capturedValue)
}

核心要点回顾

  1. 中间件是横切关注点的最佳抽象,将基础设施逻辑从业务逻辑中彻底分离。
  2. Gin 的 c.Next() 实现了洋葱模型:同一个函数可以在请求进入和响应退出时各执行一段逻辑。
  3. c.Abort() 通过设置 index 标志来停止链,而不是通过异常或返回值传递。
  4. 中间件顺序至关重要:Recovery 最外,日志在安全之前,认证在业务之前。
  5. httptest 可以独立测试每个中间件,无需启动完整服务。
本章评分
4.8  / 5  (4 评分)

💬 留言讨论